1
0
Fork 0

fix(api): use consistent cache key for each model type

This commit is contained in:
Sean Sube 2023-07-03 11:33:56 -05:00
parent a9fa76737e
commit 47b10945ff
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 29 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
@ -28,7 +28,7 @@ class CorrectGFPGANStage(BaseStage):
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.correction, cache_key)
if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
@ -46,7 +46,7 @@ class CorrectGFPGANStage(BaseStage):
upscale=upscale.face_outscale,
)
server.cache.set("gfpgan", cache_key, gfpgan)
server.cache.set(ModelTypes.correction, cache_key, gfpgan)
run_gc([device])
return gfpgan

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
@ -28,7 +28,7 @@ class UpscaleBSRGANStage(BaseStage):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline")
@ -43,7 +43,7 @@ class UpscaleBSRGANStage(BaseStage):
sess_options=device.sess_options(),
)
server.cache.set("bsrgan", cache_key, pipe)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])
return pipe

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
@ -29,7 +29,7 @@ class UpscaleRealESRGANStage(BaseStage):
model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
@ -66,7 +66,7 @@ class UpscaleRealESRGANStage(BaseStage):
half=False, # TODO: use server optimizations
)
server.cache.set("resrgan", cache_key, upsampler)
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
run_gc([device])
return upsampler

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
@ -28,7 +28,7 @@ class UpscaleSwinIRStage(BaseStage):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
@ -43,7 +43,7 @@ class UpscaleSwinIRStage(BaseStage):
sess_options=device.sess_options(),
)
server.cache.set("swinir", cache_key, pipe)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])
return pipe
@ -75,7 +75,7 @@ class UpscaleSwinIRStage(BaseStage):
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape)
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
@ -86,7 +86,7 @@ class UpscaleSwinIRStage(BaseStage):
image.shape[3] * scale,
)
)
logger.info("SwinIR output shape: %s", dest.shape)
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)

View File

@ -11,7 +11,7 @@ from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, ImageParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper
@ -119,14 +119,14 @@ def load_pipeline(
scheduler_key = (params.scheduler, model)
scheduler_type = pipeline_schedulers[params.scheduler]
cache_pipe = server.cache.get("diffusion", pipe_key)
cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)
if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline")
pipe = cache_pipe
# update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key)
cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
if cache_scheduler is None:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
@ -141,7 +141,7 @@ def load_pipeline(
scheduler = scheduler.to(device.torch_str())
pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
run_gc([device])
else:
@ -342,8 +342,8 @@ def load_pipeline(
optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)

View File

@ -4,5 +4,5 @@ from .hacks import (
apply_patch_facexlib,
apply_patches,
)
from .model_cache import ModelCache
from .model_cache import ModelCache, ModelTypes
from .context import ServerContext

View File

@ -1,3 +1,4 @@
from enum import Enum
from logging import getLogger
from typing import Any, List, Tuple
@ -6,6 +7,13 @@ logger = getLogger(__name__)
cache: List[Tuple[str, Any, Any]] = []
class ModelTypes(str, Enum):
correction = "correction"
diffusion = "diffusion"
scheduler = "scheduler"
upscaling = "upscaling"
class ModelCache:
# cache: List[Tuple[str, Any, Any]]
limit: int