1
0
Fork 0

feat: add UNet stride as its own parameter

This commit is contained in:
Sean Sube 2023-05-03 19:15:05 -05:00
parent 98386cb385
commit 3b02fc5768
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 69 additions and 21 deletions

View File

@ -125,14 +125,17 @@ def load_pipeline(
logger.debug("reusing existing diffusion pipeline")
pipe = cache_pipe
cache_pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
cache_pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
# update panorama params
if pipeline == "panorama":
latent_window = params.tiles // 8
latent_stride = params.stride() // 8
latent_stride = params.stride // 8
cache_pipe.set_window_size(latent_window, latent_stride)
cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride)
cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride)
cache_pipe.vae_encoder.set_window_size(latent_window, params.overlap)
cache_pipe.vae_decoder.set_window_size(latent_window, params.overlap)
# update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key)
@ -332,7 +335,7 @@ def load_pipeline(
# additional options for panorama pipeline
if pipeline == "panorama":
components["window"] = params.tiles // 8
components["stride"] = params.stride() // 8
components["stride"] = params.stride // 8
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
@ -433,16 +436,16 @@ def patch_pipeline(
server,
original_decoder,
decoder=True,
tiles=params.tiles,
stride=params.stride(),
window=params.tiles,
overlap=params.overlap,
)
original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper(
server,
original_encoder,
decoder=False,
tiles=params.tiles,
stride=params.stride(),
window=params.tiles,
overlap=params.overlap,
)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE

View File

@ -28,21 +28,22 @@ class VAEWrapper(object):
server: ServerContext,
wrapped: OnnxRuntimeModel,
decoder: bool,
tiles: int,
stride: int,
window: int,
overlap: float,
):
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.set_window_size(tiles, stride)
self.tiled = False
self.set_window_size(window, overlap)
def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride
def set_tiled(self, tiled: bool = True):
self.tiled = tiled
self.tile_latent_min_size = self.window
self.tile_sample_min_size = self.window * 8
self.tile_overlap_factor = self.stride / self.window
def set_window_size(self, window: int, overlap: float):
self.tile_latent_min_size = window
self.tile_sample_min_size = window * 8
self.tile_overlap_factor = overlap
def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype
@ -62,7 +63,7 @@ class VAEWrapper(object):
logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
if self.window is not None and self.stride is not None:
if self.tiled:
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:

View File

@ -369,6 +369,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
)
return views
@torch.no_grad()
def text2img(
self,
prompt: Union[str, List[str]] = None,
@ -619,6 +620,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
images=image, nsfw_content_detected=has_nsfw_concept
)
@torch.no_grad()
def img2img(
self,
prompt: Union[str, List[str]] = None,

View File

@ -214,6 +214,7 @@ class ImageParams:
tiled_vae: bool = False,
tiles: int = 512,
overlap: float = 0.25,
stride: int = 64,
) -> None:
self.model = model
self.pipeline = pipeline
@ -232,6 +233,7 @@ class ImageParams:
self.tiled_vae = tiled_vae
self.tiles = tiles
self.overlap = overlap
self.stride = stride
def do_cfg(self):
return self.cfg > 1.0
@ -260,9 +262,6 @@ class ImageParams:
def lpw(self):
return self.pipeline == "lpw"
def stride(self):
return int(self.tiles * self.overlap)
def tojson(self) -> Dict[str, Optional[Param]]:
return {
"model": self.model,
@ -282,6 +281,7 @@ class ImageParams:
"tiled_vae": self.tiled_vae,
"tiles": self.tiles,
"overlap": self.overlap,
"stride": self.stride,
}
def with_args(self, **kwargs):
@ -303,6 +303,7 @@ class ImageParams:
kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles),
kwargs.get("overlap", self.overlap),
kwargs.get("stride", self.stride),
)

View File

@ -140,6 +140,13 @@ def pipeline_from_request(
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
stride = get_and_clamp_float(
request.args,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
)
seed = int(request.args.get("seed", -1))
if seed == -1:
@ -177,6 +184,7 @@ def pipeline_from_request(
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
)
size = Size(width, height)
return (device, params, size)

View File

@ -186,6 +186,12 @@
"max": 1,
"step": 0.01
},
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": {
"default": false
},

View File

@ -51,6 +51,7 @@ export interface BaseImgParams {
tiledVAE: boolean;
tiles: number;
overlap: number;
stride: number;
cfg: number;
steps: number;
@ -409,6 +410,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
url.searchParams.append('tiledVAE', String(params.tiledVAE));
url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER));
url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT));
url.searchParams.append('stride', params.stride.toFixed(FIXED_FLOAT));
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);

View File

@ -178,6 +178,21 @@ export function ImageControl(props: ImageControlProps) {
}
}}
/>
<NumericField
label={t('parameter.stride')}
min={params.stride.min}
max={params.stride.max}
step={params.stride.step}
value={controlState.stride}
onChange={(stride) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
stride,
});
}
}}
/>
<FormControlLabel
label={t('parameter.tiledVAE')}
control={<Checkbox

View File

@ -184,6 +184,12 @@
"max": 1,
"step": 0.01
},
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": {
"default": false
},

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_DE = {
sourceFilter: '',
steps: 'Schritte',
strength: 'Stärke',
stride: '',
tiledVAE: '',
tiles: '',
tileOrder: '',

View File

@ -182,6 +182,7 @@ export const I18N_STRINGS_EN = {
sourceFilter: 'Source Filter',
steps: 'Steps',
strength: 'Strength',
stride: 'UNet Stride',
tiledVAE: 'Tiled VAE',
tiles: 'Tile Size',
tileOrder: 'Tile Order',

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_ES = {
sourceFilter: '',
steps: 'Pasos',
strength: 'Fuerza',
stride: '',
tiledVAE: '',
tiles: '',
tileOrder: 'Orden de secciones',

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_FR = {
sourceFilter: '',
steps: '',
strength: '',
stride: '',
tiledVAE: '',
tiles: '',
tileOrder: '',