1
0
Fork 0

fix(api): use consistent cache key for each model type

This commit is contained in:
Sean Sube 2023-07-03 11:33:56 -05:00
parent a9fa76737e
commit 47b10945ff
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 29 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .stage import BaseStage
@ -28,7 +28,7 @@ class CorrectGFPGANStage(BaseStage):
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,) cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key) cache_pipe = server.cache.get(ModelTypes.correction, cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline") logger.info("reusing existing GFPGAN pipeline")
@ -46,7 +46,7 @@ class CorrectGFPGANStage(BaseStage):
upscale=upscale.face_outscale, upscale=upscale.face_outscale,
) )
server.cache.set("gfpgan", cache_key, gfpgan) server.cache.set(ModelTypes.correction, cache_key, gfpgan)
run_gc([device]) run_gc([device])
return gfpgan return gfpgan

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..models.onnx import OnnxModel from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .stage import BaseStage
@ -28,7 +28,7 @@ class UpscaleBSRGANStage(BaseStage):
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,) cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key) cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline") logger.debug("reusing existing BSRGAN pipeline")
@ -43,7 +43,7 @@ class UpscaleBSRGANStage(BaseStage):
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
server.cache.set("bsrgan", cache_key, pipe) server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device]) run_gc([device])
return pipe return pipe

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..onnx import OnnxRRDBNet from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .stage import BaseStage
@ -29,7 +29,7 @@ class UpscaleRealESRGANStage(BaseStage):
model_path = path.join(server.model_path, model_file) model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format) cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key) cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline") logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe return cache_pipe
@ -66,7 +66,7 @@ class UpscaleRealESRGANStage(BaseStage):
half=False, # TODO: use server optimizations half=False, # TODO: use server optimizations
) )
server.cache.set("resrgan", cache_key, upsampler) server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
run_gc([device]) run_gc([device])
return upsampler return upsampler

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..models.onnx import OnnxModel from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .stage import BaseStage
@ -28,7 +28,7 @@ class UpscaleSwinIRStage(BaseStage):
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,) cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key) cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline") logger.info("reusing existing SwinIR pipeline")
@ -43,7 +43,7 @@ class UpscaleSwinIRStage(BaseStage):
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
server.cache.set("swinir", cache_key, pipe) server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device]) run_gc([device])
return pipe return pipe
@ -75,7 +75,7 @@ class UpscaleSwinIRStage(BaseStage):
image = np.array(source) / 255.0 image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape) logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( dest = np.zeros(
@ -86,7 +86,7 @@ class UpscaleSwinIRStage(BaseStage):
image.shape[3] * scale, image.shape[3] * scale,
) )
) )
logger.info("SwinIR output shape: %s", dest.shape) logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image) dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = np.clip(np.squeeze(dest, axis=0), 0, 1)

View File

@ -11,7 +11,7 @@ from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, ImageParams from ..params import DeviceParams, ImageParams
from ..server import ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from .patches.unet import UNetWrapper from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper from .patches.vae import VAEWrapper
@ -119,14 +119,14 @@ def load_pipeline(
scheduler_key = (params.scheduler, model) scheduler_key = (params.scheduler, model)
scheduler_type = pipeline_schedulers[params.scheduler] scheduler_type = pipeline_schedulers[params.scheduler]
cache_pipe = server.cache.get("diffusion", pipe_key) cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline") logger.debug("reusing existing diffusion pipeline")
pipe = cache_pipe pipe = cache_pipe
# update scheduler # update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key) cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
if cache_scheduler is None: if cache_scheduler is None:
logger.debug("loading new diffusion scheduler") logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained( scheduler = scheduler_type.from_pretrained(
@ -141,7 +141,7 @@ def load_pipeline(
scheduler = scheduler.to(device.torch_str()) scheduler = scheduler.to(device.torch_str())
pipe.scheduler = scheduler pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler) server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
run_gc([device]) run_gc([device])
else: else:
@ -342,8 +342,8 @@ def load_pipeline(
optimize_pipeline(server, pipe) optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline, pipeline_class, params) patch_pipeline(server, pipe, pipeline, pipeline_class, params)
server.cache.set("diffusion", pipe_key, pipe) server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"]) server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if hasattr(pipe, "vae_decoder"): if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)

View File

@ -4,5 +4,5 @@ from .hacks import (
apply_patch_facexlib, apply_patch_facexlib,
apply_patches, apply_patches,
) )
from .model_cache import ModelCache from .model_cache import ModelCache, ModelTypes
from .context import ServerContext from .context import ServerContext

View File

@ -1,3 +1,4 @@
from enum import Enum
from logging import getLogger from logging import getLogger
from typing import Any, List, Tuple from typing import Any, List, Tuple
@ -6,6 +7,13 @@ logger = getLogger(__name__)
cache: List[Tuple[str, Any, Any]] = [] cache: List[Tuple[str, Any, Any]] = []
class ModelTypes(str, Enum):
correction = "correction"
diffusion = "diffusion"
scheduler = "scheduler"
upscaling = "upscaling"
class ModelCache: class ModelCache:
# cache: List[Tuple[str, Any, Any]] # cache: List[Tuple[str, Any, Any]]
limit: int limit: int