1
0
Fork 0
onnx-web/api/onnx_web/server/hacks.py

190 lines
6.3 KiB
Python
Raw Normal View History

import sys
from functools import partial
from logging import getLogger
from os import path
from urllib.parse import urlparse
from ..utils import run_gc
from .context import ServerContext
logger = getLogger(__name__)
def unload(exclude):
"""
Remove package modules from cache except excluded ones.
On next import they will be reloaded.
From https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e
Args:
exclude (iter<str>): Sequence of module paths.
"""
pkgs = []
for mod in exclude:
pkg = mod.split(".", 1)[0]
pkgs.append(pkg)
to_unload = []
for mod in sys.modules:
if mod in exclude:
continue
if mod in pkgs:
to_unload.append(mod)
continue
for pkg in pkgs:
if mod.startswith(pkg + "."):
to_unload.append(mod)
break
2023-02-17 00:42:05 +00:00
logger.debug("unloading modules for patching: %s", to_unload)
for mod in to_unload:
del sys.modules[mod]
run_gc()
# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key
cache_path_map = {
"https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth": (
"pt-inception-2015-12-05-6726825d.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth": (
"detection-resnet50-final.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth": (
"alignment-wflw-4hg.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth": (
"assessment-hyperiqa.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": (
"detection-mobilenet-025-final.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth": (
"headpose-hopenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth": (
"matting-modnet-portrait.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth": (
"parsing-bisenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth": (
"parsing-parsenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth": (
"recognition-arcface-ir-se50.pth"
),
2023-02-12 19:32:50 +00:00
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": (
"correction-codeformer.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth": (
"detection-resnet50-final.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": (
"detection-mobilenet025-final.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth": (
"parsing-bisenet.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth": (
"parsing-parsenet.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth": (
"upscaling-real-esrgan-x2-plus"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth": (
"detection-yolo-v5-l.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth": (
"detection-yolo-v5-n.pth"
),
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth": (
"correct-gfpgan-v1-3.pth"
),
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/detection_Resnet50_Final.pth": (
"detection-resnet50-final.pth"
),
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/parsing_parsenet.pth": (
"parsing-parsenet.pth"
),
}
def patch_not_impl():
raise NotImplementedError()
def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
cache_path = cache_path_map.get(url, None)
if cache_path is None:
parsed = urlparse(url)
cache_path = path.basename(parsed.path)
cache_path = path.join(server.cache_path, cache_path)
2023-02-17 00:42:05 +00:00
logger.debug("patching download path: %s -> %s", url, cache_path)
if path.exists(cache_path):
return cache_path
else:
2023-02-17 00:42:05 +00:00
raise FileNotFoundError("missing cache file: %s" % (cache_path))
def apply_patch_basicsr(server: ServerContext):
2023-02-17 00:42:05 +00:00
logger.debug("patching BasicSR module")
2023-11-18 23:20:45 +00:00
try:
import basicsr.utils.download_util
2023-11-18 23:20:45 +00:00
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
2023-11-19 00:13:13 +00:00
basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
2023-11-18 23:20:45 +00:00
except ImportError:
logger.info("unable to import basicsr utils for patching")
except AttributeError:
logger.warning("unable to patch basicsr utils")
def apply_patch_codeformer(server: ServerContext):
2023-02-17 00:42:05 +00:00
logger.debug("patching CodeFormer module")
2023-11-18 23:20:45 +00:00
try:
import codeformer.basicsr.utils.download_util
2023-12-19 04:04:24 +00:00
codeformer.basicsr.utils.download_util.download_file_from_google_drive = (
patch_not_impl
)
codeformer.basicsr.utils.download_util.load_file_from_url = partial(
2023-11-19 00:13:13 +00:00
patch_cache_path, server
)
2023-11-18 23:20:45 +00:00
except ImportError:
logger.info("unable to import codeformer utils for patching")
except AttributeError:
logger.warning("unable to patch codeformer utils")
def apply_patch_facexlib(server: ServerContext):
2023-02-17 00:42:05 +00:00
logger.debug("patching Facexlib module")
2023-11-18 23:20:45 +00:00
try:
import facexlib.utils
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import facexlib for patching")
except AttributeError:
logger.warning("unable to patch facexlib utils")
def apply_patches(server: ServerContext):
apply_patch_basicsr(server)
apply_patch_codeformer(server)
apply_patch_facexlib(server)
unload(
[
"basicsr.utils.download_util",
"codeformer.facelib.utils.misc",
"facexlib.utils",
]
)