1
0
Fork 0

feat: add UNet stride as its own parameter

This commit is contained in:
Sean Sube 2023-05-03 19:15:05 -05:00
parent 98386cb385
commit 3b02fc5768
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 69 additions and 21 deletions

View File

@ -125,14 +125,17 @@ def load_pipeline(
logger.debug("reusing existing diffusion pipeline") logger.debug("reusing existing diffusion pipeline")
pipe = cache_pipe pipe = cache_pipe
cache_pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
cache_pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
# update panorama params # update panorama params
if pipeline == "panorama": if pipeline == "panorama":
latent_window = params.tiles // 8 latent_window = params.tiles // 8
latent_stride = params.stride() // 8 latent_stride = params.stride // 8
cache_pipe.set_window_size(latent_window, latent_stride) cache_pipe.set_window_size(latent_window, latent_stride)
cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride) cache_pipe.vae_encoder.set_window_size(latent_window, params.overlap)
cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride) cache_pipe.vae_decoder.set_window_size(latent_window, params.overlap)
# update scheduler # update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key) cache_scheduler = server.cache.get("scheduler", scheduler_key)
@ -332,7 +335,7 @@ def load_pipeline(
# additional options for panorama pipeline # additional options for panorama pipeline
if pipeline == "panorama": if pipeline == "panorama":
components["window"] = params.tiles // 8 components["window"] = params.tiles // 8
components["stride"] = params.stride() // 8 components["stride"] = params.stride // 8
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
@ -433,16 +436,16 @@ def patch_pipeline(
server, server,
original_decoder, original_decoder,
decoder=True, decoder=True,
tiles=params.tiles, window=params.tiles,
stride=params.stride(), overlap=params.overlap,
) )
original_encoder = pipe.vae_encoder original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper( pipe.vae_encoder = VAEWrapper(
server, server,
original_encoder, original_encoder,
decoder=False, decoder=False,
tiles=params.tiles, window=params.tiles,
stride=params.stride(), overlap=params.overlap,
) )
elif hasattr(pipe, "vae"): elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE pass # TODO: current wrapper does not work with upscaling VAE

View File

@ -28,21 +28,22 @@ class VAEWrapper(object):
server: ServerContext, server: ServerContext,
wrapped: OnnxRuntimeModel, wrapped: OnnxRuntimeModel,
decoder: bool, decoder: bool,
tiles: int, window: int,
stride: int, overlap: float,
): ):
self.server = server self.server = server
self.wrapped = wrapped self.wrapped = wrapped
self.decoder = decoder self.decoder = decoder
self.set_window_size(tiles, stride) self.tiled = False
self.set_window_size(window, overlap)
def set_window_size(self, window: int, stride: int): def set_tiled(self, tiled: bool = True):
self.window = window self.tiled = tiled
self.stride = stride
self.tile_latent_min_size = self.window def set_window_size(self, window: int, overlap: float):
self.tile_sample_min_size = self.window * 8 self.tile_latent_min_size = window
self.tile_overlap_factor = self.stride / self.window self.tile_sample_min_size = window * 8
self.tile_overlap_factor = overlap
def __call__(self, latent_sample=None, sample=None, **kwargs): def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype global timestep_dtype
@ -62,7 +63,7 @@ class VAEWrapper(object):
logger.debug("converting VAE sample dtype") logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype) sample = sample.astype(timestep_dtype)
if self.window is not None and self.stride is not None: if self.tiled:
if self.decoder: if self.decoder:
return self.tiled_decode(latent_sample, **kwargs) return self.tiled_decode(latent_sample, **kwargs)
else: else:

View File

@ -369,6 +369,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
return views return views
@torch.no_grad()
def text2img( def text2img(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
@ -619,6 +620,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
images=image, nsfw_content_detected=has_nsfw_concept images=image, nsfw_content_detected=has_nsfw_concept
) )
@torch.no_grad()
def img2img( def img2img(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,

View File

@ -214,6 +214,7 @@ class ImageParams:
tiled_vae: bool = False, tiled_vae: bool = False,
tiles: int = 512, tiles: int = 512,
overlap: float = 0.25, overlap: float = 0.25,
stride: int = 64,
) -> None: ) -> None:
self.model = model self.model = model
self.pipeline = pipeline self.pipeline = pipeline
@ -232,6 +233,7 @@ class ImageParams:
self.tiled_vae = tiled_vae self.tiled_vae = tiled_vae
self.tiles = tiles self.tiles = tiles
self.overlap = overlap self.overlap = overlap
self.stride = stride
def do_cfg(self): def do_cfg(self):
return self.cfg > 1.0 return self.cfg > 1.0
@ -260,9 +262,6 @@ class ImageParams:
def lpw(self): def lpw(self):
return self.pipeline == "lpw" return self.pipeline == "lpw"
def stride(self):
return int(self.tiles * self.overlap)
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
"model": self.model, "model": self.model,
@ -282,6 +281,7 @@ class ImageParams:
"tiled_vae": self.tiled_vae, "tiled_vae": self.tiled_vae,
"tiles": self.tiles, "tiles": self.tiles,
"overlap": self.overlap, "overlap": self.overlap,
"stride": self.stride,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -303,6 +303,7 @@ class ImageParams:
kwargs.get("tiled_vae", self.tiled_vae), kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles), kwargs.get("tiles", self.tiles),
kwargs.get("overlap", self.overlap), kwargs.get("overlap", self.overlap),
kwargs.get("stride", self.stride),
) )

View File

@ -140,6 +140,13 @@ def pipeline_from_request(
get_config_value("overlap", "max"), get_config_value("overlap", "max"),
get_config_value("overlap", "min"), get_config_value("overlap", "min"),
) )
stride = get_and_clamp_float(
request.args,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
)
seed = int(request.args.get("seed", -1)) seed = int(request.args.get("seed", -1))
if seed == -1: if seed == -1:
@ -177,6 +184,7 @@ def pipeline_from_request(
tiled_vae=tiled_vae, tiled_vae=tiled_vae,
tiles=tiles, tiles=tiles,
overlap=overlap, overlap=overlap,
stride=stride,
) )
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)

View File

@ -186,6 +186,12 @@
"max": 1, "max": 1,
"step": 0.01 "step": 0.01
}, },
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": { "tiledVAE": {
"default": false "default": false
}, },

View File

@ -51,6 +51,7 @@ export interface BaseImgParams {
tiledVAE: boolean; tiledVAE: boolean;
tiles: number; tiles: number;
overlap: number; overlap: number;
stride: number;
cfg: number; cfg: number;
steps: number; steps: number;
@ -409,6 +410,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
url.searchParams.append('tiledVAE', String(params.tiledVAE)); url.searchParams.append('tiledVAE', String(params.tiledVAE));
url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER)); url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER));
url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT)); url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT));
url.searchParams.append('stride', params.stride.toFixed(FIXED_FLOAT));
if (doesExist(params.scheduler)) { if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler); url.searchParams.append('scheduler', params.scheduler);

View File

@ -178,6 +178,21 @@ export function ImageControl(props: ImageControlProps) {
} }
}} }}
/> />
<NumericField
label={t('parameter.stride')}
min={params.stride.min}
max={params.stride.max}
step={params.stride.step}
value={controlState.stride}
onChange={(stride) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
stride,
});
}
}}
/>
<FormControlLabel <FormControlLabel
label={t('parameter.tiledVAE')} label={t('parameter.tiledVAE')}
control={<Checkbox control={<Checkbox

View File

@ -184,6 +184,12 @@
"max": 1, "max": 1,
"step": 0.01 "step": 0.01
}, },
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": { "tiledVAE": {
"default": false "default": false
}, },

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_DE = {
sourceFilter: '', sourceFilter: '',
steps: 'Schritte', steps: 'Schritte',
strength: 'Stärke', strength: 'Stärke',
stride: '',
tiledVAE: '', tiledVAE: '',
tiles: '', tiles: '',
tileOrder: '', tileOrder: '',

View File

@ -182,6 +182,7 @@ export const I18N_STRINGS_EN = {
sourceFilter: 'Source Filter', sourceFilter: 'Source Filter',
steps: 'Steps', steps: 'Steps',
strength: 'Strength', strength: 'Strength',
stride: 'UNet Stride',
tiledVAE: 'Tiled VAE', tiledVAE: 'Tiled VAE',
tiles: 'Tile Size', tiles: 'Tile Size',
tileOrder: 'Tile Order', tileOrder: 'Tile Order',

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_ES = {
sourceFilter: '', sourceFilter: '',
steps: 'Pasos', steps: 'Pasos',
strength: 'Fuerza', strength: 'Fuerza',
stride: '',
tiledVAE: '', tiledVAE: '',
tiles: '', tiles: '',
tileOrder: 'Orden de secciones', tileOrder: 'Orden de secciones',

View File

@ -130,6 +130,7 @@ export const I18N_STRINGS_FR = {
sourceFilter: '', sourceFilter: '',
steps: '', steps: '',
strength: '', strength: '',
stride: '',
tiledVAE: '', tiledVAE: '',
tiles: '', tiles: '',
tileOrder: '', tileOrder: '',