238 lines
6.6 KiB
Python
238 lines
6.6 KiB
Python
import shutil
|
|
from functools import partial
|
|
from logging import getLogger
|
|
from os import environ, path
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import requests
|
|
import safetensors
|
|
import torch
|
|
from huggingface_hub.utils.tqdm import tqdm
|
|
from yaml import safe_load
|
|
|
|
from ..server import ServerContext
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
ModelDict = Dict[str, Union[str, int]]
|
|
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
|
|
|
|
|
class ConversionContext(ServerContext):
|
|
def __init__(
|
|
self,
|
|
model_path: Optional[str] = None,
|
|
cache_path: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
half: Optional[bool] = False,
|
|
opset: Optional[int] = None,
|
|
token: Optional[str] = None,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
|
|
|
self.half = half
|
|
self.opset = opset
|
|
self.token = token
|
|
|
|
if device is not None:
|
|
self.training_device = device
|
|
else:
|
|
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
self.map_location = torch.device(self.training_device)
|
|
|
|
|
|
def download_progress(urls: List[Tuple[str, str]]):
|
|
for url, dest in urls:
|
|
dest_path = Path(dest).expanduser().resolve()
|
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if dest_path.exists():
|
|
logger.debug("destination already exists: %s", dest_path)
|
|
return str(dest_path.absolute())
|
|
|
|
req = requests.get(
|
|
url,
|
|
stream=True,
|
|
allow_redirects=True,
|
|
headers={
|
|
"User-Agent": "onnx-web-api",
|
|
},
|
|
)
|
|
if req.status_code != 200:
|
|
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
|
raise RuntimeError(
|
|
"Request to %s failed with status code: %s" % (url, req.status_code)
|
|
)
|
|
|
|
total = int(req.headers.get("Content-Length", 0))
|
|
desc = "unknown" if total == 0 else ""
|
|
req.raw.read = partial(req.raw.read, decode_content=True)
|
|
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
|
with dest_path.open("wb") as f:
|
|
shutil.copyfileobj(data, f)
|
|
|
|
return str(dest_path.absolute())
|
|
|
|
|
|
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
name, source, *rest = model
|
|
|
|
return {
|
|
"name": name,
|
|
"source": source,
|
|
}
|
|
else:
|
|
return model
|
|
|
|
|
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
name, source, *rest = model
|
|
scale = rest[0] if len(rest) > 0 else 1
|
|
half = rest[0] if len(rest) > 0 else False
|
|
opset = rest[0] if len(rest) > 0 else None
|
|
|
|
return {
|
|
"name": name,
|
|
"source": source,
|
|
"half": half,
|
|
"opset": opset,
|
|
"scale": scale,
|
|
}
|
|
else:
|
|
return model
|
|
|
|
|
|
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
name, source, *rest = model
|
|
single_vae = rest[0] if len(rest) > 0 else False
|
|
half = rest[0] if len(rest) > 0 else False
|
|
opset = rest[0] if len(rest) > 0 else None
|
|
|
|
return {
|
|
"name": name,
|
|
"source": source,
|
|
"half": half,
|
|
"opset": opset,
|
|
"single_vae": single_vae,
|
|
}
|
|
else:
|
|
return model
|
|
|
|
|
|
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|
if isinstance(model, list) or isinstance(model, tuple):
|
|
name, source, *rest = model
|
|
scale = rest[0] if len(rest) > 0 else 1
|
|
half = rest[0] if len(rest) > 0 else False
|
|
opset = rest[0] if len(rest) > 0 else None
|
|
|
|
return {
|
|
"name": name,
|
|
"source": source,
|
|
"half": half,
|
|
"opset": opset,
|
|
"scale": scale,
|
|
}
|
|
else:
|
|
return model
|
|
|
|
|
|
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
|
model_formats_original = ["ckpt", "safetensors"]
|
|
|
|
|
|
def source_format(model: Dict) -> Optional[str]:
|
|
if "format" in model:
|
|
return model["format"]
|
|
|
|
if "source" in model:
|
|
ext = path.splitext(model["source"])
|
|
if ext in model_formats:
|
|
return ext
|
|
|
|
return None
|
|
|
|
|
|
class Config(object):
|
|
"""
|
|
Shim for pydantic-style config.
|
|
"""
|
|
|
|
def __init__(self, kwargs):
|
|
self.__dict__.update(kwargs)
|
|
for k, v in self.__dict__.items():
|
|
Config.config_from_key(self, k, v)
|
|
|
|
def __iter__(self):
|
|
for k in self.__dict__.keys():
|
|
yield k
|
|
|
|
@classmethod
|
|
def config_from_key(cls, target, k, v):
|
|
if isinstance(v, dict):
|
|
tmp = Config(v)
|
|
setattr(target, k, tmp)
|
|
else:
|
|
setattr(target, k, v)
|
|
|
|
|
|
def load_yaml(file: str) -> str:
|
|
with open(file, "r") as f:
|
|
data = safe_load(f.read())
|
|
return Config(data)
|
|
|
|
|
|
safe_chars = "._-"
|
|
|
|
|
|
def sanitize_name(name):
|
|
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
|
|
|
|
|
def remove_prefix(name, prefix):
|
|
if name.startswith(prefix):
|
|
return name[len(prefix) :]
|
|
|
|
return name
|
|
|
|
|
|
def load_tensor(name: str, map_location=None):
|
|
logger.info("loading model from checkpoint")
|
|
_, extension = path.splitext(name)
|
|
if extension.lower() == ".safetensors":
|
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
|
try:
|
|
logger.debug("loading safetensors")
|
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
|
except Exception as e:
|
|
try:
|
|
logger.warning(
|
|
"failed to load as safetensors file, falling back to torch: %s", e
|
|
)
|
|
checkpoint = torch.jit.load(name)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
|
)
|
|
checkpoint = torch.load(name, map_location=map_location)
|
|
checkpoint = (
|
|
checkpoint["state_dict"]
|
|
if "state_dict" in checkpoint
|
|
else checkpoint
|
|
)
|
|
else:
|
|
logger.debug("loading ckpt")
|
|
checkpoint = torch.load(name, map_location=map_location)
|
|
checkpoint = (
|
|
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
|
)
|
|
|
|
return checkpoint
|