diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 262bbf44..3bb66b3c 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -3,16 +3,21 @@ from argparse import ArgumentParser from logging import getLogger from os import makedirs, path from sys import exit -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse +from typing import Any, Dict, List, Union -from huggingface_hub.file_download import hf_hub_download from jsonschema import ValidationError, validate from onnx import load_model, save_model from transformers import CLIPTokenizer +from onnx_web.server.plugin import load_plugins + from ..constants import ONNX_MODEL, ONNX_WEIGHTS from ..utils import load_config +from .client.base import BaseClient +from .client.civitai import CivitaiClient +from .client.file import FileClient +from .client.http import HttpClient +from .client.huggingface import HuggingfaceClient from .correction.gfpgan import convert_correction_gfpgan from .diffusion.control import convert_diffusion_control from .diffusion.diffusion import convert_diffusion_diffusers @@ -25,9 +30,8 @@ from .upscaling.swinir import convert_upscaling_swinir from .utils import ( DEFAULT_OPSET, ConversionContext, - download_progress, + fetch_model, fix_diffusion_name, - remove_prefix, source_format, tuple_to_correction, tuple_to_diffusion, @@ -45,16 +49,27 @@ warnings.filterwarnings( ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) -Models = Dict[str, List[Any]] - logger = getLogger(__name__) +ModelDict = Dict[str, Union[float, int, str]] +Models = Dict[str, List[ModelDict]] -model_sources: Dict[str, Tuple[str, str]] = { - "civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), + +model_sources: Dict[str, BaseClient] = { + CivitaiClient.protocol: CivitaiClient, + FileClient.protocol: FileClient, + HttpClient.insecure_protocol: HttpClient, + HttpClient.protocol: HttpClient, + HuggingfaceClient.protocol: HuggingfaceClient, } -model_source_huggingface = "huggingface://" +model_converters: Dict[str, Any] = { + "img2img": convert_diffusion_diffusers, + "img2img-sdxl": convert_diffusion_diffusers_xl, + "inpaint": convert_diffusion_diffusers, + "txt2img": convert_diffusion_diffusers, + "txt2img-sdxl": convert_diffusion_diffusers_xl, +} # recommended models base_models: Models = { @@ -62,15 +77,15 @@ base_models: Models = { # v1.x ( "stable-diffusion-onnx-v1-5", - model_source_huggingface + "runwayml/stable-diffusion-v1-5", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5", ), ( "stable-diffusion-onnx-v1-inpainting", - model_source_huggingface + "runwayml/stable-diffusion-inpainting", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting", ), ( "upscaling-stable-diffusion-x4", - model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", + HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler", True, ), ], @@ -201,69 +216,226 @@ base_models: Models = { } -def fetch_model( - conversion: ConversionContext, - name: str, - source: str, - dest: Optional[str] = None, - format: Optional[str] = None, - hf_hub_fetch: bool = False, - hf_hub_filename: Optional[str] = None, -) -> Tuple[str, bool]: - cache_path = dest or conversion.cache_path - cache_name = path.join(cache_path, name) +def add_model_source(proto: str, client: BaseClient): + global model_sources - # add an extension if possible, some of the conversion code checks for it - if format is None: - url = urlparse(source) - ext = path.basename(url.path) - _filename, ext = path.splitext(ext) - if ext is not None: - cache_name = cache_name + ext + if proto in model_sources: + raise ValueError("protocol has already been taken") + + model_sources[proto] = client + + +def convert_source_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source = model["source"] + + dest_path = None + if "dest" in model: + dest_path = path.join(conversion.model_path, model["dest"]) + + dest, hf = fetch_model( + conversion, name, source, format=model_format, dest=dest_path + ) + logger.info("finished downloading source: %s -> %s", source, dest) + + +def convert_network_model(conversion: ConversionContext, network): + network_format = source_format(network) + network_model = network.get("model", None) + name = network["name"] + network_type = network["type"] + source = network["source"] + + if network_type == "control": + dest, hf = fetch_model( + conversion, + name, + source, + format=network_format, + ) + + convert_diffusion_control( + conversion, + network, + dest, + path.join(conversion.model_path, network_type, name), + ) + if network_type == "inversion" and network_model == "concept": + dest, hf = fetch_model( + conversion, + name, + source, + dest=path.join(conversion.model_path, network_type), + format=network_format, + hf_hub_fetch=True, + hf_hub_filename="learned_embeds.bin", + ) else: - cache_name = f"{cache_name}.{format}" + dest, hf = fetch_model( + conversion, + name, + source, + dest=path.join(conversion.model_path, network_type), + format=network_format, + ) - if path.exists(cache_name): - logger.debug("model already exists in cache, skipping fetch") - return cache_name, False + logger.info("finished downloading network: %s -> %s", source, dest) - for proto in model_sources: - api_name, api_root = model_sources.get(proto) - if source.startswith(proto): - api_source = api_root % (remove_prefix(source, proto)) - logger.info( - "downloading model from %s: %s -> %s", api_name, api_source, cache_name + +def convert_diffusion_model(conversion: ConversionContext, model): + # fix up entries with missing prefixes + name = fix_diffusion_name(model["name"]) + if name != model["name"]: + # update the model in-memory if the name changed + model["name"] = name + + model_format = source_format(model) + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + + pipeline = model.get("pipeline", "txt2img") + converter = model_converters.get(pipeline) + converted, dest = converter( + conversion, + model, + source, + model_format, + hf=hf, + ) + + # make sure blending only happens once, not every run + if converted: + # keep track of which models have been blended + blend_models = {} + + inversion_dest = path.join(conversion.model_path, "inversion") + lora_dest = path.join(conversion.model_path, "lora") + + for inversion in model.get("inversions", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "tokenizer" not in blend_models: + blend_models["tokenizer"] = CLIPTokenizer.from_pretrained( + dest, + subfolder="tokenizer", + ) + + inversion_name = inversion["name"] + inversion_source = inversion["source"] + inversion_format = inversion.get("format", None) + inversion_source, hf = fetch_model( + conversion, + inversion_name, + inversion_source, + dest=inversion_dest, ) - return download_progress([(api_source, cache_name)]), False + inversion_token = inversion.get("token", inversion_name) + inversion_weight = inversion.get("weight", 1.0) - if source.startswith(model_source_huggingface): - hub_source = remove_prefix(source, model_source_huggingface) - logger.info("downloading model from Huggingface Hub: %s", hub_source) - # from_pretrained has a bunch of useful logic that snapshot_download by itself down not - if hf_hub_fetch: - return ( - hf_hub_download( - repo_id=hub_source, - filename=hf_hub_filename, - cache_dir=cache_path, - force_filename=f"{name}.bin", - ), - False, + blend_textual_inversions( + conversion, + blend_models["text_encoder"], + blend_models["tokenizer"], + [ + ( + inversion_source, + inversion_weight, + inversion_token, + inversion_format, + ) + ], ) - else: - return hub_source, True - elif source.startswith("https://"): - logger.info("downloading model from: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith("http://"): - logger.warning("downloading model from insecure source: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith(path.sep) or source.startswith("."): - logger.info("using local model: %s", source) - return source, False + + for lora in model.get("loras", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "unet" not in blend_models: + blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL)) + + # load models if not loaded yet + lora_name = lora["name"] + lora_source = lora["source"] + lora_source, hf = fetch_model( + conversion, + f"{name}-lora-{lora_name}", + lora_source, + dest=lora_dest, + ) + lora_weight = lora.get("weight", 1.0) + + blend_loras( + conversion, + blend_models["text_encoder"], + [(lora_source, lora_weight)], + "text_encoder", + ) + + blend_loras( + conversion, + blend_models["unet"], + [(lora_source, lora_weight)], + "unet", + ) + + if "tokenizer" in blend_models: + dest_path = path.join(dest, "tokenizer") + logger.debug("saving blended tokenizer to %s", dest_path) + blend_models["tokenizer"].save_pretrained(dest_path) + + for name in ["text_encoder", "unet"]: + if name in blend_models: + dest_path = path.join(dest, name, ONNX_MODEL) + logger.debug("saving blended %s model to %s", name, dest_path) + save_model( + blend_models[name], + dest_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=ONNX_WEIGHTS, + ) + + +def convert_upscaling_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "resrgan") + if model_type == "bsrgan": + convert_upscaling_bsrgan(conversion, model, source) + elif model_type == "resrgan": + convert_upscale_resrgan(conversion, model, source) + elif model_type == "swinir": + convert_upscaling_swinir(conversion, model, source) else: - logger.info("unknown model location, using path as provided: %s", source) - return source, False + logger.error("unknown upscaling model type %s for %s", model_type, name) + model_errors.append(name) + + +def convert_correction_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "gfpgan") + if model_type == "gfpgan": + convert_correction_gfpgan(conversion, model, source) + else: + logger.error("unknown correction model type %s for %s", model_type, name) + model_errors.append(name) def convert_models(conversion: ConversionContext, args, models: Models): @@ -277,18 +449,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping source: %s", name) else: - model_format = source_format(model) - source = model["source"] - try: - dest_path = None - if "dest" in model: - dest_path = path.join(conversion.model_path, model["dest"]) - - dest, hf = fetch_model( - conversion, name, source, format=model_format, dest=dest_path - ) - logger.info("finished downloading source: %s -> %s", source, dest) + convert_source_model(model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) @@ -300,46 +462,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping network: %s", name) else: - network_format = source_format(network) - network_model = network.get("model", None) - network_type = network["type"] - source = network["source"] - try: - if network_type == "control": - dest, hf = fetch_model( - conversion, - name, - source, - format=network_format, - ) - - convert_diffusion_control( - conversion, - network, - dest, - path.join(conversion.model_path, network_type, name), - ) - if network_type == "inversion" and network_model == "concept": - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - hf_hub_fetch=True, - hf_hub_filename="learned_embeds.bin", - ) - else: - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - ) - - logger.info("finished downloading network: %s -> %s", source, dest) + convert_network_model(conversion, model) except Exception: logger.exception("error fetching network %s", name) model_errors.append(name) @@ -352,148 +476,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - # fix up entries with missing prefixes - name = fix_diffusion_name(name) - if name != model["name"]: - # update the model in-memory if the name changed - model["name"] = name - - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - - pipeline = model.get("pipeline", "txt2img") - if pipeline.endswith("-sdxl"): - converted, dest = convert_diffusion_diffusers_xl( - conversion, - model, - source, - model_format, - hf=hf, - ) - else: - converted, dest = convert_diffusion_diffusers( - conversion, - model, - source, - model_format, - hf=hf, - ) - - # make sure blending only happens once, not every run - if converted: - # keep track of which models have been blended - blend_models = {} - - inversion_dest = path.join(conversion.model_path, "inversion") - lora_dest = path.join(conversion.model_path, "lora") - - for inversion in model.get("inversions", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "tokenizer" not in blend_models: - blend_models[ - "tokenizer" - ] = CLIPTokenizer.from_pretrained( - dest, - subfolder="tokenizer", - ) - - inversion_name = inversion["name"] - inversion_source = inversion["source"] - inversion_format = inversion.get("format", None) - inversion_source, hf = fetch_model( - conversion, - inversion_name, - inversion_source, - dest=inversion_dest, - ) - inversion_token = inversion.get("token", inversion_name) - inversion_weight = inversion.get("weight", 1.0) - - blend_textual_inversions( - conversion, - blend_models["text_encoder"], - blend_models["tokenizer"], - [ - ( - inversion_source, - inversion_weight, - inversion_token, - inversion_format, - ) - ], - ) - - for lora in model.get("loras", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "unet" not in blend_models: - blend_models["unet"] = load_model( - path.join(dest, "unet", ONNX_MODEL) - ) - - # load models if not loaded yet - lora_name = lora["name"] - lora_source = lora["source"] - lora_source, hf = fetch_model( - conversion, - f"{name}-lora-{lora_name}", - lora_source, - dest=lora_dest, - ) - lora_weight = lora.get("weight", 1.0) - - blend_loras( - conversion, - blend_models["text_encoder"], - [(lora_source, lora_weight)], - "text_encoder", - ) - - blend_loras( - conversion, - blend_models["unet"], - [(lora_source, lora_weight)], - "unet", - ) - - if "tokenizer" in blend_models: - dest_path = path.join(dest, "tokenizer") - logger.debug("saving blended tokenizer to %s", dest_path) - blend_models["tokenizer"].save_pretrained(dest_path) - - for name in ["text_encoder", "unet"]: - if name in blend_models: - dest_path = path.join(dest, name, ONNX_MODEL) - logger.debug( - "saving blended %s model to %s", name, dest_path - ) - save_model( - blend_models[name], - dest_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - location=ONNX_WEIGHTS, - ) - + convert_diffusion_model(conversion, model) except Exception: logger.exception( "error converting diffusion model %s", @@ -509,24 +493,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "resrgan") - if model_type == "bsrgan": - convert_upscaling_bsrgan(conversion, model, source) - elif model_type == "resrgan": - convert_upscale_resrgan(conversion, model, source) - elif model_type == "swinir": - convert_upscaling_swinir(conversion, model, source) - else: - logger.error( - "unknown upscaling model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_upscaling_model(conversion, model) except Exception: logger.exception( "error converting upscaling model %s", @@ -542,19 +510,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "gfpgan") - if model_type == "gfpgan": - convert_correction_gfpgan(conversion, model, source) - else: - logger.error( - "unknown correction model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_correction_model(conversion, model) except Exception: logger.exception( "error converting correction model %s", @@ -566,12 +523,26 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.error("error while converting models: %s", model_errors) +def register_plugins(conversion: ConversionContext): + logger.info("loading conversion plugins") + exports = load_plugins(conversion) + + for proto, client in exports.clients: + try: + add_model_source(proto, client) + except Exception: + logger.exception("error loading client for protocol: %s", proto) + + # TODO: add converters + + def main(args=None) -> int: parser = ArgumentParser( prog="onnx-web model converter", description="convert checkpoint models to ONNX" ) # model groups + parser.add_argument("--base", action="store_true", default=True) parser.add_argument("--networks", action="store_true", default=True) parser.add_argument("--sources", action="store_true", default=True) parser.add_argument("--correction", action="store_true", default=False) @@ -609,16 +580,20 @@ def main(args=None) -> int: server.half = args.half or server.has_optimization("onnx-fp16") server.opset = args.opset server.token = args.token + + register_plugins(server) + logger.info( - "converting models in %s using %s", server.model_path, server.training_device + "converting models into %s using %s", server.model_path, server.training_device ) if not path.exists(server.model_path): logger.info("model path does not existing, creating: %s", server.model_path) makedirs(server.model_path) - logger.info("converting base models") - convert_models(server, args, base_models) + if args.base: + logger.info("converting base models") + convert_models(server, args, base_models) extras = [] extras.extend(server.extra_models) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py new file mode 100644 index 00000000..c38ec082 --- /dev/null +++ b/api/onnx_web/convert/client/base.py @@ -0,0 +1,13 @@ +from typing import Optional +from ..utils import ConversionContext + + +class BaseClient: + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py new file mode 100644 index 00000000..1bc8c4a9 --- /dev/null +++ b/api/onnx_web/convert/client/civitai.py @@ -0,0 +1,45 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Optional +from logging import getLogger + +logger = getLogger(__name__) + +CIVITAI_ROOT = "https://civitai.com/api/download/models/%s" + + +class CivitaiClient(BaseClient): + protocol = "civitai://" + root: str + token: Optional[str] + + def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT): + self.root = root + self.token = token + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + """ + TODO: download with auth token + """ + cache_paths = build_cache_paths( + conversion, name, client=CivitaiClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + source = self.root % (remove_prefix(source, CivitaiClient.protocol)) + logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py new file mode 100644 index 00000000..8c44c0d0 --- /dev/null +++ b/api/onnx_web/convert/client/file.py @@ -0,0 +1,20 @@ +from .base import BaseClient +from logging import getLogger +from os import path +from urllib.parse import urlparse + +logger = getLogger(__name__) + + +class FileClient(BaseClient): + protocol = "file://" + + root: str + + def __init__(self, root: str): + self.root = root + + def download(self, uri: str) -> str: + parts = urlparse(uri) + logger.info("loading model from: %s", parts.path) + return path.join(self.root, parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py new file mode 100644 index 00000000..5259402f --- /dev/null +++ b/api/onnx_web/convert/client/http.py @@ -0,0 +1,39 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Dict, Optional +from logging import getLogger + +logger = getLogger(__name__) + + +class HttpClient(BaseClient): + name = "http" + protocol = "https://" + insecure_protocol = "http://" + + headers: Dict[str, str] + + def __init__(self, headers: Optional[Dict[str, str]] = None): + self.headers = headers or {} + + def download(self, conversion: ConversionContext, name: str, uri: str) -> str: + cache_paths = build_cache_paths( + conversion, name, client=HttpClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + if uri.startswith(HttpClient.protocol): + source = remove_prefix(uri, HttpClient.protocol) + logger.info("downloading model from: %s", source) + elif uri.startswith(HttpClient.insecure_protocol): + logger.warning("downloading model from insecure source: %s", source) + + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py new file mode 100644 index 00000000..61796e99 --- /dev/null +++ b/api/onnx_web/convert/client/huggingface.py @@ -0,0 +1,61 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Optional, Any +from logging import getLogger +from huggingface_hub.file_download import hf_hub_download + +logger = getLogger(__name__) + + +class HuggingfaceClient(BaseClient): + name = "huggingface" + protocol = "huggingface://" + + download: Any + token: Optional[str] + + def __init__(self, token: Optional[str] = None, download=hf_hub_download): + self.download = download + self.token = token + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + """ + TODO: download with auth + """ + hf_hub_fetch = TODO + hf_hub_filename = TODO + + cache_paths = build_cache_paths( + conversion, name, client=HuggingfaceClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + source = remove_prefix(source, HuggingfaceClient.protocol) + logger.info("downloading model from Huggingface Hub: %s", source) + + # from_pretrained has a bunch of useful logic that snapshot_download by itself down not + if hf_hub_fetch: + return ( + hf_hub_download( + repo_id=source, + filename=hf_hub_filename, + cache_dir=cache_paths[0], + force_filename=f"{name}.bin", + ), + False, + ) + else: + return source diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index e56e05ae..19e6a309 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -15,6 +15,8 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from packaging import version from torch.onnx import export +from onnx_web.convert.client.file import FileClient + from ..constants import ONNX_WEIGHTS from ..errors import RequestException from ..server import ServerContext @@ -85,38 +87,37 @@ class ConversionContext(ServerContext): return torch.device(self.training_device) -def download_progress(urls: List[Tuple[str, str]]): - for url, dest in urls: - dest_path = Path(dest).expanduser().resolve() - dest_path.parent.mkdir(parents=True, exist_ok=True) - - if dest_path.exists(): - logger.debug("destination already exists: %s", dest_path) - return str(dest_path.absolute()) - - req = requests.get( - url, - stream=True, - allow_redirects=True, - headers={ - "User-Agent": "onnx-web-api", - }, - ) - if req.status_code != 200: - req.raise_for_status() # Only works for 4xx errors, per SO answer - raise RequestException( - "request to %s failed with status code: %s" % (url, req.status_code) - ) - - total = int(req.headers.get("Content-Length", 0)) - desc = "unknown" if total == 0 else "" - req.raw.read = partial(req.raw.read, decode_content=True) - with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: - with dest_path.open("wb") as f: - shutil.copyfileobj(data, f) +def download_progress(source: str, dest: str): + dest_path = Path(dest).expanduser().resolve() + dest_path.parent.mkdir(parents=True, exist_ok=True) + if dest_path.exists(): + logger.debug("destination already exists: %s", dest_path) return str(dest_path.absolute()) + req = requests.get( + source, + stream=True, + allow_redirects=True, + headers={ + "User-Agent": "onnx-web-api", + }, + ) + if req.status_code != 200: + req.raise_for_status() # Only works for 4xx errors, per SO answer + raise RequestException( + "request to %s failed with status code: %s" % (source, req.status_code) + ) + + total = int(req.headers.get("Content-Length", 0)) + desc = "unknown" if total == 0 else "" + req.raw.read = partial(req.raw.read, decode_content=True) + with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: + with dest_path.open("wb") as f: + shutil.copyfileobj(data, f) + + return str(dest_path.absolute()) + def tuple_to_source(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): @@ -347,7 +348,13 @@ def onnx_export( ) -DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"] +DIFFUSION_PREFIX = [ + "diffusion-", + "diffusion/", + "diffusion\\", + "stable-diffusion-", + "upscaling-", +] def fix_diffusion_name(name: str): @@ -359,3 +366,61 @@ def fix_diffusion_name(name: str): return f"diffusion-{name}" return name + + +def fetch_model( + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], +) -> Tuple[str, bool]: + # TODO: switch to urlparse's default scheme + if source.startswith(path.sep) or source.startswith("."): + logger.info("adding file protocol to local path source: %s", source) + source = FileClient.protocol + source + + for proto, client_type in model_sources.items(): + if source.startswith(proto): + client = client_type() + return client.download(proto) + + logger.warning("unknown model protocol, using path as provided: %s", source) + return source, False + + +def build_cache_paths( + conversion: ConversionContext, + name: str, + client: Optional[str] = None, + dest: Optional[str] = None, + format: Optional[str] = None, +) -> List[str]: + cache_path = dest or conversion.cache_path + + # add an extension if possible, some of the conversion code checks for it + if format is not None: + basename = path.basename(name) + _filename, ext = path.splitext(basename) + if ext is None: + name = f"{name}.{format}" + + paths = [ + path.join(cache_path, name), + ] + + if client is not None: + client_path = path.join(cache_path, client) + paths.append(path.join(client_path, name)) + + return paths + + +def get_first_exists( + paths: List[str], +) -> Optional[str]: + for name in paths: + if path.exists(name): + logger.debug("model already exists in cache, skipping fetch: %s", name) + return name + + return None diff --git a/api/onnx_web/server/plugin.py b/api/onnx_web/server/plugin.py index 022047df..2adcabb3 100644 --- a/api/onnx_web/server/plugin.py +++ b/api/onnx_web/server/plugin.py @@ -2,18 +2,24 @@ from importlib import import_module from logging import getLogger from typing import Any, Callable, Dict -from onnx_web.chain.stages import add_stage -from onnx_web.diffusers.load import add_pipeline -from onnx_web.server.context import ServerContext +from ..chain.stages import add_stage +from ..diffusers.load import add_pipeline +from ..server.context import ServerContext logger = getLogger(__name__) class PluginExports: + clients: Dict[str, Any] + converter: Dict[str, Any] pipelines: Dict[str, Any] stages: Dict[str, Any] - def __init__(self, pipelines=None, stages=None) -> None: + def __init__( + self, clients=None, converter=None, pipelines=None, stages=None + ) -> None: + self.clients = clients or {} + self.converter = converter or {} self.pipelines = pipelines or {} self.stages = stages or {} diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 4281adbc..34b0bf9b 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase): class DownloadProgressTests(unittest.TestCase): def test_download_example(self): - path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + path = download_progress("https://example.com", "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com")