skip loading component models for XL
This commit is contained in:
parent
273ace7693
commit
44004730ea
|
@ -51,20 +51,20 @@ class BlendImg2ImgStage(BaseStage):
|
|||
)
|
||||
|
||||
pipe_params = {}
|
||||
if pipe_type == "controlnet":
|
||||
if params.is_control():
|
||||
pipe_params["controlnet_conditioning_scale"] = strength
|
||||
elif params.is_lpw():
|
||||
pipe_params["strength"] = strength
|
||||
elif params.is_panorama():
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "img2img":
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "lpw":
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "panorama":
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "pix2pix":
|
||||
pipe_params["image_guidance_scale"] = strength
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
if params.lpw():
|
||||
if params.is_lpw():
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
|
@ -82,7 +82,9 @@ class BlendImg2ImgStage(BaseStage):
|
|||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
if not params.xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
|
|
|
@ -74,7 +74,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
loras=loras,
|
||||
)
|
||||
|
||||
if params.lpw():
|
||||
if params.is_lpw():
|
||||
logger.debug("using LPW pipeline for txt2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
|
@ -95,7 +95,9 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
if not params.is_xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
|
|
|
@ -81,7 +81,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
else:
|
||||
latents = get_tile_latents(latents, params.seed, latent_size, dims)
|
||||
|
||||
if params.lpw():
|
||||
if params.is_lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
|
|
|
@ -171,7 +171,7 @@ def load_pipeline(
|
|||
unet_type = "unet"
|
||||
|
||||
# ControlNet component
|
||||
if pipeline == "controlnet" and params.control is not None:
|
||||
if params.is_control() and params.control is not None:
|
||||
cnet_path = path.join(
|
||||
server.model_path, "control", f"{params.control.name}.onnx"
|
||||
)
|
||||
|
@ -282,7 +282,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# make sure a UNet has been loaded
|
||||
if "unet" not in components:
|
||||
if not params.is_xl() and "unet" not in components:
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
|
@ -298,7 +298,7 @@ def load_pipeline(
|
|||
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
||||
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
||||
|
||||
if path.exists(vae):
|
||||
if not params.is_xl() and path.exists(vae):
|
||||
logger.debug("loading VAE from %s", vae)
|
||||
components["vae"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -307,7 +307,9 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||
elif (
|
||||
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
|
||||
):
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -327,7 +329,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# additional options for panorama pipeline
|
||||
if pipeline == "panorama":
|
||||
if params.is_panorama():
|
||||
components["window"] = params.tiles // 8
|
||||
components["stride"] = params.stride // 8
|
||||
|
||||
|
@ -346,18 +348,21 @@ def load_pipeline(
|
|||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
optimize_pipeline(server, pipe)
|
||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
||||
|
||||
if not params.is_xl():
|
||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
||||
|
||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
|
||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||
if hasattr(pipe, "vae_encoder"):
|
||||
|
||||
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
|
||||
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
||||
|
||||
# update panorama params
|
||||
if pipeline == "panorama":
|
||||
if params.is_panorama():
|
||||
latent_window = params.tiles // 8
|
||||
latent_stride = params.stride // 8
|
||||
|
||||
|
|
|
@ -43,8 +43,7 @@ def run_txt2img_pipeline(
|
|||
highres: HighresParams,
|
||||
) -> None:
|
||||
# if using panorama, the pipeline will tile itself (views)
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
if pipe_type == "panorama":
|
||||
if params.is_panorama():
|
||||
tile_size = max(params.tiles, size.width, size.height)
|
||||
else:
|
||||
tile_size = params.tiles
|
||||
|
|
|
@ -271,9 +271,18 @@ class ImageParams:
|
|||
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
||||
return group
|
||||
|
||||
def lpw(self):
|
||||
def is_control(self):
|
||||
return self.pipeline == "controlnet"
|
||||
|
||||
def is_lpw(self):
|
||||
return self.pipeline == "lpw"
|
||||
|
||||
def is_panorama(self):
|
||||
return self.pipeline == "panorama"
|
||||
|
||||
def is_xl(self):
|
||||
return self.pipeline.endswith("-sdxl")
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
"model": self.model,
|
||||
|
|
Loading…
Reference in New Issue