diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index ad55051c..9793b1ca 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -138,12 +138,16 @@ base_models: Models = { def fetch_model( - ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None + ctx: ConversionContext, + name: str, + source: str, + dest: Optional[str] = None, + format: Optional[str] = None, ) -> str: - cache_path = path.join(ctx.cache_path, name) + cache_path = path.join(dest or ctx.cache_path, name) # add an extension if possible, some of the conversion code checks for it - if model_format is None: + if format is None: url = urlparse(source) ext = path.basename(url.path) _filename, ext = path.splitext(ext) @@ -152,7 +156,7 @@ def fetch_model( else: cache_name = cache_path else: - cache_name = f"{cache_path}.{model_format}" + cache_name = f"{cache_path}.{format}" if path.exists(cache_name): logger.debug("model already exists in cache, skipping fetch") @@ -199,7 +203,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): source = model["source"] try: - dest = fetch_model(ctx, name, source, model_format=model_format) + dest = fetch_model(ctx, name, source, format=model_format) logger.info("finished downloading source: %s -> %s", source, dest) except Exception: logger.exception("error fetching source %s", name) @@ -216,7 +220,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: source = fetch_model( - ctx, name, model["source"], model_format=model_format + ctx, name, model["source"], format=model_format ) if model_format in model_formats_original: @@ -235,6 +239,9 @@ def convert_models(ctx: ConversionContext, args, models: Models): # keep track of which models have been blended blend_models = {} + inversion_dest = path.join(ctx.model_path, "inversion") + lora_dest = path.join(ctx.model_path, "lora") + for inversion in model.get("inversions", []): if "text_encoder" not in blend_models: blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx")) @@ -246,7 +253,10 @@ def convert_models(ctx: ConversionContext, args, models: Models): inversion_source = inversion["source"] inversion_format = inversion.get("format", "embeddings") inversion_source = fetch_model( - ctx, f"{name}-inversion-{inversion_name}", inversion_source + ctx, + f"{name}-inversion-{inversion_name}", + inversion_source, + dest=inversion_dest, ) inversion_token = inversion.get("token", inversion_name) inversion_weight = inversion.get("weight", 1.0) @@ -272,7 +282,10 @@ def convert_models(ctx: ConversionContext, args, models: Models): lora_name = lora["name"] lora_source = lora["source"] lora_source = fetch_model( - ctx, f"{name}-lora-{lora_name}", lora_source + ctx, + f"{name}-lora-{lora_name}", + lora_source + dest=lora_dest, ) lora_weight = lora.get("weight", 1.0) @@ -321,7 +334,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: source = fetch_model( - ctx, name, model["source"], model_format=model_format + ctx, name, model["source"], format=model_format ) convert_upscale_resrgan(ctx, model, source) except Exception: @@ -341,7 +354,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): model_format = source_format(model) try: source = fetch_model( - ctx, name, model["source"], model_format=model_format + ctx, name, model["source"], format=model_format ) convert_correction_gfpgan(ctx, model, source) except Exception: