fix(api): blend embeddings into second tokenizer/text encoder for SDXL
This commit is contained in:
parent
e338fcd0e0
commit
fc02fa6be1
|
@ -201,11 +201,6 @@ def load_pipeline(
|
|||
vae_components = load_vae(server, device, model, params)
|
||||
components.update(vae_components)
|
||||
|
||||
# additional options for panorama pipeline
|
||||
if params.is_panorama():
|
||||
components["window"] = params.tiles // 8
|
||||
components["stride"] = params.stride // 8
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
|
@ -228,8 +223,15 @@ def load_pipeline(
|
|||
components["text_encoder_2_session"], pipe
|
||||
)
|
||||
|
||||
if "tokenizer" in components:
|
||||
pipe.tokenizer = components["tokenizer"]
|
||||
|
||||
if "tokenizer_2" in components:
|
||||
pipe.tokenizer_2 = components["tokenizer_2"]
|
||||
|
||||
if "unet_session" in components:
|
||||
# unload old UNet
|
||||
logger.debug("unloading previous Unet")
|
||||
pipe.unet = None
|
||||
run_gc([device])
|
||||
|
||||
|
@ -300,20 +302,25 @@ def load_text_encoders(
|
|||
torch_dtype,
|
||||
params: ImageParams,
|
||||
):
|
||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
components = {}
|
||||
components["tokenizer"] = tokenizer
|
||||
|
||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||
text_encoder_2 = None
|
||||
components = {
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer_2",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
components["tokenizer_2"] = tokenizer_2
|
||||
|
||||
# blend embeddings, if any
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
|
@ -339,6 +346,23 @@ def load_text_encoders(
|
|||
)
|
||||
),
|
||||
)
|
||||
components["tokenizer"] = tokenizer
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_2, tokenizer_2 = blend_textual_inversions(
|
||||
server,
|
||||
text_encoder_2,
|
||||
tokenizer_2,
|
||||
list(
|
||||
zip(
|
||||
embedding_models,
|
||||
embedding_weights,
|
||||
embedding_names,
|
||||
[None] * len(embedding_models),
|
||||
)
|
||||
),
|
||||
)
|
||||
components["tokenizer_2"] = tokenizer_2
|
||||
|
||||
# blend LoRAs, if any
|
||||
if loras is not None and len(loras) > 0:
|
||||
|
|
Loading…
Reference in New Issue