2023-02-11 04:41:24 +00:00
|
|
|
import shutil
|
|
|
|
from functools import partial
|
|
|
|
from logging import getLogger
|
2023-02-11 18:36:54 +00:00
|
|
|
from os import environ, path
|
2023-02-11 04:41:24 +00:00
|
|
|
from pathlib import Path
|
2023-02-11 05:32:16 +00:00
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
import requests
|
2023-02-17 02:18:42 +00:00
|
|
|
import safetensors
|
2023-02-09 04:35:54 +00:00
|
|
|
import torch
|
2023-02-18 15:25:01 +00:00
|
|
|
from huggingface_hub.utils.tqdm import tqdm
|
2023-04-12 00:29:25 +00:00
|
|
|
from onnx import load_model, save_model
|
|
|
|
from onnx.shape_inference import infer_shapes_path
|
|
|
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
|
|
|
from packaging import version
|
|
|
|
from torch.onnx import export
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
from ..constants import ONNX_WEIGHTS
|
2023-08-21 03:28:40 +00:00
|
|
|
from ..errors import RequestException
|
2023-02-19 02:28:21 +00:00
|
|
|
from ..server import ServerContext
|
2023-05-15 01:04:43 +00:00
|
|
|
from ..utils import get_boolean
|
2023-02-18 15:25:01 +00:00
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
is_torch_2_0 = version.parse(
|
|
|
|
version.parse(torch.__version__).base_version
|
|
|
|
) >= version.parse("2.0")
|
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
ModelDict = Dict[str, Union[str, int]]
|
|
|
|
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
|
|
|
|
2023-05-15 01:04:43 +00:00
|
|
|
DEFAULT_OPSET = 14
|
2023-12-10 00:46:47 +00:00
|
|
|
DIFFUSION_PREFIX = [
|
|
|
|
"diffusion-",
|
|
|
|
"diffusion/",
|
|
|
|
"diffusion\\",
|
|
|
|
"stable-diffusion-",
|
|
|
|
"upscaling-", # SD upscaling
|
|
|
|
]
|
|
|
|
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
|
|
|
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
2023-05-15 01:04:43 +00:00
|
|
|
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-02-18 15:25:01 +00:00
|
|
|
class ConversionContext(ServerContext):
|
2023-02-11 04:41:24 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-12-03 18:53:50 +00:00
|
|
|
model_path: str = ".",
|
2023-02-11 04:41:24 +00:00
|
|
|
cache_path: Optional[str] = None,
|
2023-02-18 15:31:20 +00:00
|
|
|
device: Optional[str] = None,
|
2023-05-15 00:30:30 +00:00
|
|
|
half: bool = False,
|
2023-05-15 01:04:43 +00:00
|
|
|
opset: int = DEFAULT_OPSET,
|
2023-02-11 04:41:24 +00:00
|
|
|
token: Optional[str] = None,
|
2023-03-09 03:38:17 +00:00
|
|
|
prune: Optional[List[str]] = None,
|
2023-05-15 00:30:30 +00:00
|
|
|
control: bool = True,
|
2023-05-15 01:29:18 +00:00
|
|
|
reload: bool = True,
|
2023-05-20 20:19:38 +00:00
|
|
|
share_unet: bool = True,
|
2023-05-21 01:06:54 +00:00
|
|
|
extract: bool = False,
|
2023-02-18 15:31:20 +00:00
|
|
|
**kwargs,
|
2023-02-11 04:41:24 +00:00
|
|
|
) -> None:
|
2023-03-19 15:39:09 +00:00
|
|
|
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
2023-02-18 15:25:01 +00:00
|
|
|
|
2023-05-15 00:30:30 +00:00
|
|
|
self.control = control
|
2023-05-21 01:06:54 +00:00
|
|
|
self.extract = extract
|
2023-02-11 04:41:24 +00:00
|
|
|
self.half = half
|
|
|
|
self.opset = opset
|
2023-03-09 03:38:17 +00:00
|
|
|
self.prune = prune or []
|
2023-05-15 01:29:18 +00:00
|
|
|
self.reload = reload
|
2023-05-20 20:19:38 +00:00
|
|
|
self.share_unet = share_unet
|
2023-05-15 01:29:18 +00:00
|
|
|
self.token = token
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-02-11 18:36:54 +00:00
|
|
|
if device is not None:
|
|
|
|
self.training_device = device
|
|
|
|
else:
|
|
|
|
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
2023-05-15 01:04:43 +00:00
|
|
|
@classmethod
|
|
|
|
def from_environ(cls):
|
|
|
|
context = super().from_environ()
|
|
|
|
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
2023-12-24 12:25:24 +00:00
|
|
|
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True)
|
2023-05-15 01:29:18 +00:00
|
|
|
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
2023-05-20 20:19:38 +00:00
|
|
|
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
|
|
|
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
2023-06-08 12:20:03 +00:00
|
|
|
|
|
|
|
cpu_only = get_boolean(environ, "ONNX_WEB_CONVERT_CPU_ONLY", False)
|
|
|
|
if cpu_only:
|
|
|
|
context.training_device = "cpu"
|
|
|
|
|
2023-05-15 01:05:49 +00:00
|
|
|
return context
|
2023-05-15 01:04:43 +00:00
|
|
|
|
2023-06-08 12:20:03 +00:00
|
|
|
@property
|
|
|
|
def map_location(self):
|
|
|
|
return torch.device(self.training_device)
|
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-12-10 00:04:34 +00:00
|
|
|
def download_progress(source: str, dest: str):
|
|
|
|
dest_path = Path(dest).expanduser().resolve()
|
|
|
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-12-10 00:04:34 +00:00
|
|
|
if dest_path.exists():
|
|
|
|
logger.debug("destination already exists: %s", dest_path)
|
|
|
|
return str(dest_path.absolute())
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-12-10 00:04:34 +00:00
|
|
|
req = requests.get(
|
|
|
|
source,
|
|
|
|
stream=True,
|
|
|
|
allow_redirects=True,
|
|
|
|
headers={
|
|
|
|
"User-Agent": "onnx-web-api",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if req.status_code != 200:
|
|
|
|
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
|
|
|
raise RequestException(
|
|
|
|
"request to %s failed with status code: %s" % (source, req.status_code)
|
2023-02-11 05:32:16 +00:00
|
|
|
)
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-12-10 00:04:34 +00:00
|
|
|
total = int(req.headers.get("Content-Length", 0))
|
|
|
|
desc = "unknown" if total == 0 else ""
|
|
|
|
req.raw.read = partial(req.raw.read, decode_content=True)
|
|
|
|
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
|
|
|
with dest_path.open("wb") as f:
|
|
|
|
shutil.copyfileobj(data, f)
|
|
|
|
|
|
|
|
return str(dest_path.absolute())
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
|
2023-02-12 15:28:37 +00:00
|
|
|
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
|
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
2023-12-03 18:53:50 +00:00
|
|
|
name, source, *_rest = model
|
2023-02-12 15:28:37 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"name": name,
|
|
|
|
"source": source,
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
|
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
|
|
name, source, *rest = model
|
2023-09-28 23:45:04 +00:00
|
|
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
|
|
|
half = rest.pop(0) if len(rest) > 0 else False
|
|
|
|
opset = rest.pop(0) if len(rest) > 0 else None
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"name": name,
|
|
|
|
"source": source,
|
|
|
|
"half": half,
|
|
|
|
"opset": opset,
|
|
|
|
"scale": scale,
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
|
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
|
|
name, source, *rest = model
|
2023-09-28 23:45:04 +00:00
|
|
|
single_vae = rest.pop(0) if len(rest) > 0 else False
|
|
|
|
half = rest.pop(0) if len(rest) > 0 else False
|
|
|
|
opset = rest.pop(0) if len(rest) > 0 else None
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"name": name,
|
|
|
|
"source": source,
|
|
|
|
"half": half,
|
|
|
|
"opset": opset,
|
|
|
|
"single_vae": single_vae,
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
|
|
name, source, *rest = model
|
2023-09-28 23:45:04 +00:00
|
|
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
|
|
|
half = rest.pop(0) if len(rest) > 0 else False
|
|
|
|
opset = rest.pop(0) if len(rest) > 0 else None
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"name": name,
|
|
|
|
"source": source,
|
|
|
|
"half": half,
|
|
|
|
"opset": opset,
|
|
|
|
"scale": scale,
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
2023-10-07 00:03:15 +00:00
|
|
|
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
|
|
|
_name, ext = path.splitext(name)
|
|
|
|
ext = ext.strip(".")
|
|
|
|
|
2023-11-17 04:04:11 +00:00
|
|
|
return (ext in exts, ext)
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-02-11 05:32:16 +00:00
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
def source_format(model: Dict) -> Optional[str]:
|
|
|
|
if "format" in model:
|
|
|
|
return model["format"]
|
|
|
|
|
|
|
|
if "source" in model:
|
2023-10-07 00:03:15 +00:00
|
|
|
valid, ext = check_ext(model["source"], MODEL_FORMATS)
|
|
|
|
if valid:
|
2023-02-11 04:41:24 +00:00
|
|
|
return ext
|
|
|
|
|
|
|
|
return None
|
2023-02-11 18:36:54 +00:00
|
|
|
|
|
|
|
|
2023-03-18 15:50:48 +00:00
|
|
|
def remove_prefix(name: str, prefix: str) -> str:
|
2023-02-15 03:23:16 +00:00
|
|
|
if name.startswith(prefix):
|
|
|
|
return name[len(prefix) :]
|
|
|
|
|
|
|
|
return name
|
2023-02-17 02:18:42 +00:00
|
|
|
|
|
|
|
|
2023-03-19 20:27:51 +00:00
|
|
|
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
2023-12-10 00:46:47 +00:00
|
|
|
"""
|
|
|
|
TODO: move out of convert
|
|
|
|
"""
|
2023-03-19 20:27:51 +00:00
|
|
|
try:
|
2023-03-22 03:19:50 +00:00
|
|
|
logger.debug("loading tensor with Torch: %s", name)
|
|
|
|
checkpoint = torch.load(name, map_location=map_location)
|
2023-03-19 20:27:51 +00:00
|
|
|
except Exception:
|
|
|
|
logger.exception(
|
2023-03-22 03:19:50 +00:00
|
|
|
"error loading with Torch JIT, trying with Torch JIT: %s", name
|
2023-03-19 20:27:51 +00:00
|
|
|
)
|
2023-03-22 03:19:50 +00:00
|
|
|
checkpoint = torch.jit.load(name)
|
2023-03-19 20:27:51 +00:00
|
|
|
|
|
|
|
return checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
2023-12-10 00:46:47 +00:00
|
|
|
"""
|
|
|
|
TODO: move out of convert
|
|
|
|
"""
|
2023-03-19 20:13:54 +00:00
|
|
|
logger.debug("loading tensor: %s", name)
|
2023-02-17 02:18:42 +00:00
|
|
|
_, extension = path.splitext(name)
|
2023-03-19 20:13:54 +00:00
|
|
|
extension = extension[1:].lower()
|
|
|
|
|
2023-03-19 20:38:43 +00:00
|
|
|
checkpoint = None
|
2023-03-19 20:27:51 +00:00
|
|
|
if extension == "":
|
|
|
|
# if no extension was intentional, do not search for others
|
|
|
|
if path.exists(name):
|
|
|
|
logger.debug("loading anonymous tensor")
|
|
|
|
checkpoint = torch.load(name, map_location=map_location)
|
|
|
|
else:
|
|
|
|
logger.debug("searching for tensors with known extensions")
|
2023-06-26 22:24:34 +00:00
|
|
|
for next_extension in RESOLVE_FORMATS:
|
2023-03-19 20:38:43 +00:00
|
|
|
next_name = f"{name}.{next_extension}"
|
|
|
|
if path.exists(next_name):
|
|
|
|
checkpoint = load_tensor(next_name, map_location=map_location)
|
|
|
|
if checkpoint is not None:
|
|
|
|
break
|
2023-03-19 20:27:51 +00:00
|
|
|
elif extension == "safetensors":
|
|
|
|
logger.debug("loading safetensors")
|
2023-03-19 20:38:43 +00:00
|
|
|
try:
|
|
|
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
|
|
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error loading safetensor: %s", e)
|
2023-03-19 20:27:51 +00:00
|
|
|
elif extension in ["bin", "ckpt", "pt"]:
|
|
|
|
logger.debug("loading pickle tensor")
|
2023-03-19 20:38:43 +00:00
|
|
|
try:
|
|
|
|
checkpoint = load_torch(name, map_location=map_location)
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error loading pickle tensor: %s", e)
|
2023-03-19 20:13:54 +00:00
|
|
|
elif extension in ["onnx", "pt"]:
|
2023-03-20 01:16:52 +00:00
|
|
|
logger.warning(
|
2023-04-10 01:34:10 +00:00
|
|
|
"tensor has ONNX extension, attempting to use PyTorch anyways: %s",
|
|
|
|
extension,
|
2023-03-20 01:16:52 +00:00
|
|
|
)
|
2023-03-19 20:38:43 +00:00
|
|
|
try:
|
|
|
|
checkpoint = load_torch(name, map_location=map_location)
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error loading tensor: %s", e)
|
2023-03-19 20:13:54 +00:00
|
|
|
else:
|
2023-03-19 20:27:51 +00:00
|
|
|
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
2023-03-19 20:38:43 +00:00
|
|
|
try:
|
|
|
|
checkpoint = load_torch(name, map_location=map_location)
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error loading tensor: %s", e)
|
2023-03-19 20:13:54 +00:00
|
|
|
|
2023-04-30 03:56:52 +00:00
|
|
|
if checkpoint is None:
|
|
|
|
raise ValueError("error loading tensor")
|
|
|
|
|
2023-03-19 20:38:43 +00:00
|
|
|
if checkpoint is not None and "state_dict" in checkpoint:
|
2023-03-19 20:13:54 +00:00
|
|
|
checkpoint = checkpoint["state_dict"]
|
2023-02-17 02:23:10 +00:00
|
|
|
|
2023-02-17 03:44:33 +00:00
|
|
|
return checkpoint
|
2023-04-12 00:29:25 +00:00
|
|
|
|
|
|
|
|
2023-06-26 22:24:34 +00:00
|
|
|
def resolve_tensor(name: str) -> Optional[str]:
|
2023-12-10 00:46:47 +00:00
|
|
|
"""
|
|
|
|
TODO: move out of convert
|
|
|
|
"""
|
2023-06-26 22:24:34 +00:00
|
|
|
logger.debug("searching for tensors with known extensions: %s", name)
|
|
|
|
for next_extension in RESOLVE_FORMATS:
|
|
|
|
next_name = f"{name}.{next_extension}"
|
|
|
|
if path.exists(next_name):
|
|
|
|
return next_name
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
def onnx_export(
|
|
|
|
model,
|
|
|
|
model_args: tuple,
|
|
|
|
output_path: Path,
|
|
|
|
ordered_input_names,
|
|
|
|
output_names,
|
|
|
|
dynamic_axes,
|
|
|
|
opset,
|
|
|
|
half=False,
|
|
|
|
external_data=False,
|
2023-04-30 13:49:34 +00:00
|
|
|
v2=False,
|
2023-09-28 23:45:04 +00:00
|
|
|
op_block_list=None,
|
2023-04-12 00:29:25 +00:00
|
|
|
):
|
|
|
|
"""
|
|
|
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
|
|
|
"""
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
output_file = output_path.absolute().as_posix()
|
|
|
|
|
|
|
|
export(
|
|
|
|
model,
|
|
|
|
model_args,
|
|
|
|
f=output_file,
|
|
|
|
input_names=ordered_input_names,
|
|
|
|
output_names=output_names,
|
|
|
|
dynamic_axes=dynamic_axes,
|
|
|
|
do_constant_folding=True,
|
|
|
|
opset_version=opset,
|
|
|
|
)
|
|
|
|
|
2023-09-28 23:45:04 +00:00
|
|
|
if v2 and op_block_list is None:
|
2023-04-30 22:27:51 +00:00
|
|
|
op_block_list = ["Attention", "MultiHeadAttention"]
|
2023-04-30 13:49:34 +00:00
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
if half:
|
|
|
|
logger.info("converting model to fp16 internally: %s", output_file)
|
|
|
|
infer_shapes_path(output_file)
|
|
|
|
base_model = load_model(output_file)
|
|
|
|
opt_model = convert_float_to_float16(
|
|
|
|
base_model,
|
|
|
|
disable_shape_infer=True,
|
|
|
|
force_fp16_initializers=True,
|
2023-04-30 13:49:34 +00:00
|
|
|
keep_io_types=True,
|
|
|
|
op_block_list=op_block_list,
|
2023-04-12 00:29:25 +00:00
|
|
|
)
|
|
|
|
save_model(
|
|
|
|
opt_model,
|
|
|
|
f"{output_file}",
|
|
|
|
save_as_external_data=external_data,
|
|
|
|
all_tensors_to_one_file=True,
|
|
|
|
location=ONNX_WEIGHTS,
|
|
|
|
)
|
2023-12-09 04:26:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def fix_diffusion_name(name: str):
|
|
|
|
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
|
|
|
|
logger.warning(
|
|
|
|
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
|
|
|
name,
|
|
|
|
)
|
|
|
|
return f"diffusion-{name}"
|
|
|
|
|
|
|
|
return name
|
2023-12-10 00:04:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
def build_cache_paths(
|
|
|
|
conversion: ConversionContext,
|
|
|
|
name: str,
|
|
|
|
client: Optional[str] = None,
|
|
|
|
dest: Optional[str] = None,
|
|
|
|
format: Optional[str] = None,
|
|
|
|
) -> List[str]:
|
|
|
|
cache_path = dest or conversion.cache_path
|
|
|
|
|
|
|
|
# add an extension if possible, some of the conversion code checks for it
|
|
|
|
if format is not None:
|
|
|
|
basename = path.basename(name)
|
|
|
|
_filename, ext = path.splitext(basename)
|
2023-12-10 05:03:41 +00:00
|
|
|
if ext is None or ext == "":
|
2023-12-10 00:04:34 +00:00
|
|
|
name = f"{name}.{format}"
|
|
|
|
|
|
|
|
paths = [
|
|
|
|
path.join(cache_path, name),
|
|
|
|
]
|
|
|
|
|
|
|
|
if client is not None:
|
|
|
|
client_path = path.join(cache_path, client)
|
|
|
|
paths.append(path.join(client_path, name))
|
|
|
|
|
|
|
|
return paths
|
|
|
|
|
|
|
|
|
|
|
|
def get_first_exists(
|
|
|
|
paths: List[str],
|
|
|
|
) -> Optional[str]:
|
|
|
|
for name in paths:
|
|
|
|
if path.exists(name):
|
|
|
|
logger.debug("model already exists in cache, skipping fetch: %s", name)
|
|
|
|
return name
|
|
|
|
|
|
|
|
return None
|