1
0
Fork 0

split up download clients, add to plugin exports

This commit is contained in:
Sean Sube 2023-12-09 18:04:34 -06:00
parent e7da2cf8a6
commit cc8e564d26
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 550 additions and 326 deletions

View File

@ -3,16 +3,21 @@ from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import makedirs, path from os import makedirs, path
from sys import exit from sys import exit
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Union
from urllib.parse import urlparse
from huggingface_hub.file_download import hf_hub_download
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
from onnx import load_model, save_model from onnx import load_model, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from onnx_web.server.plugin import load_plugins
from ..constants import ONNX_MODEL, ONNX_WEIGHTS from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config from ..utils import load_config
from .client.base import BaseClient
from .client.civitai import CivitaiClient
from .client.file import FileClient
from .client.http import HttpClient
from .client.huggingface import HuggingfaceClient
from .correction.gfpgan import convert_correction_gfpgan from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control from .diffusion.control import convert_diffusion_control
from .diffusion.diffusion import convert_diffusion_diffusers from .diffusion.diffusion import convert_diffusion_diffusers
@ -25,9 +30,8 @@ from .upscaling.swinir import convert_upscaling_swinir
from .utils import ( from .utils import (
DEFAULT_OPSET, DEFAULT_OPSET,
ConversionContext, ConversionContext,
download_progress, fetch_model,
fix_diffusion_name, fix_diffusion_name,
remove_prefix,
source_format, source_format,
tuple_to_correction, tuple_to_correction,
tuple_to_diffusion, tuple_to_diffusion,
@ -45,16 +49,27 @@ warnings.filterwarnings(
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
) )
Models = Dict[str, List[Any]]
logger = getLogger(__name__) logger = getLogger(__name__)
ModelDict = Dict[str, Union[float, int, str]]
Models = Dict[str, List[ModelDict]]
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), model_sources: Dict[str, BaseClient] = {
CivitaiClient.protocol: CivitaiClient,
FileClient.protocol: FileClient,
HttpClient.insecure_protocol: HttpClient,
HttpClient.protocol: HttpClient,
HuggingfaceClient.protocol: HuggingfaceClient,
} }
model_source_huggingface = "huggingface://" model_converters: Dict[str, Any] = {
"img2img": convert_diffusion_diffusers,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers,
"txt2img": convert_diffusion_diffusers,
"txt2img-sdxl": convert_diffusion_diffusers_xl,
}
# recommended models # recommended models
base_models: Models = { base_models: Models = {
@ -62,15 +77,15 @@ base_models: Models = {
# v1.x # v1.x
( (
"stable-diffusion-onnx-v1-5", "stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5", HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
), ),
( (
"stable-diffusion-onnx-v1-inpainting", "stable-diffusion-onnx-v1-inpainting",
model_source_huggingface + "runwayml/stable-diffusion-inpainting", HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
), ),
( (
"upscaling-stable-diffusion-x4", "upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
True, True,
), ),
], ],
@ -201,69 +216,226 @@ base_models: Models = {
} }
def fetch_model( def add_model_source(proto: str, client: BaseClient):
conversion: ConversionContext, global model_sources
name: str,
source: str,
dest: Optional[str] = None,
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> Tuple[str, bool]:
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)
# add an extension if possible, some of the conversion code checks for it if proto in model_sources:
if format is None: raise ValueError("protocol has already been taken")
url = urlparse(source)
ext = path.basename(url.path) model_sources[proto] = client
_filename, ext = path.splitext(ext)
if ext is not None:
cache_name = cache_name + ext def convert_source_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = model["source"]
dest_path = None
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
def convert_network_model(conversion: ConversionContext, network):
network_format = source_format(network)
network_model = network.get("model", None)
name = network["name"]
network_type = network["type"]
source = network["source"]
if network_type == "control":
dest, hf = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else: else:
cache_name = f"{cache_name}.{format}" dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
if path.exists(cache_name): logger.info("finished downloading network: %s -> %s", source, dest)
logger.debug("model already exists in cache, skipping fetch")
return cache_name, False
for proto in model_sources:
api_name, api_root = model_sources.get(proto) def convert_diffusion_model(conversion: ConversionContext, model):
if source.startswith(proto): # fix up entries with missing prefixes
api_source = api_root % (remove_prefix(source, proto)) name = fix_diffusion_name(model["name"])
logger.info( if name != model["name"]:
"downloading model from %s: %s -> %s", api_name, api_source, cache_name # update the model in-memory if the name changed
model["name"] = name
model_format = source_format(model)
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
pipeline = model.get("pipeline", "txt2img")
converter = model_converters.get(pipeline)
converted, dest = converter(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
) )
return download_progress([(api_source, cache_name)]), False inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
if source.startswith(model_source_huggingface): blend_textual_inversions(
hub_source = remove_prefix(source, model_source_huggingface) conversion,
logger.info("downloading model from Huggingface Hub: %s", hub_source) blend_models["text_encoder"],
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not blend_models["tokenizer"],
if hf_hub_fetch: [
return ( (
hf_hub_download( inversion_source,
repo_id=hub_source, inversion_weight,
filename=hf_hub_filename, inversion_token,
cache_dir=cache_path, inversion_format,
force_filename=f"{name}.bin", )
), ],
False,
) )
else:
return hub_source, True for lora in model.get("loras", []):
elif source.startswith("https://"): if "text_encoder" not in blend_models:
logger.info("downloading model from: %s", source) blend_models["text_encoder"] = load_model(
return download_progress([(source, cache_name)]), False path.join(
elif source.startswith("http://"): dest,
logger.warning("downloading model from insecure source: %s", source) "text_encoder",
return download_progress([(source, cache_name)]), False ONNX_MODEL,
elif source.startswith(path.sep) or source.startswith("."): )
logger.info("using local model: %s", source) )
return source, False
if "unet" not in blend_models:
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug("saving blended %s model to %s", name, dest_path)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
def convert_upscaling_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else: else:
logger.info("unknown model location, using path as provided: %s", source) logger.error("unknown upscaling model type %s for %s", model_type, name)
return source, False model_errors.append(name)
def convert_correction_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error("unknown correction model type %s for %s", model_type, name)
model_errors.append(name)
def convert_models(conversion: ConversionContext, args, models: Models): def convert_models(conversion: ConversionContext, args, models: Models):
@ -277,18 +449,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip: if name in args.skip:
logger.info("skipping source: %s", name) logger.info("skipping source: %s", name)
else: else:
model_format = source_format(model)
source = model["source"]
try: try:
dest_path = None convert_source_model(model)
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching source %s", name) logger.exception("error fetching source %s", name)
model_errors.append(name) model_errors.append(name)
@ -300,46 +462,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip: if name in args.skip:
logger.info("skipping network: %s", name) logger.info("skipping network: %s", name)
else: else:
network_format = source_format(network)
network_model = network.get("model", None)
network_type = network["type"]
source = network["source"]
try: try:
if network_type == "control": convert_network_model(conversion, model)
dest, hf = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching network %s", name) logger.exception("error fetching network %s", name)
model_errors.append(name) model_errors.append(name)
@ -352,148 +476,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip: if name in args.skip:
logger.info("skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
# fix up entries with missing prefixes
name = fix_diffusion_name(name)
if name != model["name"]:
# update the model in-memory if the name changed
model["name"] = name
model_format = source_format(model)
try: try:
source, hf = fetch_model( convert_diffusion_model(conversion, model)
conversion, name, model["source"], format=model_format
)
pipeline = model.get("pipeline", "txt2img")
if pipeline.endswith("-sdxl"):
converted, dest = convert_diffusion_diffusers_xl(
conversion,
model,
source,
model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models[
"tokenizer"
] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
)
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions(
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
(
inversion_source,
inversion_weight,
inversion_token,
inversion_format,
)
],
)
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "unet" not in blend_models:
blend_models["unet"] = load_model(
path.join(dest, "unet", ONNX_MODEL)
)
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug(
"saving blended %s model to %s", name, dest_path
)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting diffusion model %s", "error converting diffusion model %s",
@ -509,24 +493,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip: if name in args.skip:
logger.info("skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
model_format = source_format(model)
try: try:
source, hf = fetch_model( convert_upscaling_model(conversion, model)
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.error(
"unknown upscaling model type %s for %s", model_type, name
)
model_errors.append(name)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting upscaling model %s", "error converting upscaling model %s",
@ -542,19 +510,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip: if name in args.skip:
logger.info("skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
model_format = source_format(model)
try: try:
source, hf = fetch_model( convert_correction_model(conversion, model)
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error(
"unknown correction model type %s for %s", model_type, name
)
model_errors.append(name)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting correction model %s", "error converting correction model %s",
@ -566,12 +523,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.error("error while converting models: %s", model_errors) logger.error("error while converting models: %s", model_errors)
def register_plugins(conversion: ConversionContext):
logger.info("loading conversion plugins")
exports = load_plugins(conversion)
for proto, client in exports.clients:
try:
add_model_source(proto, client)
except Exception:
logger.exception("error loading client for protocol: %s", proto)
# TODO: add converters
def main(args=None) -> int: def main(args=None) -> int:
parser = ArgumentParser( parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX" prog="onnx-web model converter", description="convert checkpoint models to ONNX"
) )
# model groups # model groups
parser.add_argument("--base", action="store_true", default=True)
parser.add_argument("--networks", action="store_true", default=True) parser.add_argument("--networks", action="store_true", default=True)
parser.add_argument("--sources", action="store_true", default=True) parser.add_argument("--sources", action="store_true", default=True)
parser.add_argument("--correction", action="store_true", default=False) parser.add_argument("--correction", action="store_true", default=False)
@ -609,16 +580,20 @@ def main(args=None) -> int:
server.half = args.half or server.has_optimization("onnx-fp16") server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset server.opset = args.opset
server.token = args.token server.token = args.token
register_plugins(server)
logger.info( logger.info(
"converting models in %s using %s", server.model_path, server.training_device "converting models into %s using %s", server.model_path, server.training_device
) )
if not path.exists(server.model_path): if not path.exists(server.model_path):
logger.info("model path does not existing, creating: %s", server.model_path) logger.info("model path does not existing, creating: %s", server.model_path)
makedirs(server.model_path) makedirs(server.model_path)
logger.info("converting base models") if args.base:
convert_models(server, args, base_models) logger.info("converting base models")
convert_models(server, args, base_models)
extras = [] extras = []
extras.extend(server.extra_models) extras.extend(server.extra_models)

View File

View File

@ -0,0 +1,13 @@
from typing import Optional
from ..utils import ConversionContext
class BaseClient:
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
raise NotImplementedError()

View File

@ -0,0 +1,45 @@
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Optional
from logging import getLogger
logger = getLogger(__name__)
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
class CivitaiClient(BaseClient):
protocol = "civitai://"
root: str
token: Optional[str]
def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT):
self.root = root
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
"""
TODO: download with auth token
"""
cache_paths = build_cache_paths(
conversion, name, client=CivitaiClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,20 @@
from .base import BaseClient
from logging import getLogger
from os import path
from urllib.parse import urlparse
logger = getLogger(__name__)
class FileClient(BaseClient):
protocol = "file://"
root: str
def __init__(self, root: str):
self.root = root
def download(self, uri: str) -> str:
parts = urlparse(uri)
logger.info("loading model from: %s", parts.path)
return path.join(self.root, parts.path)

View File

@ -0,0 +1,39 @@
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Dict, Optional
from logging import getLogger
logger = getLogger(__name__)
class HttpClient(BaseClient):
name = "http"
protocol = "https://"
insecure_protocol = "http://"
headers: Dict[str, str]
def __init__(self, headers: Optional[Dict[str, str]] = None):
self.headers = headers or {}
def download(self, conversion: ConversionContext, name: str, uri: str) -> str:
cache_paths = build_cache_paths(
conversion, name, client=HttpClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
if uri.startswith(HttpClient.protocol):
source = remove_prefix(uri, HttpClient.protocol)
logger.info("downloading model from: %s", source)
elif uri.startswith(HttpClient.insecure_protocol):
logger.warning("downloading model from insecure source: %s", source)
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,61 @@
from ..utils import (
ConversionContext,
build_cache_paths,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Optional, Any
from logging import getLogger
from huggingface_hub.file_download import hf_hub_download
logger = getLogger(__name__)
class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
download: Any
token: Optional[str]
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
self.download = download
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
"""
TODO: download with auth
"""
hf_hub_fetch = TODO
hf_hub_filename = TODO
cache_paths = build_cache_paths(
conversion, name, client=HuggingfaceClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=source,
filename=hf_hub_filename,
cache_dir=cache_paths[0],
force_filename=f"{name}.bin",
),
False,
)
else:
return source

View File

@ -15,6 +15,8 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
from packaging import version from packaging import version
from torch.onnx import export from torch.onnx import export
from onnx_web.convert.client.file import FileClient
from ..constants import ONNX_WEIGHTS from ..constants import ONNX_WEIGHTS
from ..errors import RequestException from ..errors import RequestException
from ..server import ServerContext from ..server import ServerContext
@ -85,38 +87,37 @@ class ConversionContext(ServerContext):
return torch.device(self.training_device) return torch.device(self.training_device)
def download_progress(urls: List[Tuple[str, str]]): def download_progress(source: str, dest: str):
for url, dest in urls: dest_path = Path(dest).expanduser().resolve()
dest_path = Path(dest).expanduser().resolve() dest_path.parent.mkdir(parents=True, exist_ok=True)
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(
url,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (url, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute()) return str(dest_path.absolute())
req = requests.get(
source,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (source, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
return str(dest_path.absolute())
def tuple_to_source(model: Union[ModelDict, LegacyModel]): def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
@ -347,7 +348,13 @@ def onnx_export(
) )
DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"] DIFFUSION_PREFIX = [
"diffusion-",
"diffusion/",
"diffusion\\",
"stable-diffusion-",
"upscaling-",
]
def fix_diffusion_name(name: str): def fix_diffusion_name(name: str):
@ -359,3 +366,61 @@ def fix_diffusion_name(name: str):
return f"diffusion-{name}" return f"diffusion-{name}"
return name return name
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> Tuple[str, bool]:
# TODO: switch to urlparse's default scheme
if source.startswith(path.sep) or source.startswith("."):
logger.info("adding file protocol to local path source: %s", source)
source = FileClient.protocol + source
for proto, client_type in model_sources.items():
if source.startswith(proto):
client = client_type()
return client.download(proto)
logger.warning("unknown model protocol, using path as provided: %s", source)
return source, False
def build_cache_paths(
conversion: ConversionContext,
name: str,
client: Optional[str] = None,
dest: Optional[str] = None,
format: Optional[str] = None,
) -> List[str]:
cache_path = dest or conversion.cache_path
# add an extension if possible, some of the conversion code checks for it
if format is not None:
basename = path.basename(name)
_filename, ext = path.splitext(basename)
if ext is None:
name = f"{name}.{format}"
paths = [
path.join(cache_path, name),
]
if client is not None:
client_path = path.join(cache_path, client)
paths.append(path.join(client_path, name))
return paths
def get_first_exists(
paths: List[str],
) -> Optional[str]:
for name in paths:
if path.exists(name):
logger.debug("model already exists in cache, skipping fetch: %s", name)
return name
return None

View File

@ -2,18 +2,24 @@ from importlib import import_module
from logging import getLogger from logging import getLogger
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage from ..chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline from ..diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext from ..server.context import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
class PluginExports: class PluginExports:
clients: Dict[str, Any]
converter: Dict[str, Any]
pipelines: Dict[str, Any] pipelines: Dict[str, Any]
stages: Dict[str, Any] stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None: def __init__(
self, clients=None, converter=None, pipelines=None, stages=None
) -> None:
self.clients = clients or {}
self.converter = converter or {}
self.pipelines = pipelines or {} self.pipelines = pipelines or {}
self.stages = stages or {} self.stages = stages or {}

View File

@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
class DownloadProgressTests(unittest.TestCase): class DownloadProgressTests(unittest.TestCase):
def test_download_example(self): def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")]) path = download_progress("https://example.com", "/tmp/example-dot-com")
self.assertEqual(path, "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com")