1
0
Fork 0

use conversion dest path when applying additional nets

This commit is contained in:
Sean Sube 2023-03-18 11:34:05 -05:00
parent 1f6105a8fe
commit f465120cad
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 7 additions and 14 deletions

View File

@ -248,13 +248,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
converted = False converted = False
if model_format in model_formats_original: if model_format in model_formats_original:
converted, _dest = convert_diffusion_original( converted, dest = convert_diffusion_original(
ctx, ctx,
model, model,
source, source,
) )
else: else:
converted, _dest = convert_diffusion_diffusers( converted, dest = convert_diffusion_diffusers(
ctx, ctx,
model, model,
source, source,
@ -272,8 +272,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if "text_encoder" not in blend_models: if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model( blend_models["text_encoder"] = load_model(
path.join( path.join(
ctx.model_path, dest,
model,
"text_encoder", "text_encoder",
"model.onnx", "model.onnx",
) )
@ -283,7 +282,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
blend_models[ blend_models[
"tokenizer" "tokenizer"
] = CLIPTokenizer.from_pretrained( ] = CLIPTokenizer.from_pretrained(
path.join(ctx.model_path, model), dest,
subfolder="tokenizer", subfolder="tokenizer",
) )
@ -292,7 +291,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_format = inversion.get("format", None) inversion_format = inversion.get("format", None)
inversion_source = fetch_model( inversion_source = fetch_model(
ctx, ctx,
f"{name}-inversion-{inversion_name}", inversion_name,
inversion_source, inversion_source,
dest=inversion_dest, dest=inversion_dest,
) )
@ -317,8 +316,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if "text_encoder" not in blend_models: if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model( blend_models["text_encoder"] = load_model(
path.join( path.join(
ctx.model_path, dest,
model,
"text_encoder", "text_encoder",
"model.onnx", "model.onnx",
) )
@ -326,9 +324,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if "unet" not in blend_models: if "unet" not in blend_models:
blend_models["text_encoder"] = load_model( blend_models["text_encoder"] = load_model(
path.join( path.join(dest, "unet", "model.onnx")
ctx.model_path, model, "unet", "model.onnx"
)
) )
# load models if not loaded yet # load models if not loaded yet

View File

@ -62,10 +62,7 @@ def blend_loras(
model_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],
): ):
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_count = len(loras)
lora_models = [load_file(name) for name, _weight in loras] lora_models = [load_file(name) for name, _weight in loras]
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
if model_type == "text_encoder": if model_type == "text_encoder":
lora_prefix = "lora_te_" lora_prefix = "lora_te_"