1
0
Fork 0

fix(api): build SDXL pipeline to avoid optimum patches

This commit is contained in:
Sean Sube 2023-11-24 17:02:21 -06:00
parent d7c95a4a4f
commit 3f3811e16a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 20 additions and 46 deletions

View File

@ -217,55 +217,29 @@ def load_pipeline(
components.update(vae_components)
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
safety_checker=None,
torch_dtype=torch_dtype,
**components,
)
if pipe.scheduler != scheduler:
pipe.scheduler = scheduler
# make sure XL models are actually being used
if "text_encoder_session" in components:
pipe.text_encoder = ORTModelTextEncoder(
components["text_encoder_session"], pipe
)
if "text_encoder_2_session" in components:
pipe.text_encoder_2 = ORTModelTextEncoder(
components["text_encoder_2_session"], pipe
)
if "tokenizer" in components:
pipe.tokenizer = components["tokenizer"]
if "tokenizer_2" in components:
pipe.tokenizer_2 = components["tokenizer_2"]
if "unet_session" in components:
# unload old UNet
logger.debug("unloading previous Unet")
pipe.unet = None
run_gc([device])
# attach correct one
pipe.unet = ORTModelUnet(components["unet_session"], pipe)
if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
if params.is_xl():
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class(
components["vae_decoder_session"],
pipe,
components["text_encoder_session"],
components["unet_session"],
{}, # empty config
components["tokenizer"],
scheduler,
vae_encoder_session=components.get("vae_encoder_session", None),
text_encoder_2_session=components.get("text_encoder_2_session", None),
tokenizer_2=components.get("tokenizer_2", None),
)
if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"],
pipe,
else:
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
safety_checker=None,
torch_dtype=torch_dtype,
**components,
)
if not server.show_progress: