diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 2d5577a1..453566d9 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -207,7 +207,7 @@ def fetch_model( format: Optional[str] = None, hf_hub_fetch: bool = False, hf_hub_filename: Optional[str] = None, -) -> str: +) -> Tuple[str, bool]: cache_path = dest or conversion.cache_path cache_name = path.join(cache_path, name) @@ -223,7 +223,7 @@ def fetch_model( if path.exists(cache_name): logger.debug("model already exists in cache, skipping fetch") - return cache_name + return cache_name, False for proto in model_sources: api_name, api_root = model_sources.get(proto) @@ -232,33 +232,36 @@ def fetch_model( logger.info( "downloading model from %s: %s -> %s", api_name, api_source, cache_name ) - return download_progress([(api_source, cache_name)]) + return download_progress([(api_source, cache_name)]), False if source.startswith(model_source_huggingface): hub_source = remove_prefix(source, model_source_huggingface) logger.info("downloading model from Huggingface Hub: %s", hub_source) # from_pretrained has a bunch of useful logic that snapshot_download by itself down not if hf_hub_fetch: - return hf_hub_download( - repo_id=hub_source, - filename=hf_hub_filename, - cache_dir=cache_path, - force_filename=f"{name}.bin", + return ( + hf_hub_download( + repo_id=hub_source, + filename=hf_hub_filename, + cache_dir=cache_path, + force_filename=f"{name}.bin", + ), + False, ) else: - return hub_source + return hub_source, True elif source.startswith("https://"): logger.info("downloading model from: %s", source) - return download_progress([(source, cache_name)]) + return download_progress([(source, cache_name)]), False elif source.startswith("http://"): logger.warning("downloading model from insecure source: %s", source) - return download_progress([(source, cache_name)]) + return download_progress([(source, cache_name)]), False elif source.startswith(path.sep) or source.startswith("."): logger.info("using local model: %s", source) - return source + return source, False else: logger.info("unknown model location, using path as provided: %s", source) - return source + return source, False def convert_models(conversion: ConversionContext, args, models: Models): @@ -280,7 +283,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): if "dest" in model: dest_path = path.join(conversion.model_path, model["dest"]) - dest = fetch_model( + dest, hf = fetch_model( conversion, name, source, format=model_format, dest=dest_path ) logger.info("finished downloading source: %s -> %s", source, dest) @@ -302,7 +305,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): try: if network_type == "control": - dest = fetch_model( + dest, hf = fetch_model( conversion, name, source, @@ -315,7 +318,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): dest, ) if network_type == "inversion" and network_model == "concept": - dest = fetch_model( + dest, hf = fetch_model( conversion, name, source, @@ -325,7 +328,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): hf_hub_filename="learned_embeds.bin", ) else: - dest = fetch_model( + dest, hf = fetch_model( conversion, name, source, @@ -349,7 +352,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_format = source_format(model) try: - source = fetch_model( + source, hf = fetch_model( conversion, name, model["source"], format=model_format ) @@ -358,6 +361,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model, source, model_format, + hf=hf, ) # make sure blending only happens once, not every run @@ -389,7 +393,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): inversion_name = inversion["name"] inversion_source = inversion["source"] inversion_format = inversion.get("format", None) - inversion_source = fetch_model( + inversion_source, hf = fetch_model( conversion, inversion_name, inversion_source, @@ -430,7 +434,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): # load models if not loaded yet lora_name = lora["name"] lora_source = lora["source"] - lora_source = fetch_model( + lora_source, hf = fetch_model( conversion, f"{name}-lora-{lora_name}", lora_source, @@ -489,7 +493,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_format = source_format(model) try: - source = fetch_model( + source, hf = fetch_model( conversion, name, model["source"], format=model_format ) model_type = model.get("model", "resrgan") @@ -521,7 +525,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): else: model_format = source_format(model) try: - source = fetch_model( + source, hf = fetch_model( conversion, name, model["source"], format=model_format ) model_type = model.get("model", "gfpgan") diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index f27b04cc..6d085c6e 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -247,6 +247,7 @@ def convert_diffusion_diffusers( model: Dict, source: str, format: str, + hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -316,6 +317,13 @@ def convert_diffusion_diffusers( pipeline_class=pipe_class, **pipe_args, ).to(device, torch_dtype=dtype) + elif hf: + logger.debug("downloading pretrained model from Huggingface hub: %s", source) + pipeline = pipe_class.from_pretrained( + source, + torch_dtype=dtype, + use_auth_token=conversion.token, + ).to(device) else: logger.warning("pipeline source not found or not recognized: %s", source) raise ValueError(f"pipeline source not found or not recognized: {source}")