fix(api): build SDXL pipeline to avoid optimum patches
This commit is contained in:
parent
d7c95a4a4f
commit
3f3811e16a
|
@ -217,6 +217,21 @@ def load_pipeline(
|
|||
components.update(vae_components)
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
|
||||
if params.is_xl():
|
||||
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class(
|
||||
components["vae_decoder_session"],
|
||||
components["text_encoder_session"],
|
||||
components["unet_session"],
|
||||
{}, # empty config
|
||||
components["tokenizer"],
|
||||
scheduler,
|
||||
vae_encoder_session=components.get("vae_encoder_session", None),
|
||||
text_encoder_2_session=components.get("text_encoder_2_session", None),
|
||||
tokenizer_2=components.get("tokenizer_2", None),
|
||||
)
|
||||
else:
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
model,
|
||||
|
@ -227,47 +242,6 @@ def load_pipeline(
|
|||
**components,
|
||||
)
|
||||
|
||||
if pipe.scheduler != scheduler:
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
# make sure XL models are actually being used
|
||||
if "text_encoder_session" in components:
|
||||
pipe.text_encoder = ORTModelTextEncoder(
|
||||
components["text_encoder_session"], pipe
|
||||
)
|
||||
|
||||
if "text_encoder_2_session" in components:
|
||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
||||
components["text_encoder_2_session"], pipe
|
||||
)
|
||||
|
||||
if "tokenizer" in components:
|
||||
pipe.tokenizer = components["tokenizer"]
|
||||
|
||||
if "tokenizer_2" in components:
|
||||
pipe.tokenizer_2 = components["tokenizer_2"]
|
||||
|
||||
if "unet_session" in components:
|
||||
# unload old UNet
|
||||
logger.debug("unloading previous Unet")
|
||||
pipe.unet = None
|
||||
run_gc([device])
|
||||
|
||||
# attach correct one
|
||||
pipe.unet = ORTModelUnet(components["unet_session"], pipe)
|
||||
|
||||
if "vae_decoder_session" in components:
|
||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||
components["vae_decoder_session"],
|
||||
pipe,
|
||||
)
|
||||
|
||||
if "vae_encoder_session" in components:
|
||||
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||
components["vae_encoder_session"],
|
||||
pipe,
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue