1
0
Fork 0

fix(api): only run SDXL LoRA node matching on XL models

This commit is contained in:
Sean Sube 2023-08-29 19:03:57 -05:00
parent 5d0d904463
commit ea9023c2eb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 8 additions and 4 deletions

View File

@ -166,6 +166,7 @@ def blend_loras(
loras: List[Tuple[str, float]], loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None, model_index: Optional[int] = None,
xl: Optional[bool] = False,
): ):
# always load to CPU for blending # always load to CPU for blending
device = torch.device("cpu") device = torch.device("cpu")
@ -394,14 +395,14 @@ def blend_loras(
blended[base_key] = np_weights blended[base_key] = np_weights
# rewrite node names for XL # rewrite node names for XL
if xl:
nodes = list(base_model.graph.node) nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes) blended = fix_xl_names(blended, nodes)
logger.trace( logger.trace(
"updating %s of %s initializers, %s missed", "updating %s of %s initializers",
len(blended.keys()), len(blended.keys()),
len(base_model.graph.initializer), len(base_model.graph.initializer),
len(nodes),
) )
fixed_initializer_names = [ fixed_initializer_names = [

View File

@ -247,6 +247,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"text_encoder", "text_encoder",
1 if params.is_xl() else None, 1 if params.is_xl() else None,
params.is_xl(),
) )
(text_encoder, text_encoder_data) = buffer_external_data_tensors( (text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder text_encoder
@ -284,6 +285,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"text_encoder", "text_encoder",
2, 2,
params.is_xl()
) )
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors( (text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2 text_encoder2
@ -311,6 +313,7 @@ def load_pipeline(
unet, unet,
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"unet", "unet",
xl=params.is_xl(),
) )
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)