feat: add UNet stride as its own parameter
This commit is contained in:
parent
98386cb385
commit
3b02fc5768
|
@ -125,14 +125,17 @@ def load_pipeline(
|
||||||
logger.debug("reusing existing diffusion pipeline")
|
logger.debug("reusing existing diffusion pipeline")
|
||||||
pipe = cache_pipe
|
pipe = cache_pipe
|
||||||
|
|
||||||
|
cache_pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
||||||
|
cache_pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||||
|
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if pipeline == "panorama":
|
if pipeline == "panorama":
|
||||||
latent_window = params.tiles // 8
|
latent_window = params.tiles // 8
|
||||||
latent_stride = params.stride() // 8
|
latent_stride = params.stride // 8
|
||||||
|
|
||||||
cache_pipe.set_window_size(latent_window, latent_stride)
|
cache_pipe.set_window_size(latent_window, latent_stride)
|
||||||
cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride)
|
cache_pipe.vae_encoder.set_window_size(latent_window, params.overlap)
|
||||||
cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride)
|
cache_pipe.vae_decoder.set_window_size(latent_window, params.overlap)
|
||||||
|
|
||||||
# update scheduler
|
# update scheduler
|
||||||
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
||||||
|
@ -332,7 +335,7 @@ def load_pipeline(
|
||||||
# additional options for panorama pipeline
|
# additional options for panorama pipeline
|
||||||
if pipeline == "panorama":
|
if pipeline == "panorama":
|
||||||
components["window"] = params.tiles // 8
|
components["window"] = params.tiles // 8
|
||||||
components["stride"] = params.stride() // 8
|
components["stride"] = params.stride // 8
|
||||||
|
|
||||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||||
|
@ -433,16 +436,16 @@ def patch_pipeline(
|
||||||
server,
|
server,
|
||||||
original_decoder,
|
original_decoder,
|
||||||
decoder=True,
|
decoder=True,
|
||||||
tiles=params.tiles,
|
window=params.tiles,
|
||||||
stride=params.stride(),
|
overlap=params.overlap,
|
||||||
)
|
)
|
||||||
original_encoder = pipe.vae_encoder
|
original_encoder = pipe.vae_encoder
|
||||||
pipe.vae_encoder = VAEWrapper(
|
pipe.vae_encoder = VAEWrapper(
|
||||||
server,
|
server,
|
||||||
original_encoder,
|
original_encoder,
|
||||||
decoder=False,
|
decoder=False,
|
||||||
tiles=params.tiles,
|
window=params.tiles,
|
||||||
stride=params.stride(),
|
overlap=params.overlap,
|
||||||
)
|
)
|
||||||
elif hasattr(pipe, "vae"):
|
elif hasattr(pipe, "vae"):
|
||||||
pass # TODO: current wrapper does not work with upscaling VAE
|
pass # TODO: current wrapper does not work with upscaling VAE
|
||||||
|
|
|
@ -28,21 +28,22 @@ class VAEWrapper(object):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
wrapped: OnnxRuntimeModel,
|
wrapped: OnnxRuntimeModel,
|
||||||
decoder: bool,
|
decoder: bool,
|
||||||
tiles: int,
|
window: int,
|
||||||
stride: int,
|
overlap: float,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.set_window_size(tiles, stride)
|
self.tiled = False
|
||||||
|
self.set_window_size(window, overlap)
|
||||||
|
|
||||||
def set_window_size(self, window: int, stride: int):
|
def set_tiled(self, tiled: bool = True):
|
||||||
self.window = window
|
self.tiled = tiled
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
self.tile_latent_min_size = self.window
|
def set_window_size(self, window: int, overlap: float):
|
||||||
self.tile_sample_min_size = self.window * 8
|
self.tile_latent_min_size = window
|
||||||
self.tile_overlap_factor = self.stride / self.window
|
self.tile_sample_min_size = window * 8
|
||||||
|
self.tile_overlap_factor = overlap
|
||||||
|
|
||||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
|
@ -62,7 +63,7 @@ class VAEWrapper(object):
|
||||||
logger.debug("converting VAE sample dtype")
|
logger.debug("converting VAE sample dtype")
|
||||||
sample = sample.astype(timestep_dtype)
|
sample = sample.astype(timestep_dtype)
|
||||||
|
|
||||||
if self.window is not None and self.stride is not None:
|
if self.tiled:
|
||||||
if self.decoder:
|
if self.decoder:
|
||||||
return self.tiled_decode(latent_sample, **kwargs)
|
return self.tiled_decode(latent_sample, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -369,6 +369,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
return views
|
return views
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def text2img(
|
def text2img(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]] = None,
|
prompt: Union[str, List[str]] = None,
|
||||||
|
@ -619,6 +620,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
images=image, nsfw_content_detected=has_nsfw_concept
|
images=image, nsfw_content_detected=has_nsfw_concept
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def img2img(
|
def img2img(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]] = None,
|
prompt: Union[str, List[str]] = None,
|
||||||
|
|
|
@ -214,6 +214,7 @@ class ImageParams:
|
||||||
tiled_vae: bool = False,
|
tiled_vae: bool = False,
|
||||||
tiles: int = 512,
|
tiles: int = 512,
|
||||||
overlap: float = 0.25,
|
overlap: float = 0.25,
|
||||||
|
stride: int = 64,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
@ -232,6 +233,7 @@ class ImageParams:
|
||||||
self.tiled_vae = tiled_vae
|
self.tiled_vae = tiled_vae
|
||||||
self.tiles = tiles
|
self.tiles = tiles
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
def do_cfg(self):
|
def do_cfg(self):
|
||||||
return self.cfg > 1.0
|
return self.cfg > 1.0
|
||||||
|
@ -260,9 +262,6 @@ class ImageParams:
|
||||||
def lpw(self):
|
def lpw(self):
|
||||||
return self.pipeline == "lpw"
|
return self.pipeline == "lpw"
|
||||||
|
|
||||||
def stride(self):
|
|
||||||
return int(self.tiles * self.overlap)
|
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
@ -282,6 +281,7 @@ class ImageParams:
|
||||||
"tiled_vae": self.tiled_vae,
|
"tiled_vae": self.tiled_vae,
|
||||||
"tiles": self.tiles,
|
"tiles": self.tiles,
|
||||||
"overlap": self.overlap,
|
"overlap": self.overlap,
|
||||||
|
"stride": self.stride,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -303,6 +303,7 @@ class ImageParams:
|
||||||
kwargs.get("tiled_vae", self.tiled_vae),
|
kwargs.get("tiled_vae", self.tiled_vae),
|
||||||
kwargs.get("tiles", self.tiles),
|
kwargs.get("tiles", self.tiles),
|
||||||
kwargs.get("overlap", self.overlap),
|
kwargs.get("overlap", self.overlap),
|
||||||
|
kwargs.get("stride", self.stride),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -140,6 +140,13 @@ def pipeline_from_request(
|
||||||
get_config_value("overlap", "max"),
|
get_config_value("overlap", "max"),
|
||||||
get_config_value("overlap", "min"),
|
get_config_value("overlap", "min"),
|
||||||
)
|
)
|
||||||
|
stride = get_and_clamp_float(
|
||||||
|
request.args,
|
||||||
|
"stride",
|
||||||
|
get_config_value("stride"),
|
||||||
|
get_config_value("stride", "max"),
|
||||||
|
get_config_value("stride", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
seed = int(request.args.get("seed", -1))
|
seed = int(request.args.get("seed", -1))
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
|
@ -177,6 +184,7 @@ def pipeline_from_request(
|
||||||
tiled_vae=tiled_vae,
|
tiled_vae=tiled_vae,
|
||||||
tiles=tiles,
|
tiles=tiles,
|
||||||
overlap=overlap,
|
overlap=overlap,
|
||||||
|
stride=stride,
|
||||||
)
|
)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
|
|
@ -186,6 +186,12 @@
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 0.01
|
"step": 0.01
|
||||||
},
|
},
|
||||||
|
"stride": {
|
||||||
|
"default": 128,
|
||||||
|
"min": 64,
|
||||||
|
"max": 512,
|
||||||
|
"step": 64
|
||||||
|
},
|
||||||
"tiledVAE": {
|
"tiledVAE": {
|
||||||
"default": false
|
"default": false
|
||||||
},
|
},
|
||||||
|
|
|
@ -51,6 +51,7 @@ export interface BaseImgParams {
|
||||||
tiledVAE: boolean;
|
tiledVAE: boolean;
|
||||||
tiles: number;
|
tiles: number;
|
||||||
overlap: number;
|
overlap: number;
|
||||||
|
stride: number;
|
||||||
|
|
||||||
cfg: number;
|
cfg: number;
|
||||||
steps: number;
|
steps: number;
|
||||||
|
@ -409,6 +410,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
||||||
url.searchParams.append('tiledVAE', String(params.tiledVAE));
|
url.searchParams.append('tiledVAE', String(params.tiledVAE));
|
||||||
url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER));
|
url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER));
|
||||||
url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT));
|
url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT));
|
||||||
|
url.searchParams.append('stride', params.stride.toFixed(FIXED_FLOAT));
|
||||||
|
|
||||||
if (doesExist(params.scheduler)) {
|
if (doesExist(params.scheduler)) {
|
||||||
url.searchParams.append('scheduler', params.scheduler);
|
url.searchParams.append('scheduler', params.scheduler);
|
||||||
|
|
|
@ -178,6 +178,21 @@ export function ImageControl(props: ImageControlProps) {
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<NumericField
|
||||||
|
label={t('parameter.stride')}
|
||||||
|
min={params.stride.min}
|
||||||
|
max={params.stride.max}
|
||||||
|
step={params.stride.step}
|
||||||
|
value={controlState.stride}
|
||||||
|
onChange={(stride) => {
|
||||||
|
if (doesExist(props.onChange)) {
|
||||||
|
props.onChange({
|
||||||
|
...controlState,
|
||||||
|
stride,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
<FormControlLabel
|
<FormControlLabel
|
||||||
label={t('parameter.tiledVAE')}
|
label={t('parameter.tiledVAE')}
|
||||||
control={<Checkbox
|
control={<Checkbox
|
||||||
|
|
|
@ -184,6 +184,12 @@
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 0.01
|
"step": 0.01
|
||||||
},
|
},
|
||||||
|
"stride": {
|
||||||
|
"default": 128,
|
||||||
|
"min": 64,
|
||||||
|
"max": 512,
|
||||||
|
"step": 64
|
||||||
|
},
|
||||||
"tiledVAE": {
|
"tiledVAE": {
|
||||||
"default": false
|
"default": false
|
||||||
},
|
},
|
||||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_DE = {
|
||||||
sourceFilter: '',
|
sourceFilter: '',
|
||||||
steps: 'Schritte',
|
steps: 'Schritte',
|
||||||
strength: 'Stärke',
|
strength: 'Stärke',
|
||||||
|
stride: '',
|
||||||
tiledVAE: '',
|
tiledVAE: '',
|
||||||
tiles: '',
|
tiles: '',
|
||||||
tileOrder: '',
|
tileOrder: '',
|
||||||
|
|
|
@ -182,6 +182,7 @@ export const I18N_STRINGS_EN = {
|
||||||
sourceFilter: 'Source Filter',
|
sourceFilter: 'Source Filter',
|
||||||
steps: 'Steps',
|
steps: 'Steps',
|
||||||
strength: 'Strength',
|
strength: 'Strength',
|
||||||
|
stride: 'UNet Stride',
|
||||||
tiledVAE: 'Tiled VAE',
|
tiledVAE: 'Tiled VAE',
|
||||||
tiles: 'Tile Size',
|
tiles: 'Tile Size',
|
||||||
tileOrder: 'Tile Order',
|
tileOrder: 'Tile Order',
|
||||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_ES = {
|
||||||
sourceFilter: '',
|
sourceFilter: '',
|
||||||
steps: 'Pasos',
|
steps: 'Pasos',
|
||||||
strength: 'Fuerza',
|
strength: 'Fuerza',
|
||||||
|
stride: '',
|
||||||
tiledVAE: '',
|
tiledVAE: '',
|
||||||
tiles: '',
|
tiles: '',
|
||||||
tileOrder: 'Orden de secciones',
|
tileOrder: 'Orden de secciones',
|
||||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_FR = {
|
||||||
sourceFilter: '',
|
sourceFilter: '',
|
||||||
steps: '',
|
steps: '',
|
||||||
strength: '',
|
strength: '',
|
||||||
|
stride: '',
|
||||||
tiledVAE: '',
|
tiledVAE: '',
|
||||||
tiles: '',
|
tiles: '',
|
||||||
tileOrder: '',
|
tileOrder: '',
|
||||||
|
|
Loading…
Reference in New Issue