fix(api): patch various download fns to use cache (#95)
This commit is contained in:
parent
6983691e98
commit
8ea33e9874
|
@ -1,6 +1,5 @@
|
|||
from logging import getLogger
|
||||
|
||||
from codeformer import CodeFormer
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
|
@ -23,6 +22,8 @@ def correct_codeformer(
|
|||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
from codeformer import CodeFormer
|
||||
|
||||
device = job.get_device()
|
||||
# TODO: terrible names, fix
|
||||
image = source or source_image
|
||||
|
|
|
@ -4,6 +4,7 @@ from logging import getLogger
|
|||
from os import makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jsonschema import ValidationError, validate
|
||||
from yaml import safe_load
|
||||
|
@ -107,8 +108,14 @@ def fetch_model(
|
|||
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
|
||||
) -> str:
|
||||
cache_name = path.join(ctx.cache_path, name)
|
||||
if model_format is not None:
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if model_format is None:
|
||||
url = urlparse(source)
|
||||
ext = path.basename(url.path)
|
||||
if ext is not None:
|
||||
cache_name = "%s.%s" % (cache_name, ext)
|
||||
else:
|
||||
cache_name = "%s.%s" % (cache_name, model_format)
|
||||
|
||||
for proto in model_sources:
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
import sys
|
||||
import codeformer.facelib.utils.misc
|
||||
import basicsr.utils.download_util
|
||||
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from urllib.parse import urlparse
|
||||
from os import path
|
||||
from .utils import ServerContext, base_join
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def unload(exclude):
|
||||
"""
|
||||
Remove package modules from cache except excluded ones.
|
||||
On next import they will be reloaded.
|
||||
From https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e
|
||||
|
||||
Args:
|
||||
exclude (iter<str>): Sequence of module paths.
|
||||
"""
|
||||
pkgs = []
|
||||
for mod in exclude:
|
||||
pkg = mod.split('.', 1)[0]
|
||||
pkgs.append(pkg)
|
||||
|
||||
to_unload = []
|
||||
for mod in sys.modules:
|
||||
if mod in exclude:
|
||||
continue
|
||||
|
||||
if mod in pkgs:
|
||||
to_unload.append(mod)
|
||||
continue
|
||||
|
||||
for pkg in pkgs:
|
||||
if mod.startswith(pkg + '.'):
|
||||
to_unload.append(mod)
|
||||
break
|
||||
|
||||
logger.debug("Unloading modules for patching: %s", to_unload)
|
||||
for mod in to_unload:
|
||||
del sys.modules[mod]
|
||||
|
||||
|
||||
def patch_not_impl():
|
||||
raise NotImplementedError()
|
||||
|
||||
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
||||
url = urlparse(url)
|
||||
base = path.basename(url.path)
|
||||
cache_path = path.join(ctx.model_path, ".cache", base)
|
||||
|
||||
logger.debug("Patching download path: %s -> %s", path, cache_path)
|
||||
|
||||
return cache_path
|
||||
|
||||
def apply_patch_codeformer(ctx: ServerContext):
|
||||
logger.debug("Patching CodeFormer module...")
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
|
||||
|
||||
def apply_patch_basicsr(ctx: ServerContext):
|
||||
logger.debug("Patching BasicSR module...")
|
||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
|
||||
|
||||
def apply_patches(ctx: ServerContext):
|
||||
apply_patch_basicsr(ctx)
|
||||
apply_patch_codeformer(ctx)
|
||||
unload(["basicsr.utils.download_util", "codeformer.facelib.utils.misc"])
|
|
@ -29,6 +29,8 @@ from jsonschema import validate
|
|||
from onnxruntime import get_available_providers
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.hacks import apply_patches
|
||||
|
||||
from .chain import (
|
||||
ChainPipeline,
|
||||
blend_img2img,
|
||||
|
@ -405,6 +407,7 @@ def load_platforms(context: ServerContext):
|
|||
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
apply_patches(context)
|
||||
check_paths(context)
|
||||
load_models(context)
|
||||
load_params(context)
|
||||
|
|
Loading…
Reference in New Issue