1
0
Fork 0

fix(api): only run SDXL LoRA node matching on XL models

This commit is contained in:
Sean Sube 2023-08-29 19:03:57 -05:00
parent 5d0d904463
commit ea9023c2eb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 8 additions and 4 deletions

View File

@ -166,6 +166,7 @@ def blend_loras(
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
@ -394,14 +395,14 @@ def blend_loras(
blended[base_key] = np_weights
# rewrite node names for XL
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace(
"updating %s of %s initializers, %s missed",
"updating %s of %s initializers",
len(blended.keys()),
len(base_model.graph.initializer),
len(nodes),
)
fixed_initializer_names = [

View File

@ -247,6 +247,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)),
"text_encoder",
1 if params.is_xl() else None,
params.is_xl(),
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
@ -284,6 +285,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl()
)
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2
@ -311,6 +313,7 @@ def load_pipeline(
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)