diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 1198d85d..eeea47d4 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -201,11 +201,6 @@ def load_pipeline( vae_components = load_vae(server, device, model, params) components.update(vae_components) - # additional options for panorama pipeline - if params.is_panorama(): - components["window"] = params.tiles // 8 - components["stride"] = params.stride // 8 - pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) pipe = pipeline_class.from_pretrained( @@ -228,8 +223,15 @@ def load_pipeline( components["text_encoder_2_session"], pipe ) + if "tokenizer" in components: + pipe.tokenizer = components["tokenizer"] + + if "tokenizer_2" in components: + pipe.tokenizer_2 = components["tokenizer_2"] + if "unet_session" in components: # unload old UNet + logger.debug("unloading previous Unet") pipe.unet = None run_gc([device]) @@ -300,20 +302,25 @@ def load_text_encoders( torch_dtype, params: ImageParams, ): + text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) tokenizer = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer", torch_dtype=torch_dtype, ) - components = {} - components["tokenizer"] = tokenizer - - text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) - text_encoder_2 = None + components = { + "tokenizer": tokenizer, + } if params.is_xl(): text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) + tokenizer_2 = CLIPTokenizer.from_pretrained( + model, + subfolder="tokenizer_2", + torch_dtype=torch_dtype, + ) + components["tokenizer_2"] = tokenizer_2 # blend embeddings, if any if embeddings is not None and len(embeddings) > 0: @@ -339,6 +346,23 @@ def load_text_encoders( ) ), ) + components["tokenizer"] = tokenizer + + if params.is_xl(): + text_encoder_2, tokenizer_2 = blend_textual_inversions( + server, + text_encoder_2, + tokenizer_2, + list( + zip( + embedding_models, + embedding_weights, + embedding_names, + [None] * len(embedding_models), + ) + ), + ) + components["tokenizer_2"] = tokenizer_2 # blend LoRAs, if any if loras is not None and len(loras) > 0: