1
0
Fork 0

skip loading component models for XL

This commit is contained in:
Sean Sube 2023-08-20 22:28:08 -05:00
parent 273ace7693
commit 44004730ea
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 39 additions and 22 deletions

View File

@ -51,20 +51,20 @@ class BlendImg2ImgStage(BaseStage):
)
pipe_params = {}
if pipe_type == "controlnet":
if params.is_control():
pipe_params["controlnet_conditioning_scale"] = strength
elif params.is_lpw():
pipe_params["strength"] = strength
elif params.is_panorama():
pipe_params["strength"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "lpw":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
outputs = []
for source in sources:
if params.lpw():
if params.is_lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
@ -82,7 +82,9 @@ class BlendImg2ImgStage(BaseStage):
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -74,7 +74,7 @@ class SourceTxt2ImgStage(BaseStage):
loras=loras,
)
if params.lpw():
if params.is_lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
@ -95,7 +95,9 @@ class SourceTxt2ImgStage(BaseStage):
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -81,7 +81,7 @@ class UpscaleOutpaintStage(BaseStage):
else:
latents = get_tile_latents(latents, params.seed, latent_size, dims)
if params.lpw():
if params.is_lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(

View File

@ -171,7 +171,7 @@ def load_pipeline(
unet_type = "unet"
# ControlNet component
if pipeline == "controlnet" and params.control is not None:
if params.is_control() and params.control is not None:
cnet_path = path.join(
server.model_path, "control", f"{params.control.name}.onnx"
)
@ -282,7 +282,7 @@ def load_pipeline(
)
# make sure a UNet has been loaded
if "unet" not in components:
if not params.is_xl() and "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
@ -298,7 +298,7 @@ def load_pipeline(
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
if path.exists(vae):
if not params.is_xl() and path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -307,7 +307,9 @@ def load_pipeline(
sess_options=device.sess_options(),
)
)
elif path.exists(vae_decoder) and path.exists(vae_encoder):
elif (
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
):
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -327,7 +329,7 @@ def load_pipeline(
)
# additional options for panorama pipeline
if pipeline == "panorama":
if params.is_panorama():
components["window"] = params.tiles // 8
components["stride"] = params.stride // 8
@ -346,18 +348,21 @@ def load_pipeline(
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
if not params.is_xl():
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if hasattr(pipe, "vae_decoder"):
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
if hasattr(pipe, "vae_encoder"):
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
# update panorama params
if pipeline == "panorama":
if params.is_panorama():
latent_window = params.tiles // 8
latent_stride = params.stride // 8

View File

@ -43,8 +43,7 @@ def run_txt2img_pipeline(
highres: HighresParams,
) -> None:
# if using panorama, the pipeline will tile itself (views)
pipe_type = params.get_valid_pipeline("txt2img")
if pipe_type == "panorama":
if params.is_panorama():
tile_size = max(params.tiles, size.width, size.height)
else:
tile_size = params.tiles

View File

@ -271,9 +271,18 @@ class ImageParams:
logger.debug("pipeline %s is not valid for %s", pipeline, group)
return group
def lpw(self):
def is_control(self):
return self.pipeline == "controlnet"
def is_lpw(self):
return self.pipeline == "lpw"
def is_panorama(self):
return self.pipeline == "panorama"
def is_xl(self):
return self.pipeline.endswith("-sdxl")
def tojson(self) -> Dict[str, Optional[Param]]:
return {
"model": self.model,