From 8ea33e9874ccd51fc0ae444cf477f6ffc3ba7fd8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Feb 2023 06:25:44 -0600 Subject: [PATCH] fix(api): patch various download fns to use cache (#95) --- api/onnx_web/chain/correct_codeformer.py | 3 +- api/onnx_web/convert/__main__.py | 11 +++- api/onnx_web/hacks.py | 71 ++++++++++++++++++++++++ api/onnx_web/serve.py | 3 + 4 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 api/onnx_web/hacks.py diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 5dc4bdd8..e8d74a1a 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,6 +1,5 @@ from logging import getLogger -from codeformer import CodeFormer from PIL import Image from ..device_pool import JobContext @@ -23,6 +22,8 @@ def correct_codeformer( upscale: UpscaleParams, **kwargs, ) -> Image.Image: + from codeformer import CodeFormer + device = job.get_device() # TODO: terrible names, fix image = source or source_image diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 5d6923be..430c6f82 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -4,6 +4,7 @@ from logging import getLogger from os import makedirs, path from sys import exit from typing import Dict, List, Optional, Tuple +from urllib.parse import urlparse from jsonschema import ValidationError, validate from yaml import safe_load @@ -107,8 +108,14 @@ def fetch_model( ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None ) -> str: cache_name = path.join(ctx.cache_path, name) - if model_format is not None: - # add an extension if possible, some of the conversion code checks for it + + # add an extension if possible, some of the conversion code checks for it + if model_format is None: + url = urlparse(source) + ext = path.basename(url.path) + if ext is not None: + cache_name = "%s.%s" % (cache_name, ext) + else: cache_name = "%s.%s" % (cache_name, model_format) for proto in model_sources: diff --git a/api/onnx_web/hacks.py b/api/onnx_web/hacks.py new file mode 100644 index 00000000..9c3e2a36 --- /dev/null +++ b/api/onnx_web/hacks.py @@ -0,0 +1,71 @@ +import sys +import codeformer.facelib.utils.misc +import basicsr.utils.download_util + +from functools import partial +from logging import getLogger +from urllib.parse import urlparse +from os import path +from .utils import ServerContext, base_join + +logger = getLogger(__name__) + +def unload(exclude): + """ + Remove package modules from cache except excluded ones. + On next import they will be reloaded. + From https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e + + Args: + exclude (iter): Sequence of module paths. + """ + pkgs = [] + for mod in exclude: + pkg = mod.split('.', 1)[0] + pkgs.append(pkg) + + to_unload = [] + for mod in sys.modules: + if mod in exclude: + continue + + if mod in pkgs: + to_unload.append(mod) + continue + + for pkg in pkgs: + if mod.startswith(pkg + '.'): + to_unload.append(mod) + break + + logger.debug("Unloading modules for patching: %s", to_unload) + for mod in to_unload: + del sys.modules[mod] + + +def patch_not_impl(): + raise NotImplementedError() + +def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: + url = urlparse(url) + base = path.basename(url.path) + cache_path = path.join(ctx.model_path, ".cache", base) + + logger.debug("Patching download path: %s -> %s", path, cache_path) + + return cache_path + +def apply_patch_codeformer(ctx: ServerContext): + logger.debug("Patching CodeFormer module...") + codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl + codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) + +def apply_patch_basicsr(ctx: ServerContext): + logger.debug("Patching BasicSR module...") + basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl + basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) + +def apply_patches(ctx: ServerContext): + apply_patch_basicsr(ctx) + apply_patch_codeformer(ctx) + unload(["basicsr.utils.download_util", "codeformer.facelib.utils.misc"]) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index e135dba4..e2a05d2e 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -29,6 +29,8 @@ from jsonschema import validate from onnxruntime import get_available_providers from PIL import Image +from onnx_web.hacks import apply_patches + from .chain import ( ChainPipeline, blend_img2img, @@ -405,6 +407,7 @@ def load_platforms(context: ServerContext): context = ServerContext.from_environ() +apply_patches(context) check_paths(context) load_models(context) load_params(context)