import sys from functools import partial from logging import getLogger from os import path from urllib.parse import urlparse import basicsr.utils.download_util import codeformer.facelib.utils.misc import facexlib.utils from ..utils import ServerContext logger = getLogger(__name__) def unload(exclude): """ Remove package modules from cache except excluded ones. On next import they will be reloaded. From Args: exclude (iter): Sequence of module paths. """ pkgs = [] for mod in exclude: pkg = mod.split(".", 1)[0] pkgs.append(pkg) to_unload = [] for mod in sys.modules: if mod in exclude: continue if mod in pkgs: to_unload.append(mod) continue for pkg in pkgs: if mod.startswith(pkg + "."): to_unload.append(mod) break logger.debug("unloading modules for patching: %s", to_unload) for mod in to_unload: del sys.modules[mod] # these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key cache_path_map = { "": ( "pt-inception-2015-12-05-6726825d.pth" ), "": ( "detection-resnet50-final.pth" ), "": ( "alignment-wflw-4hg.pth" ), "": ( "assessment-hyperiqa.pth" ), "": ( "detection-mobilenet-025-final.pth" ), "": ( "headpose-hopenet.pth" ), "": ( "matting-modnet-portrait.pth" ), "": ( "parsing-bisenet.pth" ), "": ( "parsing-parsenet.pth" ), "": ( "recognition-arcface-ir-se50.pth" ), "": ( "correction-codeformer.pth" ), "": ( "detection-resnet50-final.pth" ), "": ( "detection-mobilenet025-final.pth" ), "": ( "parsing-bisenet.pth" ), "": ( "parsing-parsenet.pth" ), "": ( "upscaling-real-esrgan-x2-plus" ), "": ( "detection-yolo-v5-l.pth" ), "": ( "detection-yolo-v5-n.pth" ), "": ( "correct-gfpgan-v1-3.pth" ), "": ( "detection-resnet50-final.pth" ), "": ( "parsing-parsenet.pth" ), } def patch_not_impl(): raise NotImplementedError() def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: if url in cache_path_map: cache_path = cache_path_map.get(url) else: parsed = urlparse(url) cache_path = path.basename(parsed.path) cache_path = path.join(ctx.cache_path, cache_path) logger.debug("patching download path: %s -> %s", url, cache_path) if path.exists(cache_path): return cache_path else: raise FileNotFoundError("missing cache file: %s" % (cache_path)) def apply_patch_basicsr(ctx: ServerContext): logger.debug("patching BasicSR module") basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) def apply_patch_codeformer(ctx: ServerContext): logger.debug("patching CodeFormer module") codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) def apply_patch_facexlib(ctx: ServerContext): logger.debug("patching Facexlib module") facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx) def apply_patches(ctx: ServerContext): apply_patch_basicsr(ctx) apply_patch_codeformer(ctx) apply_patch_facexlib(ctx) unload( [ "basicsr.utils.download_util", "codeformer.facelib.utils.misc", "facexlib.utils", ] )