apply lint
This commit is contained in:
parent
e3bf04ab8f
commit
af62c1c3b6
|
@ -220,7 +220,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
source = network["source"]
|
source = network["source"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dest = fetch_model(ctx, name, source, dest=path.join(ctx.model_path, network_type), format=network_format)
|
dest = fetch_model(
|
||||||
|
ctx,
|
||||||
|
name,
|
||||||
|
source,
|
||||||
|
dest=path.join(ctx.model_path, network_type),
|
||||||
|
format=network_format,
|
||||||
|
)
|
||||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error fetching network %s", name)
|
logger.exception("error fetching network %s", name)
|
||||||
|
@ -264,10 +270,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
for inversion in model.get("inversions", []):
|
for inversion in model.get("inversions", []):
|
||||||
if "text_encoder" not in blend_models:
|
if "text_encoder" not in blend_models:
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
blend_models["text_encoder"] = load_model(
|
||||||
|
path.join(
|
||||||
|
ctx.model_path,
|
||||||
|
model,
|
||||||
|
"text_encoder",
|
||||||
|
"model.onnx",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if "tokenizer" not in blend_models:
|
if "tokenizer" not in blend_models:
|
||||||
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")
|
blend_models[
|
||||||
|
"tokenizer"
|
||||||
|
] = CLIPTokenizer.from_pretrained(
|
||||||
|
path.join(ctx.model_path, model),
|
||||||
|
subfolder="tokenizer",
|
||||||
|
)
|
||||||
|
|
||||||
inversion_name = inversion["name"]
|
inversion_name = inversion["name"]
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
|
@ -293,10 +311,21 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
for lora in model.get("loras", []):
|
for lora in model.get("loras", []):
|
||||||
if "text_encoder" not in blend_models:
|
if "text_encoder" not in blend_models:
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
blend_models["text_encoder"] = load_model(
|
||||||
|
path.join(
|
||||||
|
ctx.model_path,
|
||||||
|
model,
|
||||||
|
"text_encoder",
|
||||||
|
"model.onnx",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if "unet" not in blend_models:
|
if "unet" not in blend_models:
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))
|
blend_models["text_encoder"] = load_model(
|
||||||
|
path.join(
|
||||||
|
ctx.model_path, model, "unet", "model.onnx"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# load models if not loaded yet
|
# load models if not loaded yet
|
||||||
lora_name = lora["name"]
|
lora_name = lora["name"]
|
||||||
|
@ -325,8 +354,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
for name in ["text_encoder", "unet"]:
|
for name in ["text_encoder", "unet"]:
|
||||||
if name in blend_models:
|
if name in blend_models:
|
||||||
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
|
dest_path = path.join(
|
||||||
logger.debug("saving blended %s model to %s", name, dest_path)
|
ctx.model_path, model, name, "model.onnx"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"saving blended %s model to %s", name, dest_path
|
||||||
|
)
|
||||||
save_model(
|
save_model(
|
||||||
blend_models[name],
|
blend_models[name],
|
||||||
dest_path,
|
dest_path,
|
||||||
|
@ -335,7 +368,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
location="weights.pb",
|
location="weights.pb",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting diffusion model %s",
|
"error converting diffusion model %s",
|
||||||
|
|
|
@ -1706,4 +1706,3 @@ def convert_diffusion_original(
|
||||||
|
|
||||||
logger.info("ONNX pipeline saved to %s", name)
|
logger.info("ONNX pipeline saved to %s", name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue