1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-18 07:41:29 -05:00
parent e3bf04ab8f
commit af62c1c3b6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 40 additions and 9 deletions

View File

@ -220,7 +220,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source = network["source"] source = network["source"]
try: try:
dest = fetch_model(ctx, name, source, dest=path.join(ctx.model_path, network_type), format=network_format) dest = fetch_model(
ctx,
name,
source,
dest=path.join(ctx.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest) logger.info("finished downloading network: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching network %s", name) logger.exception("error fetching network %s", name)
@ -264,10 +270,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []): for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models: if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx")) blend_models["text_encoder"] = load_model(
path.join(
ctx.model_path,
model,
"text_encoder",
"model.onnx",
)
)
if "tokenizer" not in blend_models: if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer") blend_models[
"tokenizer"
] = CLIPTokenizer.from_pretrained(
path.join(ctx.model_path, model),
subfolder="tokenizer",
)
inversion_name = inversion["name"] inversion_name = inversion["name"]
inversion_source = inversion["source"] inversion_source = inversion["source"]
@ -293,10 +311,21 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for lora in model.get("loras", []): for lora in model.get("loras", []):
if "text_encoder" not in blend_models: if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx")) blend_models["text_encoder"] = load_model(
path.join(
ctx.model_path,
model,
"text_encoder",
"model.onnx",
)
)
if "unet" not in blend_models: if "unet" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx")) blend_models["text_encoder"] = load_model(
path.join(
ctx.model_path, model, "unet", "model.onnx"
)
)
# load models if not loaded yet # load models if not loaded yet
lora_name = lora["name"] lora_name = lora["name"]
@ -325,8 +354,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for name in ["text_encoder", "unet"]: for name in ["text_encoder", "unet"]:
if name in blend_models: if name in blend_models:
dest_path = path.join(ctx.model_path, model, name, "model.onnx") dest_path = path.join(
logger.debug("saving blended %s model to %s", name, dest_path) ctx.model_path, model, name, "model.onnx"
)
logger.debug(
"saving blended %s model to %s", name, dest_path
)
save_model( save_model(
blend_models[name], blend_models[name],
dest_path, dest_path,
@ -335,7 +368,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
location="weights.pb", location="weights.pb",
) )
except Exception: except Exception:
logger.exception( logger.exception(
"error converting diffusion model %s", "error converting diffusion model %s",

View File

@ -1706,4 +1706,3 @@ def convert_diffusion_original(
logger.info("ONNX pipeline saved to %s", name) logger.info("ONNX pipeline saved to %s", name)
return result return result