fix(api): only run SDXL LoRA node matching on XL models
This commit is contained in:
parent
5d0d904463
commit
ea9023c2eb
|
@ -166,6 +166,7 @@ def blend_loras(
|
||||||
loras: List[Tuple[str, float]],
|
loras: List[Tuple[str, float]],
|
||||||
model_type: Literal["text_encoder", "unet"],
|
model_type: Literal["text_encoder", "unet"],
|
||||||
model_index: Optional[int] = None,
|
model_index: Optional[int] = None,
|
||||||
|
xl: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -394,14 +395,14 @@ def blend_loras(
|
||||||
blended[base_key] = np_weights
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
# rewrite node names for XL
|
# rewrite node names for XL
|
||||||
|
if xl:
|
||||||
nodes = list(base_model.graph.node)
|
nodes = list(base_model.graph.node)
|
||||||
blended = fix_xl_names(blended, nodes)
|
blended = fix_xl_names(blended, nodes)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"updating %s of %s initializers, %s missed",
|
"updating %s of %s initializers",
|
||||||
len(blended.keys()),
|
len(blended.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
len(nodes),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
|
|
|
@ -247,6 +247,7 @@ def load_pipeline(
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"text_encoder",
|
"text_encoder",
|
||||||
1 if params.is_xl() else None,
|
1 if params.is_xl() else None,
|
||||||
|
params.is_xl(),
|
||||||
)
|
)
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||||
text_encoder
|
text_encoder
|
||||||
|
@ -284,6 +285,7 @@ def load_pipeline(
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"text_encoder",
|
"text_encoder",
|
||||||
2,
|
2,
|
||||||
|
params.is_xl()
|
||||||
)
|
)
|
||||||
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
||||||
text_encoder2
|
text_encoder2
|
||||||
|
@ -311,6 +313,7 @@ def load_pipeline(
|
||||||
unet,
|
unet,
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"unet",
|
"unet",
|
||||||
|
xl=params.is_xl(),
|
||||||
)
|
)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
|
|
Loading…
Reference in New Issue