diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index e05a46b9..618ea6a5 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -202,7 +202,7 @@ base_models: Models = { } -def convert_source_model(conversion: ConversionContext, model): +def convert_model_source(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] source = model["source"] @@ -215,9 +215,10 @@ def convert_source_model(conversion: ConversionContext, model): logger.info("finished downloading source: %s -> %s", source, dest) -def convert_network_model(conversion: ConversionContext, network): - network_format = source_format(network) +def convert_model_network(conversion: ConversionContext, network): + format = source_format(network) name = network["name"] + model = network["model"] network_type = network["type"] source = network["source"] @@ -226,7 +227,7 @@ def convert_network_model(conversion: ConversionContext, network): conversion, name, source, - format=network_format, + format=format, ) convert_diffusion_control( @@ -241,13 +242,14 @@ def convert_network_model(conversion: ConversionContext, network): name, source, dest=path.join(conversion.model_path, network_type), - format=network_format, + format=format, + embeds=(network_type == "inversion" and model == "concept"), ) logger.info("finished downloading network: %s -> %s", source, dest) -def convert_diffusion_model(conversion: ConversionContext, model): +def convert_model_diffusion(conversion: ConversionContext, model): # fix up entries with missing prefixes name = fix_diffusion_name(model["name"]) if name != model["name"]: @@ -372,7 +374,7 @@ def convert_diffusion_model(conversion: ConversionContext, model): ) -def convert_upscaling_model(conversion: ConversionContext, model): +def convert_model_upscaling(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] @@ -389,7 +391,7 @@ def convert_upscaling_model(conversion: ConversionContext, model): raise ValueError(name) -def convert_correction_model(conversion: ConversionContext, model): +def convert_model_correction(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] source = fetch_model(conversion, name, model["source"], format=model_format) @@ -413,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping source: %s", name) else: try: - convert_source_model(model) + convert_model_source(model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) @@ -426,7 +428,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping network: %s", name) else: try: - convert_network_model(conversion, model) + convert_model_network(conversion, model) except Exception: logger.exception("error fetching network %s", name) model_errors.append(name) @@ -440,7 +442,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_diffusion_model(conversion, model) + convert_model_diffusion(conversion, model) except Exception: logger.exception( "error converting diffusion model %s", @@ -457,7 +459,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_upscaling_model(conversion, model) + convert_model_upscaling(conversion, model) except Exception: logger.exception( "error converting upscaling model %s", @@ -474,7 +476,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_correction_model(conversion, model) + convert_model_correction(conversion, model) except Exception: logger.exception( "error converting correction model %s", diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py index 0e91dab7..2ec2ae65 100644 --- a/api/onnx_web/convert/client/base.py +++ b/api/onnx_web/convert/client/base.py @@ -11,5 +11,6 @@ class BaseClient: source: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index e56b49e5..706e5c34 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -31,6 +31,7 @@ class CivitaiClient(BaseClient): source: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: """ TODO: download with auth token diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 142e503e..5619b35b 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -24,6 +24,7 @@ class FileClient(BaseClient): uri: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: parts = urlparse(uri) logger.info("loading model from: %s", parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index d994dabe..9d9f840c 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -30,6 +30,7 @@ class HttpClient(BaseClient): uri: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: cache_paths = build_cache_paths( conversion, @@ -46,6 +47,7 @@ class HttpClient(BaseClient): source = remove_prefix(uri, HttpClient.protocol) logger.info("downloading model from: %s", source) elif uri.startswith(HttpClient.insecure_protocol): + source = remove_prefix(uri, HttpClient.insecure_protocol) logger.warning("downloading model from insecure source: %s", source) return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 6d1064fd..6db97928 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -1,14 +1,10 @@ from logging import getLogger -from typing import Any, Optional +from typing import Optional +from huggingface_hub import snapshot_download from huggingface_hub.file_download import hf_hub_download -from ..utils import ( - ConversionContext, - build_cache_paths, - get_first_exists, - remove_prefix, -) +from ..utils import ConversionContext, remove_prefix from .base import BaseClient logger = getLogger(__name__) @@ -18,11 +14,9 @@ class HuggingfaceClient(BaseClient): name = "huggingface" protocol = "huggingface://" - download: Any token: Optional[str] - def __init__(self, token: Optional[str] = None, download=hf_hub_download): - self.download = download + def __init__(self, token: Optional[str] = None): self.token = token def download( @@ -32,39 +26,22 @@ class HuggingfaceClient(BaseClient): source: str, format: Optional[str] = None, dest: Optional[str] = None, + embeds: bool = False, + **kwargs, ) -> str: - """ - TODO: download with auth - TODO: set fetch and filename - if network_type == "inversion" and network_model == "concept": - """ - hf_hub_fetch = True - hf_hub_filename = "learned_embeds.bin" - - cache_paths = build_cache_paths( - conversion, - name, - client=HuggingfaceClient.name, - format=format, - dest=dest, - ) - cached = get_first_exists(cache_paths) - if cached: - return cached - source = remove_prefix(source, HuggingfaceClient.protocol) logger.info("downloading model from Huggingface Hub: %s", source) - if hf_hub_fetch: - return ( - hf_hub_download( - repo_id=source, - filename=hf_hub_filename, - cache_dir=cache_paths[0], - force_filename=f"{name}.bin", - ), - False, + if embeds: + return hf_hub_download( + repo_id=source, + filename="learned_embeds.bin", + cache_dir=conversion.cache_path, + force_filename=f"{name}.bin", + token=self.token, ) else: - # TODO: download pretrained because load doesn't call from_pretrained anymore - return source + return snapshot_download( + repo_id=source, + token=self.token, + )