1
0
Fork 0

fix(api): download additional networks to their own subdir in models

This commit is contained in:
Sean Sube 2023-03-18 07:07:05 -05:00
parent 84bd852837
commit 32b2a76a0b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 23 additions and 10 deletions

View File

@ -138,12 +138,16 @@ base_models: Models = {
def fetch_model(
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
ctx: ConversionContext,
name: str,
source: str,
dest: Optional[str] = None,
format: Optional[str] = None,
) -> str:
cache_path = path.join(ctx.cache_path, name)
cache_path = path.join(dest or ctx.cache_path, name)
# add an extension if possible, some of the conversion code checks for it
if model_format is None:
if format is None:
url = urlparse(source)
ext = path.basename(url.path)
_filename, ext = path.splitext(ext)
@ -152,7 +156,7 @@ def fetch_model(
else:
cache_name = cache_path
else:
cache_name = f"{cache_path}.{model_format}"
cache_name = f"{cache_path}.{format}"
if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
@ -199,7 +203,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source = model["source"]
try:
dest = fetch_model(ctx, name, source, model_format=model_format)
dest = fetch_model(ctx, name, source, format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest)
except Exception:
logger.exception("error fetching source %s", name)
@ -216,7 +220,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
ctx, name, model["source"], format=model_format
)
if model_format in model_formats_original:
@ -235,6 +239,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(ctx.model_path, "inversion")
lora_dest = path.join(ctx.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
@ -246,7 +253,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_source = inversion["source"]
inversion_format = inversion.get("format", "embeddings")
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
ctx,
f"{name}-inversion-{inversion_name}",
inversion_source,
dest=inversion_dest,
)
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
@ -272,7 +282,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
ctx, f"{name}-lora-{lora_name}", lora_source
ctx,
f"{name}-lora-{lora_name}",
lora_source
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
@ -321,7 +334,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
ctx, name, model["source"], format=model_format
)
convert_upscale_resrgan(ctx, model, source)
except Exception:
@ -341,7 +354,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model_format = source_format(model)
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
ctx, name, model["source"], format=model_format
)
convert_correction_gfpgan(ctx, model, source)
except Exception: