feat: add UNet stride as its own parameter
This commit is contained in:
parent
98386cb385
commit
3b02fc5768
|
@ -125,14 +125,17 @@ def load_pipeline(
|
|||
logger.debug("reusing existing diffusion pipeline")
|
||||
pipe = cache_pipe
|
||||
|
||||
cache_pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
||||
cache_pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||
|
||||
# update panorama params
|
||||
if pipeline == "panorama":
|
||||
latent_window = params.tiles // 8
|
||||
latent_stride = params.stride() // 8
|
||||
latent_stride = params.stride // 8
|
||||
|
||||
cache_pipe.set_window_size(latent_window, latent_stride)
|
||||
cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride)
|
||||
cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride)
|
||||
cache_pipe.vae_encoder.set_window_size(latent_window, params.overlap)
|
||||
cache_pipe.vae_decoder.set_window_size(latent_window, params.overlap)
|
||||
|
||||
# update scheduler
|
||||
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
||||
|
@ -332,7 +335,7 @@ def load_pipeline(
|
|||
# additional options for panorama pipeline
|
||||
if pipeline == "panorama":
|
||||
components["window"] = params.tiles // 8
|
||||
components["stride"] = params.stride() // 8
|
||||
components["stride"] = params.stride // 8
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
|
@ -433,16 +436,16 @@ def patch_pipeline(
|
|||
server,
|
||||
original_decoder,
|
||||
decoder=True,
|
||||
tiles=params.tiles,
|
||||
stride=params.stride(),
|
||||
window=params.tiles,
|
||||
overlap=params.overlap,
|
||||
)
|
||||
original_encoder = pipe.vae_encoder
|
||||
pipe.vae_encoder = VAEWrapper(
|
||||
server,
|
||||
original_encoder,
|
||||
decoder=False,
|
||||
tiles=params.tiles,
|
||||
stride=params.stride(),
|
||||
window=params.tiles,
|
||||
overlap=params.overlap,
|
||||
)
|
||||
elif hasattr(pipe, "vae"):
|
||||
pass # TODO: current wrapper does not work with upscaling VAE
|
||||
|
|
|
@ -28,21 +28,22 @@ class VAEWrapper(object):
|
|||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
decoder: bool,
|
||||
tiles: int,
|
||||
stride: int,
|
||||
window: int,
|
||||
overlap: float,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.decoder = decoder
|
||||
self.set_window_size(tiles, stride)
|
||||
self.tiled = False
|
||||
self.set_window_size(window, overlap)
|
||||
|
||||
def set_window_size(self, window: int, stride: int):
|
||||
self.window = window
|
||||
self.stride = stride
|
||||
def set_tiled(self, tiled: bool = True):
|
||||
self.tiled = tiled
|
||||
|
||||
self.tile_latent_min_size = self.window
|
||||
self.tile_sample_min_size = self.window * 8
|
||||
self.tile_overlap_factor = self.stride / self.window
|
||||
def set_window_size(self, window: int, overlap: float):
|
||||
self.tile_latent_min_size = window
|
||||
self.tile_sample_min_size = window * 8
|
||||
self.tile_overlap_factor = overlap
|
||||
|
||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||
global timestep_dtype
|
||||
|
@ -62,7 +63,7 @@ class VAEWrapper(object):
|
|||
logger.debug("converting VAE sample dtype")
|
||||
sample = sample.astype(timestep_dtype)
|
||||
|
||||
if self.window is not None and self.stride is not None:
|
||||
if self.tiled:
|
||||
if self.decoder:
|
||||
return self.tiled_decode(latent_sample, **kwargs)
|
||||
else:
|
||||
|
|
|
@ -369,6 +369,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
)
|
||||
return views
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
|
@ -619,6 +620,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def img2img(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
|
|
|
@ -214,6 +214,7 @@ class ImageParams:
|
|||
tiled_vae: bool = False,
|
||||
tiles: int = 512,
|
||||
overlap: float = 0.25,
|
||||
stride: int = 64,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
|
@ -232,6 +233,7 @@ class ImageParams:
|
|||
self.tiled_vae = tiled_vae
|
||||
self.tiles = tiles
|
||||
self.overlap = overlap
|
||||
self.stride = stride
|
||||
|
||||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
@ -260,9 +262,6 @@ class ImageParams:
|
|||
def lpw(self):
|
||||
return self.pipeline == "lpw"
|
||||
|
||||
def stride(self):
|
||||
return int(self.tiles * self.overlap)
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
"model": self.model,
|
||||
|
@ -282,6 +281,7 @@ class ImageParams:
|
|||
"tiled_vae": self.tiled_vae,
|
||||
"tiles": self.tiles,
|
||||
"overlap": self.overlap,
|
||||
"stride": self.stride,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -303,6 +303,7 @@ class ImageParams:
|
|||
kwargs.get("tiled_vae", self.tiled_vae),
|
||||
kwargs.get("tiles", self.tiles),
|
||||
kwargs.get("overlap", self.overlap),
|
||||
kwargs.get("stride", self.stride),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -140,6 +140,13 @@ def pipeline_from_request(
|
|||
get_config_value("overlap", "max"),
|
||||
get_config_value("overlap", "min"),
|
||||
)
|
||||
stride = get_and_clamp_float(
|
||||
request.args,
|
||||
"stride",
|
||||
get_config_value("stride"),
|
||||
get_config_value("stride", "max"),
|
||||
get_config_value("stride", "min"),
|
||||
)
|
||||
|
||||
seed = int(request.args.get("seed", -1))
|
||||
if seed == -1:
|
||||
|
@ -177,6 +184,7 @@ def pipeline_from_request(
|
|||
tiled_vae=tiled_vae,
|
||||
tiles=tiles,
|
||||
overlap=overlap,
|
||||
stride=stride,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
|
|
@ -186,6 +186,12 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"stride": {
|
||||
"default": 128,
|
||||
"min": 64,
|
||||
"max": 512,
|
||||
"step": 64
|
||||
},
|
||||
"tiledVAE": {
|
||||
"default": false
|
||||
},
|
||||
|
|
|
@ -51,6 +51,7 @@ export interface BaseImgParams {
|
|||
tiledVAE: boolean;
|
||||
tiles: number;
|
||||
overlap: number;
|
||||
stride: number;
|
||||
|
||||
cfg: number;
|
||||
steps: number;
|
||||
|
@ -409,6 +410,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
|||
url.searchParams.append('tiledVAE', String(params.tiledVAE));
|
||||
url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('stride', params.stride.toFixed(FIXED_FLOAT));
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
url.searchParams.append('scheduler', params.scheduler);
|
||||
|
|
|
@ -178,6 +178,21 @@ export function ImageControl(props: ImageControlProps) {
|
|||
}
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label={t('parameter.stride')}
|
||||
min={params.stride.min}
|
||||
max={params.stride.max}
|
||||
step={params.stride.step}
|
||||
value={controlState.stride}
|
||||
onChange={(stride) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...controlState,
|
||||
stride,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<FormControlLabel
|
||||
label={t('parameter.tiledVAE')}
|
||||
control={<Checkbox
|
||||
|
|
|
@ -184,6 +184,12 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"stride": {
|
||||
"default": 128,
|
||||
"min": 64,
|
||||
"max": 512,
|
||||
"step": 64
|
||||
},
|
||||
"tiledVAE": {
|
||||
"default": false
|
||||
},
|
||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_DE = {
|
|||
sourceFilter: '',
|
||||
steps: 'Schritte',
|
||||
strength: 'Stärke',
|
||||
stride: '',
|
||||
tiledVAE: '',
|
||||
tiles: '',
|
||||
tileOrder: '',
|
||||
|
|
|
@ -182,6 +182,7 @@ export const I18N_STRINGS_EN = {
|
|||
sourceFilter: 'Source Filter',
|
||||
steps: 'Steps',
|
||||
strength: 'Strength',
|
||||
stride: 'UNet Stride',
|
||||
tiledVAE: 'Tiled VAE',
|
||||
tiles: 'Tile Size',
|
||||
tileOrder: 'Tile Order',
|
||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_ES = {
|
|||
sourceFilter: '',
|
||||
steps: 'Pasos',
|
||||
strength: 'Fuerza',
|
||||
stride: '',
|
||||
tiledVAE: '',
|
||||
tiles: '',
|
||||
tileOrder: 'Orden de secciones',
|
||||
|
|
|
@ -130,6 +130,7 @@ export const I18N_STRINGS_FR = {
|
|||
sourceFilter: '',
|
||||
steps: '',
|
||||
strength: '',
|
||||
stride: '',
|
||||
tiledVAE: '',
|
||||
tiles: '',
|
||||
tileOrder: '',
|
||||
|
|
Loading…
Reference in New Issue