1
0
Fork 0

fix(api): blend embeddings into second tokenizer/text encoder for SDXL

This commit is contained in:
Sean Sube 2023-09-25 18:24:16 -05:00
parent e338fcd0e0
commit fc02fa6be1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 34 additions and 10 deletions

View File

@ -201,11 +201,6 @@ def load_pipeline(
vae_components = load_vae(server, device, model, params)
components.update(vae_components)
# additional options for panorama pipeline
if params.is_panorama():
components["window"] = params.tiles // 8
components["stride"] = params.stride // 8
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
@ -228,8 +223,15 @@ def load_pipeline(
components["text_encoder_2_session"], pipe
)
if "tokenizer" in components:
pipe.tokenizer = components["tokenizer"]
if "tokenizer_2" in components:
pipe.tokenizer_2 = components["tokenizer_2"]
if "unet_session" in components:
# unload old UNet
logger.debug("unloading previous Unet")
pipe.unet = None
run_gc([device])
@ -300,20 +302,25 @@ def load_text_encoders(
torch_dtype,
params: ImageParams,
):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
components = {}
components["tokenizer"] = tokenizer
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
text_encoder_2 = None
components = {
"tokenizer": tokenizer,
}
if params.is_xl():
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
tokenizer_2 = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer_2",
torch_dtype=torch_dtype,
)
components["tokenizer_2"] = tokenizer_2
# blend embeddings, if any
if embeddings is not None and len(embeddings) > 0:
@ -339,6 +346,23 @@ def load_text_encoders(
)
),
)
components["tokenizer"] = tokenizer
if params.is_xl():
text_encoder_2, tokenizer_2 = blend_textual_inversions(
server,
text_encoder_2,
tokenizer_2,
list(
zip(
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
components["tokenizer_2"] = tokenizer_2
# blend LoRAs, if any
if loras is not None and len(loras) > 0: