1
0
Fork 0

fix(api): patch various download fns to use cache (#95)

This commit is contained in:
Sean Sube 2023-02-12 06:25:44 -06:00
parent 6983691e98
commit 8ea33e9874
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 85 additions and 3 deletions

View File

@ -1,6 +1,5 @@
from logging import getLogger from logging import getLogger
from codeformer import CodeFormer
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext
@ -23,6 +22,8 @@ def correct_codeformer(
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
from codeformer import CodeFormer
device = job.get_device() device = job.get_device()
# TODO: terrible names, fix # TODO: terrible names, fix
image = source or source_image image = source or source_image

View File

@ -4,6 +4,7 @@ from logging import getLogger
from os import makedirs, path from os import makedirs, path
from sys import exit from sys import exit
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
from yaml import safe_load from yaml import safe_load
@ -107,8 +108,14 @@ def fetch_model(
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
) -> str: ) -> str:
cache_name = path.join(ctx.cache_path, name) cache_name = path.join(ctx.cache_path, name)
if model_format is not None:
# add an extension if possible, some of the conversion code checks for it # add an extension if possible, some of the conversion code checks for it
if model_format is None:
url = urlparse(source)
ext = path.basename(url.path)
if ext is not None:
cache_name = "%s.%s" % (cache_name, ext)
else:
cache_name = "%s.%s" % (cache_name, model_format) cache_name = "%s.%s" % (cache_name, model_format)
for proto in model_sources: for proto in model_sources:

71
api/onnx_web/hacks.py Normal file
View File

@ -0,0 +1,71 @@
import sys
import codeformer.facelib.utils.misc
import basicsr.utils.download_util
from functools import partial
from logging import getLogger
from urllib.parse import urlparse
from os import path
from .utils import ServerContext, base_join
logger = getLogger(__name__)
def unload(exclude):
"""
Remove package modules from cache except excluded ones.
On next import they will be reloaded.
From https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e
Args:
exclude (iter<str>): Sequence of module paths.
"""
pkgs = []
for mod in exclude:
pkg = mod.split('.', 1)[0]
pkgs.append(pkg)
to_unload = []
for mod in sys.modules:
if mod in exclude:
continue
if mod in pkgs:
to_unload.append(mod)
continue
for pkg in pkgs:
if mod.startswith(pkg + '.'):
to_unload.append(mod)
break
logger.debug("Unloading modules for patching: %s", to_unload)
for mod in to_unload:
del sys.modules[mod]
def patch_not_impl():
raise NotImplementedError()
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
url = urlparse(url)
base = path.basename(url.path)
cache_path = path.join(ctx.model_path, ".cache", base)
logger.debug("Patching download path: %s -> %s", path, cache_path)
return cache_path
def apply_patch_codeformer(ctx: ServerContext):
logger.debug("Patching CodeFormer module...")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patch_basicsr(ctx: ServerContext):
logger.debug("Patching BasicSR module...")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patches(ctx: ServerContext):
apply_patch_basicsr(ctx)
apply_patch_codeformer(ctx)
unload(["basicsr.utils.download_util", "codeformer.facelib.utils.misc"])

View File

@ -29,6 +29,8 @@ from jsonschema import validate
from onnxruntime import get_available_providers from onnxruntime import get_available_providers
from PIL import Image from PIL import Image
from onnx_web.hacks import apply_patches
from .chain import ( from .chain import (
ChainPipeline, ChainPipeline,
blend_img2img, blend_img2img,
@ -405,6 +407,7 @@ def load_platforms(context: ServerContext):
context = ServerContext.from_environ() context = ServerContext.from_environ()
apply_patches(context)
check_paths(context) check_paths(context)
load_models(context) load_models(context)
load_params(context) load_params(context)