diff --git a/api/launch.bat b/api/launch.bat index dd1cd87c..e46b55fa 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -1,6 +1,6 @@ echo "Downloading and converting models to ONNX format..." IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) -python -m onnx_web.convert --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN% +python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN% echo "Launching API server..." flask --app=onnx_web.serve run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 84b377ed..96db9bb2 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -17,6 +17,7 @@ fi echo "Downloading and converting models to ONNX format..." python3 -m onnx_web.convert \ + --sources \ --diffusion \ --upscaling \ --correction \ diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 430c6f82..28531f2e 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -20,6 +20,7 @@ from .utils import ( source_format, tuple_to_correction, tuple_to_diffusion, + tuple_to_source, tuple_to_upscaling, ) @@ -101,6 +102,33 @@ base_models: Models = { 4, ), ], + # download only + "sources": [ + ( + "detection-resnet50-final", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth", + ), + ( + "detection-mobilenet025-final", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth", + ), + ( + "detection-yolo-v5-l", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth", + ), + ( + "detection-yolo-v5-n", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth", + ), + ( + "parsing-bisenet", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth", + ), + ( + "parsing-parsenet", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth", + ), + ], } @@ -113,8 +141,9 @@ def fetch_model( if model_format is None: url = urlparse(source) ext = path.basename(url.path) + file, ext = path.splitext(ext) if ext is not None: - cache_name = "%s.%s" % (cache_name, ext) + cache_name += ext else: cache_name = "%s.%s" % (cache_name, model_format) @@ -147,6 +176,19 @@ def fetch_model( def convert_models(ctx: ConversionContext, args, models: Models): + if args.sources: + for model in models.get("sources"): + model = tuple_to_source(model) + name = model.get("name") + + if name in args.skip: + logger.info("Skipping source: %s", name) + else: + model_format = source_format(model) + source = model["source"] + dest = fetch_model(ctx, name, source, model_format=model_format) + logger.info("Finished downloading source: %s -> %s", source, dest) + if args.diffusion: for model in models.get("diffusion"): model = tuple_to_diffusion(model) @@ -208,6 +250,7 @@ def main() -> int: ) # model groups + parser.add_argument("--sources", action="store_true", default=False) parser.add_argument("--correction", action="store_true", default=False) parser.add_argument("--diffusion", action="store_true", default=False) parser.add_argument("--upscaling", action="store_true", default=False) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 0466e5dc..689af692 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -76,6 +76,18 @@ def download_progress(urls: List[Tuple[str, str]]): return str(dest_path.absolute()) +def tuple_to_source(model: Union[ModelDict, LegacyModel]): + if isinstance(model, list) or isinstance(model, tuple): + name, source, *rest = model + + return { + "name": name, + "source": source, + } + else: + return model + + def tuple_to_correction(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model diff --git a/api/onnx_web/hacks.py b/api/onnx_web/hacks.py index 9c3e2a36..996f6186 100644 --- a/api/onnx_web/hacks.py +++ b/api/onnx_web/hacks.py @@ -1,15 +1,17 @@ import sys -import codeformer.facelib.utils.misc -import basicsr.utils.download_util - from functools import partial from logging import getLogger -from urllib.parse import urlparse from os import path -from .utils import ServerContext, base_join +from urllib.parse import urlparse + +import basicsr.utils.download_util +import codeformer.facelib.utils.misc + +from .utils import ServerContext logger = getLogger(__name__) + def unload(exclude): """ Remove package modules from cache except excluded ones. @@ -21,7 +23,7 @@ def unload(exclude): """ pkgs = [] for mod in exclude: - pkg = mod.split('.', 1)[0] + pkg = mod.split(".", 1)[0] pkgs.append(pkg) to_unload = [] @@ -34,7 +36,7 @@ def unload(exclude): continue for pkg in pkgs: - if mod.startswith(pkg + '.'): + if mod.startswith(pkg + "."): to_unload.append(mod) break @@ -43,28 +45,50 @@ def unload(exclude): del sys.modules[mod] +# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key +cache_path_map = { + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": "correction-codeformer.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth": "detection-resnet50-final.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": "detection-mobilenet025-final.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth": "parsing-bisenet.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth": "parsing-parsenet.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth": "upscaling-real-esrgan-x2-plus", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth": "detection-yolo-v5-l.pth", + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth": "detection-yolo-v5-n.pth", + "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth": "correct-gfpgan-v1-3.pth", + "https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/detection_Resnet50_Final.pth": "detection-resnet50-final.pth", + "https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/parsing_parsenet.pth": "parsing-parsenet.pth", +} + + def patch_not_impl(): raise NotImplementedError() + def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: - url = urlparse(url) - base = path.basename(url.path) - cache_path = path.join(ctx.model_path, ".cache", base) - - logger.debug("Patching download path: %s -> %s", path, cache_path) + if url in cache_path_map: + cache_path = cache_path_map.get(url) + else: + parsed = urlparse(url) + cache_path = path.basename(parsed.path) + cache_path = path.join(ctx.model_path, ".cache", cache_path) + logger.debug("Patching download path: %s -> %s", url, cache_path) return cache_path + def apply_patch_codeformer(ctx: ServerContext): logger.debug("Patching CodeFormer module...") codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) + def apply_patch_basicsr(ctx: ServerContext): logger.debug("Patching BasicSR module...") basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) + def apply_patches(ctx: ServerContext): apply_patch_basicsr(ctx) apply_patch_codeformer(ctx)