split up download clients, add to plugin exports
This commit is contained in:
parent
e7da2cf8a6
commit
cc8e564d26
|
@ -3,16 +3,21 @@ from argparse import ArgumentParser
|
|||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from sys import exit
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from jsonschema import ValidationError, validate
|
||||
from onnx import load_model, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from onnx_web.server.plugin import load_plugins
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ..utils import load_config
|
||||
from .client.base import BaseClient
|
||||
from .client.civitai import CivitaiClient
|
||||
from .client.file import FileClient
|
||||
from .client.http import HttpClient
|
||||
from .client.huggingface import HuggingfaceClient
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.control import convert_diffusion_control
|
||||
from .diffusion.diffusion import convert_diffusion_diffusers
|
||||
|
@ -25,9 +30,8 @@ from .upscaling.swinir import convert_upscaling_swinir
|
|||
from .utils import (
|
||||
DEFAULT_OPSET,
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
fetch_model,
|
||||
fix_diffusion_name,
|
||||
remove_prefix,
|
||||
source_format,
|
||||
tuple_to_correction,
|
||||
tuple_to_diffusion,
|
||||
|
@ -45,16 +49,27 @@ warnings.filterwarnings(
|
|||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
|
||||
Models = Dict[str, List[Any]]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
ModelDict = Dict[str, Union[float, int, str]]
|
||||
Models = Dict[str, List[ModelDict]]
|
||||
|
||||
model_sources: Dict[str, Tuple[str, str]] = {
|
||||
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
||||
|
||||
model_sources: Dict[str, BaseClient] = {
|
||||
CivitaiClient.protocol: CivitaiClient,
|
||||
FileClient.protocol: FileClient,
|
||||
HttpClient.insecure_protocol: HttpClient,
|
||||
HttpClient.protocol: HttpClient,
|
||||
HuggingfaceClient.protocol: HuggingfaceClient,
|
||||
}
|
||||
|
||||
model_source_huggingface = "huggingface://"
|
||||
model_converters: Dict[str, Any] = {
|
||||
"img2img": convert_diffusion_diffusers,
|
||||
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
"inpaint": convert_diffusion_diffusers,
|
||||
"txt2img": convert_diffusion_diffusers,
|
||||
"txt2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
}
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
|
@ -62,15 +77,15 @@ base_models: Models = {
|
|||
# v1.x
|
||||
(
|
||||
"stable-diffusion-onnx-v1-5",
|
||||
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
||||
HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
|
||||
),
|
||||
(
|
||||
"stable-diffusion-onnx-v1-inpainting",
|
||||
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
||||
HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
|
||||
),
|
||||
(
|
||||
"upscaling-stable-diffusion-x4",
|
||||
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
||||
HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
|
||||
True,
|
||||
),
|
||||
],
|
||||
|
@ -201,69 +216,226 @@ base_models: Models = {
|
|||
}
|
||||
|
||||
|
||||
def fetch_model(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
dest: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
hf_hub_fetch: bool = False,
|
||||
hf_hub_filename: Optional[str] = None,
|
||||
) -> Tuple[str, bool]:
|
||||
cache_path = dest or conversion.cache_path
|
||||
cache_name = path.join(cache_path, name)
|
||||
def add_model_source(proto: str, client: BaseClient):
|
||||
global model_sources
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if format is None:
|
||||
url = urlparse(source)
|
||||
ext = path.basename(url.path)
|
||||
_filename, ext = path.splitext(ext)
|
||||
if ext is not None:
|
||||
cache_name = cache_name + ext
|
||||
if proto in model_sources:
|
||||
raise ValueError("protocol has already been taken")
|
||||
|
||||
model_sources[proto] = client
|
||||
|
||||
|
||||
def convert_source_model(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source = model["source"]
|
||||
|
||||
dest_path = None
|
||||
if "dest" in model:
|
||||
dest_path = path.join(conversion.model_path, model["dest"])
|
||||
|
||||
dest, hf = fetch_model(
|
||||
conversion, name, source, format=model_format, dest=dest_path
|
||||
)
|
||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
|
||||
|
||||
def convert_network_model(conversion: ConversionContext, network):
|
||||
network_format = source_format(network)
|
||||
network_model = network.get("model", None)
|
||||
name = network["name"]
|
||||
network_type = network["type"]
|
||||
source = network["source"]
|
||||
|
||||
if network_type == "control":
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
convert_diffusion_control(
|
||||
conversion,
|
||||
network,
|
||||
dest,
|
||||
path.join(conversion.model_path, network_type, name),
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
hf_hub_fetch=True,
|
||||
hf_hub_filename="learned_embeds.bin",
|
||||
)
|
||||
else:
|
||||
cache_name = f"{cache_name}.{format}"
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
if path.exists(cache_name):
|
||||
logger.debug("model already exists in cache, skipping fetch")
|
||||
return cache_name, False
|
||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||
|
||||
for proto in model_sources:
|
||||
api_name, api_root = model_sources.get(proto)
|
||||
if source.startswith(proto):
|
||||
api_source = api_root % (remove_prefix(source, proto))
|
||||
logger.info(
|
||||
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
||||
|
||||
def convert_diffusion_model(conversion: ConversionContext, model):
|
||||
# fix up entries with missing prefixes
|
||||
name = fix_diffusion_name(model["name"])
|
||||
if name != model["name"]:
|
||||
# update the model in-memory if the name changed
|
||||
model["name"] = name
|
||||
|
||||
model_format = source_format(model)
|
||||
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
|
||||
pipeline = model.get("pipeline", "txt2img")
|
||||
converter = model_converters.get(pipeline)
|
||||
converted, dest = converter(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
|
||||
# make sure blending only happens once, not every run
|
||||
if converted:
|
||||
# keep track of which models have been blended
|
||||
blend_models = {}
|
||||
|
||||
inversion_dest = path.join(conversion.model_path, "inversion")
|
||||
lora_dest = path.join(conversion.model_path, "lora")
|
||||
|
||||
for inversion in model.get("inversions", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(
|
||||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
if "tokenizer" not in blend_models:
|
||||
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
dest,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", None)
|
||||
inversion_source, hf = fetch_model(
|
||||
conversion,
|
||||
inversion_name,
|
||||
inversion_source,
|
||||
dest=inversion_dest,
|
||||
)
|
||||
return download_progress([(api_source, cache_name)]), False
|
||||
inversion_token = inversion.get("token", inversion_name)
|
||||
inversion_weight = inversion.get("weight", 1.0)
|
||||
|
||||
if source.startswith(model_source_huggingface):
|
||||
hub_source = remove_prefix(source, model_source_huggingface)
|
||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
if hf_hub_fetch:
|
||||
return (
|
||||
hf_hub_download(
|
||||
repo_id=hub_source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
),
|
||||
False,
|
||||
blend_textual_inversions(
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
blend_models["tokenizer"],
|
||||
[
|
||||
(
|
||||
inversion_source,
|
||||
inversion_weight,
|
||||
inversion_token,
|
||||
inversion_format,
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return hub_source, True
|
||||
elif source.startswith("https://"):
|
||||
logger.info("downloading model from: %s", source)
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith("http://"):
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("using local model: %s", source)
|
||||
return source, False
|
||||
|
||||
for lora in model.get("loras", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(
|
||||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
if "unet" not in blend_models:
|
||||
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
|
||||
|
||||
# load models if not loaded yet
|
||||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source, hf = fetch_model(
|
||||
conversion,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source,
|
||||
dest=lora_dest,
|
||||
)
|
||||
lora_weight = lora.get("weight", 1.0)
|
||||
|
||||
blend_loras(
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
[(lora_source, lora_weight)],
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
blend_loras(
|
||||
conversion,
|
||||
blend_models["unet"],
|
||||
[(lora_source, lora_weight)],
|
||||
"unet",
|
||||
)
|
||||
|
||||
if "tokenizer" in blend_models:
|
||||
dest_path = path.join(dest, "tokenizer")
|
||||
logger.debug("saving blended tokenizer to %s", dest_path)
|
||||
blend_models["tokenizer"].save_pretrained(dest_path)
|
||||
|
||||
for name in ["text_encoder", "unet"]:
|
||||
if name in blend_models:
|
||||
dest_path = path.join(dest, name, ONNX_MODEL)
|
||||
logger.debug("saving blended %s model to %s", name, dest_path)
|
||||
save_model(
|
||||
blend_models[name],
|
||||
dest_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
def convert_upscaling_model(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
|
||||
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
else:
|
||||
logger.info("unknown model location, using path as provided: %s", source)
|
||||
return source, False
|
||||
logger.error("unknown upscaling model type %s for %s", model_type, name)
|
||||
model_errors.append(name)
|
||||
|
||||
|
||||
def convert_correction_model(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source, hf = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
if model_type == "gfpgan":
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
else:
|
||||
logger.error("unknown correction model type %s for %s", model_type, name)
|
||||
model_errors.append(name)
|
||||
|
||||
|
||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||
|
@ -277,18 +449,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping source: %s", name)
|
||||
else:
|
||||
model_format = source_format(model)
|
||||
source = model["source"]
|
||||
|
||||
try:
|
||||
dest_path = None
|
||||
if "dest" in model:
|
||||
dest_path = path.join(conversion.model_path, model["dest"])
|
||||
|
||||
dest, hf = fetch_model(
|
||||
conversion, name, source, format=model_format, dest=dest_path
|
||||
)
|
||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
convert_source_model(model)
|
||||
except Exception:
|
||||
logger.exception("error fetching source %s", name)
|
||||
model_errors.append(name)
|
||||
|
@ -300,46 +462,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping network: %s", name)
|
||||
else:
|
||||
network_format = source_format(network)
|
||||
network_model = network.get("model", None)
|
||||
network_type = network["type"]
|
||||
source = network["source"]
|
||||
|
||||
try:
|
||||
if network_type == "control":
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
convert_diffusion_control(
|
||||
conversion,
|
||||
network,
|
||||
dest,
|
||||
path.join(conversion.model_path, network_type, name),
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
hf_hub_fetch=True,
|
||||
hf_hub_filename="learned_embeds.bin",
|
||||
)
|
||||
else:
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||
convert_network_model(conversion, model)
|
||||
except Exception:
|
||||
logger.exception("error fetching network %s", name)
|
||||
model_errors.append(name)
|
||||
|
@ -352,148 +476,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
# fix up entries with missing prefixes
|
||||
name = fix_diffusion_name(name)
|
||||
if name != model["name"]:
|
||||
# update the model in-memory if the name changed
|
||||
model["name"] = name
|
||||
|
||||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
pipeline = model.get("pipeline", "txt2img")
|
||||
if pipeline.endswith("-sdxl"):
|
||||
converted, dest = convert_diffusion_diffusers_xl(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
else:
|
||||
converted, dest = convert_diffusion_diffusers(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
|
||||
# make sure blending only happens once, not every run
|
||||
if converted:
|
||||
# keep track of which models have been blended
|
||||
blend_models = {}
|
||||
|
||||
inversion_dest = path.join(conversion.model_path, "inversion")
|
||||
lora_dest = path.join(conversion.model_path, "lora")
|
||||
|
||||
for inversion in model.get("inversions", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(
|
||||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
if "tokenizer" not in blend_models:
|
||||
blend_models[
|
||||
"tokenizer"
|
||||
] = CLIPTokenizer.from_pretrained(
|
||||
dest,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", None)
|
||||
inversion_source, hf = fetch_model(
|
||||
conversion,
|
||||
inversion_name,
|
||||
inversion_source,
|
||||
dest=inversion_dest,
|
||||
)
|
||||
inversion_token = inversion.get("token", inversion_name)
|
||||
inversion_weight = inversion.get("weight", 1.0)
|
||||
|
||||
blend_textual_inversions(
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
blend_models["tokenizer"],
|
||||
[
|
||||
(
|
||||
inversion_source,
|
||||
inversion_weight,
|
||||
inversion_token,
|
||||
inversion_format,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
for lora in model.get("loras", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(
|
||||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
if "unet" not in blend_models:
|
||||
blend_models["unet"] = load_model(
|
||||
path.join(dest, "unet", ONNX_MODEL)
|
||||
)
|
||||
|
||||
# load models if not loaded yet
|
||||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source, hf = fetch_model(
|
||||
conversion,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source,
|
||||
dest=lora_dest,
|
||||
)
|
||||
lora_weight = lora.get("weight", 1.0)
|
||||
|
||||
blend_loras(
|
||||
conversion,
|
||||
blend_models["text_encoder"],
|
||||
[(lora_source, lora_weight)],
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
blend_loras(
|
||||
conversion,
|
||||
blend_models["unet"],
|
||||
[(lora_source, lora_weight)],
|
||||
"unet",
|
||||
)
|
||||
|
||||
if "tokenizer" in blend_models:
|
||||
dest_path = path.join(dest, "tokenizer")
|
||||
logger.debug("saving blended tokenizer to %s", dest_path)
|
||||
blend_models["tokenizer"].save_pretrained(dest_path)
|
||||
|
||||
for name in ["text_encoder", "unet"]:
|
||||
if name in blend_models:
|
||||
dest_path = path.join(dest, name, ONNX_MODEL)
|
||||
logger.debug(
|
||||
"saving blended %s model to %s", name, dest_path
|
||||
)
|
||||
save_model(
|
||||
blend_models[name],
|
||||
dest_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
convert_diffusion_model(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting diffusion model %s",
|
||||
|
@ -509,24 +493,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown upscaling model type %s for %s", model_type, name
|
||||
)
|
||||
model_errors.append(name)
|
||||
convert_upscaling_model(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting upscaling model %s",
|
||||
|
@ -542,19 +510,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
model_format = source_format(model)
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
if model_type == "gfpgan":
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown correction model type %s for %s", model_type, name
|
||||
)
|
||||
model_errors.append(name)
|
||||
convert_correction_model(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting correction model %s",
|
||||
|
@ -566,12 +523,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.error("error while converting models: %s", model_errors)
|
||||
|
||||
|
||||
def register_plugins(conversion: ConversionContext):
|
||||
logger.info("loading conversion plugins")
|
||||
exports = load_plugins(conversion)
|
||||
|
||||
for proto, client in exports.clients:
|
||||
try:
|
||||
add_model_source(proto, client)
|
||||
except Exception:
|
||||
logger.exception("error loading client for protocol: %s", proto)
|
||||
|
||||
# TODO: add converters
|
||||
|
||||
|
||||
def main(args=None) -> int:
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||
)
|
||||
|
||||
# model groups
|
||||
parser.add_argument("--base", action="store_true", default=True)
|
||||
parser.add_argument("--networks", action="store_true", default=True)
|
||||
parser.add_argument("--sources", action="store_true", default=True)
|
||||
parser.add_argument("--correction", action="store_true", default=False)
|
||||
|
@ -609,16 +580,20 @@ def main(args=None) -> int:
|
|||
server.half = args.half or server.has_optimization("onnx-fp16")
|
||||
server.opset = args.opset
|
||||
server.token = args.token
|
||||
|
||||
register_plugins(server)
|
||||
|
||||
logger.info(
|
||||
"converting models in %s using %s", server.model_path, server.training_device
|
||||
"converting models into %s using %s", server.model_path, server.training_device
|
||||
)
|
||||
|
||||
if not path.exists(server.model_path):
|
||||
logger.info("model path does not existing, creating: %s", server.model_path)
|
||||
makedirs(server.model_path)
|
||||
|
||||
logger.info("converting base models")
|
||||
convert_models(server, args, base_models)
|
||||
if args.base:
|
||||
logger.info("converting base models")
|
||||
convert_models(server, args, base_models)
|
||||
|
||||
extras = []
|
||||
extras.extend(server.extra_models)
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Optional
|
||||
from ..utils import ConversionContext
|
||||
|
||||
|
||||
class BaseClient:
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,45 @@
|
|||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
download_progress,
|
||||
get_first_exists,
|
||||
remove_prefix,
|
||||
)
|
||||
from .base import BaseClient
|
||||
from typing import Optional
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
|
||||
|
||||
|
||||
class CivitaiClient(BaseClient):
|
||||
protocol = "civitai://"
|
||||
root: str
|
||||
token: Optional[str]
|
||||
|
||||
def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT):
|
||||
self.root = root
|
||||
self.token = token
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
TODO: download with auth token
|
||||
"""
|
||||
cache_paths = build_cache_paths(
|
||||
conversion, name, client=CivitaiClient.name, format=format
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
|
||||
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
|
||||
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,20 @@
|
|||
from .base import BaseClient
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FileClient(BaseClient):
|
||||
protocol = "file://"
|
||||
|
||||
root: str
|
||||
|
||||
def __init__(self, root: str):
|
||||
self.root = root
|
||||
|
||||
def download(self, uri: str) -> str:
|
||||
parts = urlparse(uri)
|
||||
logger.info("loading model from: %s", parts.path)
|
||||
return path.join(self.root, parts.path)
|
|
@ -0,0 +1,39 @@
|
|||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
download_progress,
|
||||
get_first_exists,
|
||||
remove_prefix,
|
||||
)
|
||||
from .base import BaseClient
|
||||
from typing import Dict, Optional
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class HttpClient(BaseClient):
|
||||
name = "http"
|
||||
protocol = "https://"
|
||||
insecure_protocol = "http://"
|
||||
|
||||
headers: Dict[str, str]
|
||||
|
||||
def __init__(self, headers: Optional[Dict[str, str]] = None):
|
||||
self.headers = headers or {}
|
||||
|
||||
def download(self, conversion: ConversionContext, name: str, uri: str) -> str:
|
||||
cache_paths = build_cache_paths(
|
||||
conversion, name, client=HttpClient.name, format=format
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if uri.startswith(HttpClient.protocol):
|
||||
source = remove_prefix(uri, HttpClient.protocol)
|
||||
logger.info("downloading model from: %s", source)
|
||||
elif uri.startswith(HttpClient.insecure_protocol):
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
|
||||
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,61 @@
|
|||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
get_first_exists,
|
||||
remove_prefix,
|
||||
)
|
||||
from .base import BaseClient
|
||||
from typing import Optional, Any
|
||||
from logging import getLogger
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceClient(BaseClient):
|
||||
name = "huggingface"
|
||||
protocol = "huggingface://"
|
||||
|
||||
download: Any
|
||||
token: Optional[str]
|
||||
|
||||
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
|
||||
self.download = download
|
||||
self.token = token
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
) -> str:
|
||||
"""
|
||||
TODO: download with auth
|
||||
"""
|
||||
hf_hub_fetch = TODO
|
||||
hf_hub_filename = TODO
|
||||
|
||||
cache_paths = build_cache_paths(
|
||||
conversion, name, client=HuggingfaceClient.name, format=format
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
source = remove_prefix(source, HuggingfaceClient.protocol)
|
||||
logger.info("downloading model from Huggingface Hub: %s", source)
|
||||
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
if hf_hub_fetch:
|
||||
return (
|
||||
hf_hub_download(
|
||||
repo_id=source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_paths[0],
|
||||
force_filename=f"{name}.bin",
|
||||
),
|
||||
False,
|
||||
)
|
||||
else:
|
||||
return source
|
|
@ -15,6 +15,8 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
|
|||
from packaging import version
|
||||
from torch.onnx import export
|
||||
|
||||
from onnx_web.convert.client.file import FileClient
|
||||
|
||||
from ..constants import ONNX_WEIGHTS
|
||||
from ..errors import RequestException
|
||||
from ..server import ServerContext
|
||||
|
@ -85,38 +87,37 @@ class ConversionContext(ServerContext):
|
|||
return torch.device(self.training_device)
|
||||
|
||||
|
||||
def download_progress(urls: List[Tuple[str, str]]):
|
||||
for url, dest in urls:
|
||||
dest_path = Path(dest).expanduser().resolve()
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest_path.exists():
|
||||
logger.debug("destination already exists: %s", dest_path)
|
||||
return str(dest_path.absolute())
|
||||
|
||||
req = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
"User-Agent": "onnx-web-api",
|
||||
},
|
||||
)
|
||||
if req.status_code != 200:
|
||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||
raise RequestException(
|
||||
"request to %s failed with status code: %s" % (url, req.status_code)
|
||||
)
|
||||
|
||||
total = int(req.headers.get("Content-Length", 0))
|
||||
desc = "unknown" if total == 0 else ""
|
||||
req.raw.read = partial(req.raw.read, decode_content=True)
|
||||
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
||||
with dest_path.open("wb") as f:
|
||||
shutil.copyfileobj(data, f)
|
||||
def download_progress(source: str, dest: str):
|
||||
dest_path = Path(dest).expanduser().resolve()
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest_path.exists():
|
||||
logger.debug("destination already exists: %s", dest_path)
|
||||
return str(dest_path.absolute())
|
||||
|
||||
req = requests.get(
|
||||
source,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
"User-Agent": "onnx-web-api",
|
||||
},
|
||||
)
|
||||
if req.status_code != 200:
|
||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||
raise RequestException(
|
||||
"request to %s failed with status code: %s" % (source, req.status_code)
|
||||
)
|
||||
|
||||
total = int(req.headers.get("Content-Length", 0))
|
||||
desc = "unknown" if total == 0 else ""
|
||||
req.raw.read = partial(req.raw.read, decode_content=True)
|
||||
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
||||
with dest_path.open("wb") as f:
|
||||
shutil.copyfileobj(data, f)
|
||||
|
||||
return str(dest_path.absolute())
|
||||
|
||||
|
||||
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
|
@ -347,7 +348,13 @@ def onnx_export(
|
|||
)
|
||||
|
||||
|
||||
DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"]
|
||||
DIFFUSION_PREFIX = [
|
||||
"diffusion-",
|
||||
"diffusion/",
|
||||
"diffusion\\",
|
||||
"stable-diffusion-",
|
||||
"upscaling-",
|
||||
]
|
||||
|
||||
|
||||
def fix_diffusion_name(name: str):
|
||||
|
@ -359,3 +366,61 @@ def fix_diffusion_name(name: str):
|
|||
return f"diffusion-{name}"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def fetch_model(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
) -> Tuple[str, bool]:
|
||||
# TODO: switch to urlparse's default scheme
|
||||
if source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("adding file protocol to local path source: %s", source)
|
||||
source = FileClient.protocol + source
|
||||
|
||||
for proto, client_type in model_sources.items():
|
||||
if source.startswith(proto):
|
||||
client = client_type()
|
||||
return client.download(proto)
|
||||
|
||||
logger.warning("unknown model protocol, using path as provided: %s", source)
|
||||
return source, False
|
||||
|
||||
|
||||
def build_cache_paths(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
client: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
cache_path = dest or conversion.cache_path
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if format is not None:
|
||||
basename = path.basename(name)
|
||||
_filename, ext = path.splitext(basename)
|
||||
if ext is None:
|
||||
name = f"{name}.{format}"
|
||||
|
||||
paths = [
|
||||
path.join(cache_path, name),
|
||||
]
|
||||
|
||||
if client is not None:
|
||||
client_path = path.join(cache_path, client)
|
||||
paths.append(path.join(client_path, name))
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def get_first_exists(
|
||||
paths: List[str],
|
||||
) -> Optional[str]:
|
||||
for name in paths:
|
||||
if path.exists(name):
|
||||
logger.debug("model already exists in cache, skipping fetch: %s", name)
|
||||
return name
|
||||
|
||||
return None
|
||||
|
|
|
@ -2,18 +2,24 @@ from importlib import import_module
|
|||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from onnx_web.chain.stages import add_stage
|
||||
from onnx_web.diffusers.load import add_pipeline
|
||||
from onnx_web.server.context import ServerContext
|
||||
from ..chain.stages import add_stage
|
||||
from ..diffusers.load import add_pipeline
|
||||
from ..server.context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class PluginExports:
|
||||
clients: Dict[str, Any]
|
||||
converter: Dict[str, Any]
|
||||
pipelines: Dict[str, Any]
|
||||
stages: Dict[str, Any]
|
||||
|
||||
def __init__(self, pipelines=None, stages=None) -> None:
|
||||
def __init__(
|
||||
self, clients=None, converter=None, pipelines=None, stages=None
|
||||
) -> None:
|
||||
self.clients = clients or {}
|
||||
self.converter = converter or {}
|
||||
self.pipelines = pipelines or {}
|
||||
self.stages = stages or {}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
|
|||
|
||||
class DownloadProgressTests(unittest.TestCase):
|
||||
def test_download_example(self):
|
||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
||||
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
||||
self.assertEqual(path, "/tmp/example-dot-com")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue