1
0
Fork 0

split up download clients, add to plugin exports

This commit is contained in:
Sean Sube 2023-12-09 18:04:34 -06:00
parent e7da2cf8a6
commit cc8e564d26
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 550 additions and 326 deletions

View File

@ -3,16 +3,21 @@ from argparse import ArgumentParser
from logging import getLogger
from os import makedirs, path
from sys import exit
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from typing import Any, Dict, List, Union
from huggingface_hub.file_download import hf_hub_download
from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
from onnx_web.server.plugin import load_plugins
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config
from .client.base import BaseClient
from .client.civitai import CivitaiClient
from .client.file import FileClient
from .client.http import HttpClient
from .client.huggingface import HuggingfaceClient
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusion import convert_diffusion_diffusers
@ -25,9 +30,8 @@ from .upscaling.swinir import convert_upscaling_swinir
from .utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
fetch_model,
fix_diffusion_name,
remove_prefix,
source_format,
tuple_to_correction,
tuple_to_diffusion,
@ -45,16 +49,27 @@ warnings.filterwarnings(
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Any]]
logger = getLogger(__name__)
ModelDict = Dict[str, Union[float, int, str]]
Models = Dict[str, List[ModelDict]]
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
model_sources: Dict[str, BaseClient] = {
CivitaiClient.protocol: CivitaiClient,
FileClient.protocol: FileClient,
HttpClient.insecure_protocol: HttpClient,
HttpClient.protocol: HttpClient,
HuggingfaceClient.protocol: HuggingfaceClient,
}
model_source_huggingface = "huggingface://"
model_converters: Dict[str, Any] = {
"img2img": convert_diffusion_diffusers,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers,
"txt2img": convert_diffusion_diffusers,
"txt2img-sdxl": convert_diffusion_diffusers_xl,
}
# recommended models
base_models: Models = {
@ -62,15 +77,15 @@ base_models: Models = {
# v1.x
(
"stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
),
(
"stable-diffusion-onnx-v1-inpainting",
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
),
(
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
],
@ -201,69 +216,226 @@ base_models: Models = {
}
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
dest: Optional[str] = None,
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> Tuple[str, bool]:
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)
def add_model_source(proto: str, client: BaseClient):
global model_sources
# add an extension if possible, some of the conversion code checks for it
if format is None:
url = urlparse(source)
ext = path.basename(url.path)
_filename, ext = path.splitext(ext)
if ext is not None:
cache_name = cache_name + ext
if proto in model_sources:
raise ValueError("protocol has already been taken")
model_sources[proto] = client
def convert_source_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = model["source"]
dest_path = None
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
def convert_network_model(conversion: ConversionContext, network):
network_format = source_format(network)
network_model = network.get("model", None)
name = network["name"]
network_type = network["type"]
source = network["source"]
if network_type == "control":
dest, hf = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
cache_name = f"{cache_name}.{format}"
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
return cache_name, False
logger.info("finished downloading network: %s -> %s", source, dest)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (remove_prefix(source, proto))
logger.info(
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
def convert_diffusion_model(conversion: ConversionContext, model):
# fix up entries with missing prefixes
name = fix_diffusion_name(model["name"])
if name != model["name"]:
# update the model in-memory if the name changed
model["name"] = name
model_format = source_format(model)
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
pipeline = model.get("pipeline", "txt2img")
converter = model_converters.get(pipeline)
converted, dest = converter(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
)
return download_progress([(api_source, cache_name)]), False
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface)
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=hub_source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
),
False,
blend_textual_inversions(
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
(
inversion_source,
inversion_weight,
inversion_token,
inversion_format,
)
],
)
else:
return hub_source, True
elif source.startswith("https://"):
logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)]), False
elif source.startswith("http://"):
logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)]), False
elif source.startswith(path.sep) or source.startswith("."):
logger.info("using local model: %s", source)
return source, False
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "unet" not in blend_models:
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug("saving blended %s model to %s", name, dest_path)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
def convert_upscaling_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.info("unknown model location, using path as provided: %s", source)
return source, False
logger.error("unknown upscaling model type %s for %s", model_type, name)
model_errors.append(name)
def convert_correction_model(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error("unknown correction model type %s for %s", model_type, name)
model_errors.append(name)
def convert_models(conversion: ConversionContext, args, models: Models):
@ -277,18 +449,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping source: %s", name)
else:
model_format = source_format(model)
source = model["source"]
try:
dest_path = None
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
convert_source_model(model)
except Exception:
logger.exception("error fetching source %s", name)
model_errors.append(name)
@ -300,46 +462,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping network: %s", name)
else:
network_format = source_format(network)
network_model = network.get("model", None)
network_type = network["type"]
source = network["source"]
try:
if network_type == "control":
dest, hf = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest)
convert_network_model(conversion, model)
except Exception:
logger.exception("error fetching network %s", name)
model_errors.append(name)
@ -352,148 +476,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
# fix up entries with missing prefixes
name = fix_diffusion_name(name)
if name != model["name"]:
# update the model in-memory if the name changed
model["name"] = name
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
pipeline = model.get("pipeline", "txt2img")
if pipeline.endswith("-sdxl"):
converted, dest = convert_diffusion_diffusers_xl(
conversion,
model,
source,
model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models[
"tokenizer"
] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
)
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions(
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
(
inversion_source,
inversion_weight,
inversion_token,
inversion_format,
)
],
)
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "unet" not in blend_models:
blend_models["unet"] = load_model(
path.join(dest, "unet", ONNX_MODEL)
)
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug(
"saving blended %s model to %s", name, dest_path
)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
convert_diffusion_model(conversion, model)
except Exception:
logger.exception(
"error converting diffusion model %s",
@ -509,24 +493,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.error(
"unknown upscaling model type %s for %s", model_type, name
)
model_errors.append(name)
convert_upscaling_model(conversion, model)
except Exception:
logger.exception(
"error converting upscaling model %s",
@ -542,19 +510,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error(
"unknown correction model type %s for %s", model_type, name
)
model_errors.append(name)
convert_correction_model(conversion, model)
except Exception:
logger.exception(
"error converting correction model %s",
@ -566,12 +523,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.error("error while converting models: %s", model_errors)
def register_plugins(conversion: ConversionContext):
logger.info("loading conversion plugins")
exports = load_plugins(conversion)
for proto, client in exports.clients:
try:
add_model_source(proto, client)
except Exception:
logger.exception("error loading client for protocol: %s", proto)
# TODO: add converters
def main(args=None) -> int:
parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument("--base", action="store_true", default=True)
parser.add_argument("--networks", action="store_true", default=True)
parser.add_argument("--sources", action="store_true", default=True)
parser.add_argument("--correction", action="store_true", default=False)
@ -609,16 +580,20 @@ def main(args=None) -> int:
server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset
server.token = args.token
register_plugins(server)
logger.info(
"converting models in %s using %s", server.model_path, server.training_device
"converting models into %s using %s", server.model_path, server.training_device
)
if not path.exists(server.model_path):
logger.info("model path does not existing, creating: %s", server.model_path)
makedirs(server.model_path)
logger.info("converting base models")
convert_models(server, args, base_models)
if args.base:
logger.info("converting base models")
convert_models(server, args, base_models)
extras = []
extras.extend(server.extra_models)

View File

View File

@ -0,0 +1,13 @@
from typing import Optional
from ..utils import ConversionContext
class BaseClient:
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
raise NotImplementedError()

View File

@ -0,0 +1,45 @@
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Optional
from logging import getLogger
logger = getLogger(__name__)
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
class CivitaiClient(BaseClient):
protocol = "civitai://"
root: str
token: Optional[str]
def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT):
self.root = root
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
"""
TODO: download with auth token
"""
cache_paths = build_cache_paths(
conversion, name, client=CivitaiClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,20 @@
from .base import BaseClient
from logging import getLogger
from os import path
from urllib.parse import urlparse
logger = getLogger(__name__)
class FileClient(BaseClient):
protocol = "file://"
root: str
def __init__(self, root: str):
self.root = root
def download(self, uri: str) -> str:
parts = urlparse(uri)
logger.info("loading model from: %s", parts.path)
return path.join(self.root, parts.path)

View File

@ -0,0 +1,39 @@
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Dict, Optional
from logging import getLogger
logger = getLogger(__name__)
class HttpClient(BaseClient):
name = "http"
protocol = "https://"
insecure_protocol = "http://"
headers: Dict[str, str]
def __init__(self, headers: Optional[Dict[str, str]] = None):
self.headers = headers or {}
def download(self, conversion: ConversionContext, name: str, uri: str) -> str:
cache_paths = build_cache_paths(
conversion, name, client=HttpClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
if uri.startswith(HttpClient.protocol):
source = remove_prefix(uri, HttpClient.protocol)
logger.info("downloading model from: %s", source)
elif uri.startswith(HttpClient.insecure_protocol):
logger.warning("downloading model from insecure source: %s", source)
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,61 @@
from ..utils import (
ConversionContext,
build_cache_paths,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
from typing import Optional, Any
from logging import getLogger
from huggingface_hub.file_download import hf_hub_download
logger = getLogger(__name__)
class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
download: Any
token: Optional[str]
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
self.download = download
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> str:
"""
TODO: download with auth
"""
hf_hub_fetch = TODO
hf_hub_filename = TODO
cache_paths = build_cache_paths(
conversion, name, client=HuggingfaceClient.name, format=format
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=source,
filename=hf_hub_filename,
cache_dir=cache_paths[0],
force_filename=f"{name}.bin",
),
False,
)
else:
return source

View File

@ -15,6 +15,8 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
from packaging import version
from torch.onnx import export
from onnx_web.convert.client.file import FileClient
from ..constants import ONNX_WEIGHTS
from ..errors import RequestException
from ..server import ServerContext
@ -85,38 +87,37 @@ class ConversionContext(ServerContext):
return torch.device(self.training_device)
def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(
url,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (url, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
def download_progress(source: str, dest: str):
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(
source,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (source, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
return str(dest_path.absolute())
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
@ -347,7 +348,13 @@ def onnx_export(
)
DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"]
DIFFUSION_PREFIX = [
"diffusion-",
"diffusion/",
"diffusion\\",
"stable-diffusion-",
"upscaling-",
]
def fix_diffusion_name(name: str):
@ -359,3 +366,61 @@ def fix_diffusion_name(name: str):
return f"diffusion-{name}"
return name
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str],
) -> Tuple[str, bool]:
# TODO: switch to urlparse's default scheme
if source.startswith(path.sep) or source.startswith("."):
logger.info("adding file protocol to local path source: %s", source)
source = FileClient.protocol + source
for proto, client_type in model_sources.items():
if source.startswith(proto):
client = client_type()
return client.download(proto)
logger.warning("unknown model protocol, using path as provided: %s", source)
return source, False
def build_cache_paths(
conversion: ConversionContext,
name: str,
client: Optional[str] = None,
dest: Optional[str] = None,
format: Optional[str] = None,
) -> List[str]:
cache_path = dest or conversion.cache_path
# add an extension if possible, some of the conversion code checks for it
if format is not None:
basename = path.basename(name)
_filename, ext = path.splitext(basename)
if ext is None:
name = f"{name}.{format}"
paths = [
path.join(cache_path, name),
]
if client is not None:
client_path = path.join(cache_path, client)
paths.append(path.join(client_path, name))
return paths
def get_first_exists(
paths: List[str],
) -> Optional[str]:
for name in paths:
if path.exists(name):
logger.debug("model already exists in cache, skipping fetch: %s", name)
return name
return None

View File

@ -2,18 +2,24 @@ from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
from ..chain.stages import add_stage
from ..diffusers.load import add_pipeline
from ..server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
clients: Dict[str, Any]
converter: Dict[str, Any]
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
def __init__(
self, clients=None, converter=None, pipelines=None, stages=None
) -> None:
self.clients = clients or {}
self.converter = converter or {}
self.pipelines = pipelines or {}
self.stages = stages or {}

View File

@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
path = download_progress("https://example.com", "/tmp/example-dot-com")
self.assertEqual(path, "/tmp/example-dot-com")