fix(api): only run SDXL LoRA node matching on XL models
This commit is contained in:
parent
5d0d904463
commit
ea9023c2eb
|
@ -166,6 +166,7 @@ def blend_loras(
|
|||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
model_index: Optional[int] = None,
|
||||
xl: Optional[bool] = False,
|
||||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
|
@ -394,14 +395,14 @@ def blend_loras(
|
|||
blended[base_key] = np_weights
|
||||
|
||||
# rewrite node names for XL
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
if xl:
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
|
||||
logger.trace(
|
||||
"updating %s of %s initializers, %s missed",
|
||||
"updating %s of %s initializers",
|
||||
len(blended.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
len(nodes),
|
||||
)
|
||||
|
||||
fixed_initializer_names = [
|
||||
|
|
|
@ -247,6 +247,7 @@ def load_pipeline(
|
|||
list(zip(lora_models, lora_weights)),
|
||||
"text_encoder",
|
||||
1 if params.is_xl() else None,
|
||||
params.is_xl(),
|
||||
)
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||
text_encoder
|
||||
|
@ -284,6 +285,7 @@ def load_pipeline(
|
|||
list(zip(lora_models, lora_weights)),
|
||||
"text_encoder",
|
||||
2,
|
||||
params.is_xl()
|
||||
)
|
||||
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
||||
text_encoder2
|
||||
|
@ -311,6 +313,7 @@ def load_pipeline(
|
|||
unet,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"unet",
|
||||
xl=params.is_xl(),
|
||||
)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
|
|
Loading…
Reference in New Issue