From c432ab0795ae75c225d6399d3ef65ca4245ba2ad Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 13 Feb 2023 18:33:06 -0600 Subject: [PATCH] remove oldest items from model cache first --- api/onnx_web/chain/correct_gfpgan.py | 1 - api/onnx_web/chain/upscale_resrgan.py | 9 +++++++-- api/onnx_web/chain/upscale_stable_diffusion.py | 1 - api/onnx_web/server/hacks.py | 8 +++++++- api/onnx_web/server/model_cache.py | 2 +- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 52d17bcf..8eda8cd4 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -1,6 +1,5 @@ from logging import getLogger from os import path -from typing import Optional import numpy as np from gfpgan import GFPGANer diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 0d8f7b2c..5ce9eb51 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -48,7 +48,10 @@ def load_resrgan( elif params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( - server, model_file, provider=device.provider, provider_options=device.options + server, + model_file, + provider=device.provider, + provider_options=device.options, ) elif params.format == "pth": model = RRDBNet( @@ -71,7 +74,9 @@ def load_resrgan( logger.debug("loading Real ESRGAN upscale model from %s", model_path) # TODO: shouldn't need the PTH file - model_path_pth = path.join(server.model_path, ".cache", ("%s.pth" % params.upscale_model)) + model_path_pth = path.join( + server.model_path, ".cache", ("%s.pth" % params.upscale_model) + ) upsampler = RealESRGANer( scale=params.scale, model_path=model_path_pth, diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 04c04e65..8b540b68 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -15,7 +15,6 @@ from ..utils import ServerContext, run_gc logger = getLogger(__name__) - def load_stable_diffusion( server: ServerContext, upscale: UpscaleParams, device: DeviceParams ): diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 89c6587a..0cdb3215 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -121,4 +121,10 @@ def apply_patches(ctx: ServerContext): apply_patch_basicsr(ctx) apply_patch_codeformer(ctx) apply_patch_facexlib(ctx) - unload(["basicsr.utils.download_util", "codeformer.facelib.utils.misc", "facexlib.utils"]) + unload( + [ + "basicsr.utils.download_util", + "codeformer.facelib.utils.misc", + "facexlib.utils", + ] + ) diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 728cfc81..04cc86cc 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -43,6 +43,6 @@ class ModelCache: logger.info( "Removing models from cache, %s of %s", (total - self.limit), total ) - self.cache[:] = self.cache[: self.limit] + self.cache[:] = self.cache[-self.limit :] else: logger.debug("Model cache below limit, %s of %s", total, self.limit)