feat(api): intercept model downloads in libs and use cached copy (fixes #95)
This commit is contained in:
parent
8ea33e9874
commit
1179092028
|
@ -1,6 +1,6 @@
|
||||||
echo "Downloading and converting models to ONNX format..."
|
echo "Downloading and converting models to ONNX format..."
|
||||||
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
|
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
|
||||||
python -m onnx_web.convert --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN%
|
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN%
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
flask --app=onnx_web.serve run --host=0.0.0.0
|
||||||
|
|
|
@ -17,6 +17,7 @@ fi
|
||||||
|
|
||||||
echo "Downloading and converting models to ONNX format..."
|
echo "Downloading and converting models to ONNX format..."
|
||||||
python3 -m onnx_web.convert \
|
python3 -m onnx_web.convert \
|
||||||
|
--sources \
|
||||||
--diffusion \
|
--diffusion \
|
||||||
--upscaling \
|
--upscaling \
|
||||||
--correction \
|
--correction \
|
||||||
|
|
|
@ -20,6 +20,7 @@ from .utils import (
|
||||||
source_format,
|
source_format,
|
||||||
tuple_to_correction,
|
tuple_to_correction,
|
||||||
tuple_to_diffusion,
|
tuple_to_diffusion,
|
||||||
|
tuple_to_source,
|
||||||
tuple_to_upscaling,
|
tuple_to_upscaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,6 +102,33 @@ base_models: Models = {
|
||||||
4,
|
4,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
# download only
|
||||||
|
"sources": [
|
||||||
|
(
|
||||||
|
"detection-resnet50-final",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"detection-mobilenet025-final",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"detection-yolo-v5-l",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"detection-yolo-v5-n",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"parsing-bisenet",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"parsing-parsenet",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
|
||||||
|
),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,8 +141,9 @@ def fetch_model(
|
||||||
if model_format is None:
|
if model_format is None:
|
||||||
url = urlparse(source)
|
url = urlparse(source)
|
||||||
ext = path.basename(url.path)
|
ext = path.basename(url.path)
|
||||||
|
file, ext = path.splitext(ext)
|
||||||
if ext is not None:
|
if ext is not None:
|
||||||
cache_name = "%s.%s" % (cache_name, ext)
|
cache_name += ext
|
||||||
else:
|
else:
|
||||||
cache_name = "%s.%s" % (cache_name, model_format)
|
cache_name = "%s.%s" % (cache_name, model_format)
|
||||||
|
|
||||||
|
@ -147,6 +176,19 @@ def fetch_model(
|
||||||
|
|
||||||
|
|
||||||
def convert_models(ctx: ConversionContext, args, models: Models):
|
def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
if args.sources:
|
||||||
|
for model in models.get("sources"):
|
||||||
|
model = tuple_to_source(model)
|
||||||
|
name = model.get("name")
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("Skipping source: %s", name)
|
||||||
|
else:
|
||||||
|
model_format = source_format(model)
|
||||||
|
source = model["source"]
|
||||||
|
dest = fetch_model(ctx, name, source, model_format=model_format)
|
||||||
|
logger.info("Finished downloading source: %s -> %s", source, dest)
|
||||||
|
|
||||||
if args.diffusion:
|
if args.diffusion:
|
||||||
for model in models.get("diffusion"):
|
for model in models.get("diffusion"):
|
||||||
model = tuple_to_diffusion(model)
|
model = tuple_to_diffusion(model)
|
||||||
|
@ -208,6 +250,7 @@ def main() -> int:
|
||||||
)
|
)
|
||||||
|
|
||||||
# model groups
|
# model groups
|
||||||
|
parser.add_argument("--sources", action="store_true", default=False)
|
||||||
parser.add_argument("--correction", action="store_true", default=False)
|
parser.add_argument("--correction", action="store_true", default=False)
|
||||||
parser.add_argument("--diffusion", action="store_true", default=False)
|
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||||
parser.add_argument("--upscaling", action="store_true", default=False)
|
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||||
|
|
|
@ -76,6 +76,18 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
return str(dest_path.absolute())
|
return str(dest_path.absolute())
|
||||||
|
|
||||||
|
|
||||||
|
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||||
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
|
name, source, *rest = model
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"source": source,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import sys
|
import sys
|
||||||
import codeformer.facelib.utils.misc
|
|
||||||
import basicsr.utils.download_util
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from urllib.parse import urlparse
|
|
||||||
from os import path
|
from os import path
|
||||||
from .utils import ServerContext, base_join
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import basicsr.utils.download_util
|
||||||
|
import codeformer.facelib.utils.misc
|
||||||
|
|
||||||
|
from .utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def unload(exclude):
|
def unload(exclude):
|
||||||
"""
|
"""
|
||||||
Remove package modules from cache except excluded ones.
|
Remove package modules from cache except excluded ones.
|
||||||
|
@ -21,7 +23,7 @@ def unload(exclude):
|
||||||
"""
|
"""
|
||||||
pkgs = []
|
pkgs = []
|
||||||
for mod in exclude:
|
for mod in exclude:
|
||||||
pkg = mod.split('.', 1)[0]
|
pkg = mod.split(".", 1)[0]
|
||||||
pkgs.append(pkg)
|
pkgs.append(pkg)
|
||||||
|
|
||||||
to_unload = []
|
to_unload = []
|
||||||
|
@ -34,7 +36,7 @@ def unload(exclude):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for pkg in pkgs:
|
for pkg in pkgs:
|
||||||
if mod.startswith(pkg + '.'):
|
if mod.startswith(pkg + "."):
|
||||||
to_unload.append(mod)
|
to_unload.append(mod)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -43,28 +45,50 @@ def unload(exclude):
|
||||||
del sys.modules[mod]
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key
|
||||||
|
cache_path_map = {
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": "correction-codeformer.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": "detection-mobilenet025-final.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth": "parsing-bisenet.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth": "parsing-parsenet.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth": "upscaling-real-esrgan-x2-plus",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth": "detection-yolo-v5-l.pth",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth": "detection-yolo-v5-n.pth",
|
||||||
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth": "correct-gfpgan-v1-3.pth",
|
||||||
|
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
|
||||||
|
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/parsing_parsenet.pth": "parsing-parsenet.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def patch_not_impl():
|
def patch_not_impl():
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
||||||
url = urlparse(url)
|
if url in cache_path_map:
|
||||||
base = path.basename(url.path)
|
cache_path = cache_path_map.get(url)
|
||||||
cache_path = path.join(ctx.model_path, ".cache", base)
|
else:
|
||||||
|
parsed = urlparse(url)
|
||||||
logger.debug("Patching download path: %s -> %s", path, cache_path)
|
cache_path = path.basename(parsed.path)
|
||||||
|
|
||||||
|
cache_path = path.join(ctx.model_path, ".cache", cache_path)
|
||||||
|
logger.debug("Patching download path: %s -> %s", url, cache_path)
|
||||||
return cache_path
|
return cache_path
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_codeformer(ctx: ServerContext):
|
def apply_patch_codeformer(ctx: ServerContext):
|
||||||
logger.debug("Patching CodeFormer module...")
|
logger.debug("Patching CodeFormer module...")
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
|
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_basicsr(ctx: ServerContext):
|
def apply_patch_basicsr(ctx: ServerContext):
|
||||||
logger.debug("Patching BasicSR module...")
|
logger.debug("Patching BasicSR module...")
|
||||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
|
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
|
||||||
|
|
||||||
|
|
||||||
def apply_patches(ctx: ServerContext):
|
def apply_patches(ctx: ServerContext):
|
||||||
apply_patch_basicsr(ctx)
|
apply_patch_basicsr(ctx)
|
||||||
apply_patch_codeformer(ctx)
|
apply_patch_codeformer(ctx)
|
||||||
|
|
Loading…
Reference in New Issue