From 0709c1dbf026018b53338754ca2386cd8e02d3db Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 13 Feb 2023 18:10:11 -0600 Subject: [PATCH] feat(api): add GFPGAN and Real ESRGAN to model cache --- api/onnx_web/chain/correct_gfpgan.py | 20 +++++--------- api/onnx_web/chain/upscale_resrgan.py | 26 +++++++++---------- .../chain/upscale_stable_diffusion.py | 25 +++++++----------- api/onnx_web/server/model_cache.py | 9 ++++--- 4 files changed, 35 insertions(+), 45 deletions(-) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index db56f762..52d17bcf 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -13,24 +13,19 @@ from ..utils import ServerContext, run_gc logger = getLogger(__name__) -last_pipeline_instance: Optional[GFPGANer] = None -last_pipeline_params: Optional[str] = None - - def load_gfpgan( server: ServerContext, - stage: StageParams, + _stage: StageParams, upscale: UpscaleParams, - device: DeviceParams, + _device: DeviceParams, ): - global last_pipeline_instance - global last_pipeline_params - face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model)) + cache_key = (face_path,) + cache_pipe = server.cache.get("gfpgan", cache_key) - if last_pipeline_instance is not None and face_path == last_pipeline_params: + if cache_pipe is not None: logger.info("reusing existing GFPGAN pipeline") - return last_pipeline_instance + return cache_pipe logger.debug("loading GFPGAN model from %s", face_path) @@ -43,8 +38,7 @@ def load_gfpgan( upscale=upscale.face_outscale, ) - last_pipeline_instance = gfpgan - last_pipeline_params = face_path + server.cache.set("gfpgan", cache_key, gfpgan) run_gc() return gfpgan diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index cbb90241..1dee1551 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -21,21 +21,20 @@ x4_v3_tag = "real-esrgan-x4-v3" def load_resrgan( - ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 + server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 ): - global last_pipeline_instance - global last_pipeline_params - model_file = "%s.%s" % (params.upscale_model, params.format) - model_path = path.join(ctx.model_path, model_file) + model_path = path.join(server.model_path, model_file) + + cache_key = (model_path, params.format) + cache_pipe = server.cache.get("resrgan", cache_key) + if cache_pipe is not None: + logger.info("reusing existing Real ESRGAN pipeline") + return cache_pipe + if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - cache_params = (model_path, params.format) - if last_pipeline_instance is not None and cache_params == last_pipeline_params: - logger.info("reusing existing Real ESRGAN pipeline") - return last_pipeline_instance - if x4_v3_tag in model_file: # the x4-v3 model needs a different network model = SRVGGNetCompact( @@ -49,7 +48,7 @@ def load_resrgan( elif params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( - ctx, model_file, provider=device.provider, provider_options=device.options + server, model_file, provider=device.provider, provider_options=device.options ) elif params.format == "pth": model = RRDBNet( @@ -72,7 +71,7 @@ def load_resrgan( logger.debug("loading Real ESRGAN upscale model from %s", model_path) # TODO: shouldn't need the PTH file - model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model) + model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model) upsampler = RealESRGANer( scale=params.scale, model_path=model_path_pth, @@ -84,8 +83,7 @@ def load_resrgan( half=params.half, ) - last_pipeline_instance = upsampler - last_pipeline_params = cache_params + server.cache.set("resrgan", cache_key, upsampler) run_gc() return upsampler diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 85c23957..04c04e65 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -15,22 +15,18 @@ from ..utils import ServerContext, run_gc logger = getLogger(__name__) -last_pipeline_instance = None -last_pipeline_params = (None, None) - def load_stable_diffusion( - ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams + server: ServerContext, upscale: UpscaleParams, device: DeviceParams ): - global last_pipeline_instance - global last_pipeline_params + model_path = path.join(server.model_path, upscale.upscale_model) - model_path = path.join(ctx.model_path, upscale.upscale_model) - cache_params = (model_path, upscale.format) + cache_key = (model_path, upscale.format) + cache_pipe = server.cache.get("diffusion", cache_key) - if last_pipeline_instance is not None and cache_params == last_pipeline_params: + if cache_pipe is not None: logger.debug("reusing existing Stable Diffusion upscale pipeline") - return last_pipeline_instance + return cache_pipe if upscale.format == "onnx": logger.debug( @@ -38,7 +34,7 @@ def load_stable_diffusion( model_path, device.provider, ) - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained( + pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained( model_path, provider=device.provider, provider_options=device.options ) else: @@ -47,15 +43,14 @@ def load_stable_diffusion( model_path, device.provider, ) - pipeline = StableDiffusionUpscalePipeline.from_pretrained( + pipe = StableDiffusionUpscalePipeline.from_pretrained( model_path, provider=device.provider ) - last_pipeline_instance = pipeline - last_pipeline_params = cache_params + server.cache.set("diffusion", cache_key, pipe) run_gc() - return pipeline + return pipe def upscale_stable_diffusion( diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index bc4cf8f7..676c17b8 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -12,8 +12,9 @@ class ModelCache: self.limit = limit def drop(self, tag: str, key: Any) -> None: - self.cache = [model for model in self.cache if model[0] != tag and model[1] != key] - + self.cache = [ + model for model in self.cache if model[0] != tag and model[1] != key + ] def get(self, tag: str, key: Any) -> Any: for t, k, v in self.cache: @@ -38,7 +39,9 @@ class ModelCache: def prune(self): total = len(self.cache) if total > self.limit: - logger.info("Removing models from cache, %s of %s", (total - self.limit), total) + logger.info( + "Removing models from cache, %s of %s", (total - self.limit), total + ) self.cache[:] = self.cache[: self.limit] else: logger.debug("Model cache below limit, %s of %s", total, self.limit)