1
0
Fork 0

feat(api): intercept model downloads in libs and use cached copy (fixes #95)

This commit is contained in:
Sean Sube 2023-02-12 09:28:37 -06:00
parent 8ea33e9874
commit 1179092028
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 94 additions and 14 deletions

View File

@ -1,6 +1,6 @@
echo "Downloading and converting models to ONNX format..." echo "Downloading and converting models to ONNX format..."
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
python -m onnx_web.convert --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN% python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN%
echo "Launching API server..." echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0 flask --app=onnx_web.serve run --host=0.0.0.0

View File

@ -17,6 +17,7 @@ fi
echo "Downloading and converting models to ONNX format..." echo "Downloading and converting models to ONNX format..."
python3 -m onnx_web.convert \ python3 -m onnx_web.convert \
--sources \
--diffusion \ --diffusion \
--upscaling \ --upscaling \
--correction \ --correction \

View File

@ -20,6 +20,7 @@ from .utils import (
source_format, source_format,
tuple_to_correction, tuple_to_correction,
tuple_to_diffusion, tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling, tuple_to_upscaling,
) )
@ -101,6 +102,33 @@ base_models: Models = {
4, 4,
), ),
], ],
# download only
"sources": [
(
"detection-resnet50-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
),
(
"detection-mobilenet025-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth",
),
(
"detection-yolo-v5-l",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth",
),
(
"detection-yolo-v5-n",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth",
),
(
"parsing-bisenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth",
),
(
"parsing-parsenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
),
],
} }
@ -113,8 +141,9 @@ def fetch_model(
if model_format is None: if model_format is None:
url = urlparse(source) url = urlparse(source)
ext = path.basename(url.path) ext = path.basename(url.path)
file, ext = path.splitext(ext)
if ext is not None: if ext is not None:
cache_name = "%s.%s" % (cache_name, ext) cache_name += ext
else: else:
cache_name = "%s.%s" % (cache_name, model_format) cache_name = "%s.%s" % (cache_name, model_format)
@ -147,6 +176,19 @@ def fetch_model(
def convert_models(ctx: ConversionContext, args, models: Models): def convert_models(ctx: ConversionContext, args, models: Models):
if args.sources:
for model in models.get("sources"):
model = tuple_to_source(model)
name = model.get("name")
if name in args.skip:
logger.info("Skipping source: %s", name)
else:
model_format = source_format(model)
source = model["source"]
dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("Finished downloading source: %s -> %s", source, dest)
if args.diffusion: if args.diffusion:
for model in models.get("diffusion"): for model in models.get("diffusion"):
model = tuple_to_diffusion(model) model = tuple_to_diffusion(model)
@ -208,6 +250,7 @@ def main() -> int:
) )
# model groups # model groups
parser.add_argument("--sources", action="store_true", default=False)
parser.add_argument("--correction", action="store_true", default=False) parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False) parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False) parser.add_argument("--upscaling", action="store_true", default=False)

View File

@ -76,6 +76,18 @@ def download_progress(urls: List[Tuple[str, str]]):
return str(dest_path.absolute()) return str(dest_path.absolute())
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
return {
"name": name,
"source": source,
}
else:
return model
def tuple_to_correction(model: Union[ModelDict, LegacyModel]): def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model

View File

@ -1,15 +1,17 @@
import sys import sys
import codeformer.facelib.utils.misc
import basicsr.utils.download_util
from functools import partial from functools import partial
from logging import getLogger from logging import getLogger
from urllib.parse import urlparse
from os import path from os import path
from .utils import ServerContext, base_join from urllib.parse import urlparse
import basicsr.utils.download_util
import codeformer.facelib.utils.misc
from .utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def unload(exclude): def unload(exclude):
""" """
Remove package modules from cache except excluded ones. Remove package modules from cache except excluded ones.
@ -21,7 +23,7 @@ def unload(exclude):
""" """
pkgs = [] pkgs = []
for mod in exclude: for mod in exclude:
pkg = mod.split('.', 1)[0] pkg = mod.split(".", 1)[0]
pkgs.append(pkg) pkgs.append(pkg)
to_unload = [] to_unload = []
@ -34,7 +36,7 @@ def unload(exclude):
continue continue
for pkg in pkgs: for pkg in pkgs:
if mod.startswith(pkg + '.'): if mod.startswith(pkg + "."):
to_unload.append(mod) to_unload.append(mod)
break break
@ -43,28 +45,50 @@ def unload(exclude):
del sys.modules[mod] del sys.modules[mod]
# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key
cache_path_map = {
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": "correction-codeformer.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": "detection-mobilenet025-final.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth": "parsing-bisenet.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth": "parsing-parsenet.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth": "upscaling-real-esrgan-x2-plus",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth": "detection-yolo-v5-l.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth": "detection-yolo-v5-n.pth",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth": "correct-gfpgan-v1-3.pth",
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/parsing_parsenet.pth": "parsing-parsenet.pth",
}
def patch_not_impl(): def patch_not_impl():
raise NotImplementedError() raise NotImplementedError()
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
url = urlparse(url) if url in cache_path_map:
base = path.basename(url.path) cache_path = cache_path_map.get(url)
cache_path = path.join(ctx.model_path, ".cache", base) else:
parsed = urlparse(url)
logger.debug("Patching download path: %s -> %s", path, cache_path) cache_path = path.basename(parsed.path)
cache_path = path.join(ctx.model_path, ".cache", cache_path)
logger.debug("Patching download path: %s -> %s", url, cache_path)
return cache_path return cache_path
def apply_patch_codeformer(ctx: ServerContext): def apply_patch_codeformer(ctx: ServerContext):
logger.debug("Patching CodeFormer module...") logger.debug("Patching CodeFormer module...")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patch_basicsr(ctx: ServerContext): def apply_patch_basicsr(ctx: ServerContext):
logger.debug("Patching BasicSR module...") logger.debug("Patching BasicSR module...")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patches(ctx: ServerContext): def apply_patches(ctx: ServerContext):
apply_patch_basicsr(ctx) apply_patch_basicsr(ctx)
apply_patch_codeformer(ctx) apply_patch_codeformer(ctx)