fix(api): blend embeddings into second tokenizer/text encoder for SDXL
This commit is contained in:
parent
e338fcd0e0
commit
fc02fa6be1
|
@ -201,11 +201,6 @@ def load_pipeline(
|
||||||
vae_components = load_vae(server, device, model, params)
|
vae_components = load_vae(server, device, model, params)
|
||||||
components.update(vae_components)
|
components.update(vae_components)
|
||||||
|
|
||||||
# additional options for panorama pipeline
|
|
||||||
if params.is_panorama():
|
|
||||||
components["window"] = params.tiles // 8
|
|
||||||
components["stride"] = params.stride // 8
|
|
||||||
|
|
||||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||||
pipe = pipeline_class.from_pretrained(
|
pipe = pipeline_class.from_pretrained(
|
||||||
|
@ -228,8 +223,15 @@ def load_pipeline(
|
||||||
components["text_encoder_2_session"], pipe
|
components["text_encoder_2_session"], pipe
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "tokenizer" in components:
|
||||||
|
pipe.tokenizer = components["tokenizer"]
|
||||||
|
|
||||||
|
if "tokenizer_2" in components:
|
||||||
|
pipe.tokenizer_2 = components["tokenizer_2"]
|
||||||
|
|
||||||
if "unet_session" in components:
|
if "unet_session" in components:
|
||||||
# unload old UNet
|
# unload old UNet
|
||||||
|
logger.debug("unloading previous Unet")
|
||||||
pipe.unet = None
|
pipe.unet = None
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
|
@ -300,20 +302,25 @@ def load_text_encoders(
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
):
|
):
|
||||||
|
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
model,
|
model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
components = {}
|
components = {
|
||||||
components["tokenizer"] = tokenizer
|
"tokenizer": tokenizer,
|
||||||
|
}
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
|
||||||
text_encoder_2 = None
|
|
||||||
|
|
||||||
if params.is_xl():
|
if params.is_xl():
|
||||||
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||||
|
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||||
|
model,
|
||||||
|
subfolder="tokenizer_2",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
components["tokenizer_2"] = tokenizer_2
|
||||||
|
|
||||||
# blend embeddings, if any
|
# blend embeddings, if any
|
||||||
if embeddings is not None and len(embeddings) > 0:
|
if embeddings is not None and len(embeddings) > 0:
|
||||||
|
@ -339,6 +346,23 @@ def load_text_encoders(
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
components["tokenizer"] = tokenizer
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2, tokenizer_2 = blend_textual_inversions(
|
||||||
|
server,
|
||||||
|
text_encoder_2,
|
||||||
|
tokenizer_2,
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
embedding_models,
|
||||||
|
embedding_weights,
|
||||||
|
embedding_names,
|
||||||
|
[None] * len(embedding_models),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
components["tokenizer_2"] = tokenizer_2
|
||||||
|
|
||||||
# blend LoRAs, if any
|
# blend LoRAs, if any
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
|
|
Loading…
Reference in New Issue