From ae3bcf3b8ba7de8f58c367bf7dda19f972409fb0 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 20:32:21 -0500 Subject: [PATCH] fix(api): add HF hub download to fetch logic for Inversion concepts --- api/onnx_web/convert/__main__.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 5ec93076..6b63b05a 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -144,8 +144,11 @@ def fetch_model( source: str, dest: Optional[str] = None, format: Optional[str] = None, + hf_hub_fetch: bool = False, + hf_hub_filename: Optional[str] = None, ) -> str: - cache_path = path.join(dest or ctx.cache_path, name) + cache_path = dest or ctx.cache_path + cache_name = path.join(cache_path, name) # add an extension if possible, some of the conversion code checks for it if format is None: @@ -153,11 +156,11 @@ def fetch_model( ext = path.basename(url.path) _filename, ext = path.splitext(ext) if ext is not None: - cache_name = cache_path + ext + cache_name = cache_name + ext else: - cache_name = cache_path + cache_name = cache_name else: - cache_name = f"{cache_path}.{format}" + cache_name = f"{cache_name}.{format}" if path.exists(cache_name): logger.debug("model already exists in cache, skipping fetch") @@ -176,7 +179,15 @@ def fetch_model( hub_source = remove_prefix(source, model_source_huggingface) logger.info("downloading model from Huggingface Hub: %s", hub_source) # from_pretrained has a bunch of useful logic that snapshot_download by itself down not - return hub_source + if hf_hub_fetch: + return hf_hub_download( + repo_id=source, + filename=hf_hub_filename, + cache_dir=cache_path, + force_filename=f"{name}.bin", + ) + else: + return hub_source elif source.startswith("https://"): logger.info("downloading model from: %s", source) return download_progress([(source, cache_name)]) @@ -223,8 +234,14 @@ def convert_models(ctx: ConversionContext, args, models: Models): try: if network_type == "inversion" and network_model == "concept": - dest = hf_hub_download( - repo_id=source, filename="learned_embeds.bin" + dest = fetch_model( + ctx, + name, + source, + dest=path.join(ctx.model_path, network_type), + format=network_format, + hf_hub_fetch=True, + hf_hub_filename="learned_embeds.bin", ) else: dest = fetch_model(