diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index ef36f471..29caeaed 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import ServerContext +from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .stage import BaseStage @@ -28,7 +28,7 @@ class CorrectGFPGANStage(BaseStage): face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) cache_key = (face_path,) - cache_pipe = server.cache.get("gfpgan", cache_key) + cache_pipe = server.cache.get(ModelTypes.correction, cache_key) if cache_pipe is not None: logger.info("reusing existing GFPGAN pipeline") @@ -46,7 +46,7 @@ class CorrectGFPGANStage(BaseStage): upscale=upscale.face_outscale, ) - server.cache.set("gfpgan", cache_key, gfpgan) + server.cache.set(ModelTypes.correction, cache_key, gfpgan) run_gc([device]) return gfpgan diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 5d1ed922..b8774ffa 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -7,7 +7,7 @@ from PIL import Image from ..models.onnx import OnnxModel from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams -from ..server import ServerContext +from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .stage import BaseStage @@ -28,7 +28,7 @@ class UpscaleBSRGANStage(BaseStage): # must be within the load function for patch to take effect model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) cache_key = (model_path,) - cache_pipe = server.cache.get("bsrgan", cache_key) + cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key) if cache_pipe is not None: logger.debug("reusing existing BSRGAN pipeline") @@ -43,7 +43,7 @@ class UpscaleBSRGANStage(BaseStage): sess_options=device.sess_options(), ) - server.cache.set("bsrgan", cache_key, pipe) + server.cache.set(ModelTypes.upscaling, cache_key, pipe) run_gc([device]) return pipe diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 915a2c7b..8538a2ee 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -7,7 +7,7 @@ from PIL import Image from ..onnx import OnnxRRDBNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import ServerContext +from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .stage import BaseStage @@ -29,7 +29,7 @@ class UpscaleRealESRGANStage(BaseStage): model_path = path.join(server.model_path, model_file) cache_key = (model_path, params.format) - cache_pipe = server.cache.get("resrgan", cache_key) + cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key) if cache_pipe is not None: logger.info("reusing existing Real ESRGAN pipeline") return cache_pipe @@ -66,7 +66,7 @@ class UpscaleRealESRGANStage(BaseStage): half=False, # TODO: use server optimizations ) - server.cache.set("resrgan", cache_key, upsampler) + server.cache.set(ModelTypes.upscaling, cache_key, upsampler) run_gc([device]) return upsampler diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index e7379ff5..c6ccfebb 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -7,7 +7,7 @@ from PIL import Image from ..models.onnx import OnnxModel from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import ServerContext +from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .stage import BaseStage @@ -28,7 +28,7 @@ class UpscaleSwinIRStage(BaseStage): # must be within the load function for patch to take effect model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) cache_key = (model_path,) - cache_pipe = server.cache.get("swinir", cache_key) + cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key) if cache_pipe is not None: logger.info("reusing existing SwinIR pipeline") @@ -43,7 +43,7 @@ class UpscaleSwinIRStage(BaseStage): sess_options=device.sess_options(), ) - server.cache.set("swinir", cache_key, pipe) + server.cache.set(ModelTypes.upscaling, cache_key, pipe) run_gc([device]) return pipe @@ -75,7 +75,7 @@ class UpscaleSwinIRStage(BaseStage): image = np.array(source) / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) - logger.info("SwinIR input shape: %s", image.shape) + logger.trace("SwinIR input shape: %s", image.shape) scale = upscale.outscale dest = np.zeros( @@ -86,7 +86,7 @@ class UpscaleSwinIRStage(BaseStage): image.shape[3] * scale, ) ) - logger.info("SwinIR output shape: %s", dest.shape) + logger.trace("SwinIR output shape: %s", dest.shape) dest = swinir(image) dest = np.clip(np.squeeze(dest, axis=0), 0, 1) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 7b339ee7..2432e83a 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -11,7 +11,7 @@ from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ..diffusers.utils import expand_prompt from ..params import DeviceParams, ImageParams -from ..server import ServerContext +from ..server import ModelTypes, ServerContext from ..utils import run_gc from .patches.unet import UNetWrapper from .patches.vae import VAEWrapper @@ -119,14 +119,14 @@ def load_pipeline( scheduler_key = (params.scheduler, model) scheduler_type = pipeline_schedulers[params.scheduler] - cache_pipe = server.cache.get("diffusion", pipe_key) + cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key) if cache_pipe is not None: logger.debug("reusing existing diffusion pipeline") pipe = cache_pipe # update scheduler - cache_scheduler = server.cache.get("scheduler", scheduler_key) + cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key) if cache_scheduler is None: logger.debug("loading new diffusion scheduler") scheduler = scheduler_type.from_pretrained( @@ -141,7 +141,7 @@ def load_pipeline( scheduler = scheduler.to(device.torch_str()) pipe.scheduler = scheduler - server.cache.set("scheduler", scheduler_key, scheduler) + server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler) run_gc([device]) else: @@ -342,8 +342,8 @@ def load_pipeline( optimize_pipeline(server, pipe) patch_pipeline(server, pipe, pipeline, pipeline_class, params) - server.cache.set("diffusion", pipe_key, pipe) - server.cache.set("scheduler", scheduler_key, components["scheduler"]) + server.cache.set(ModelTypes.diffusion, pipe_key, pipe) + server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) if hasattr(pipe, "vae_decoder"): pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) diff --git a/api/onnx_web/server/__init__.py b/api/onnx_web/server/__init__.py index f02fa35a..b74014f1 100644 --- a/api/onnx_web/server/__init__.py +++ b/api/onnx_web/server/__init__.py @@ -4,5 +4,5 @@ from .hacks import ( apply_patch_facexlib, apply_patches, ) -from .model_cache import ModelCache +from .model_cache import ModelCache, ModelTypes from .context import ServerContext diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index b7207c41..21da25f4 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -1,3 +1,4 @@ +from enum import Enum from logging import getLogger from typing import Any, List, Tuple @@ -6,6 +7,13 @@ logger = getLogger(__name__) cache: List[Tuple[str, Any, Any]] = [] +class ModelTypes(str, Enum): + correction = "correction" + diffusion = "diffusion" + scheduler = "scheduler" + upscaling = "upscaling" + + class ModelCache: # cache: List[Tuple[str, Any, Any]] limit: int