fix(api): use consistent cache key for each model type
This commit is contained in:
parent
a9fa76737e
commit
47b10945ff
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
|
@ -28,7 +28,7 @@ class CorrectGFPGANStage(BaseStage):
|
|||
|
||||
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
||||
cache_key = (face_path,)
|
||||
cache_pipe = server.cache.get("gfpgan", cache_key)
|
||||
cache_pipe = server.cache.get(ModelTypes.correction, cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing GFPGAN pipeline")
|
||||
|
@ -46,7 +46,7 @@ class CorrectGFPGANStage(BaseStage):
|
|||
upscale=upscale.face_outscale,
|
||||
)
|
||||
|
||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||
server.cache.set(ModelTypes.correction, cache_key, gfpgan)
|
||||
run_gc([device])
|
||||
|
||||
return gfpgan
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
|
@ -28,7 +28,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
# must be within the load function for patch to take effect
|
||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||
cache_key = (model_path,)
|
||||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
||||
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.debug("reusing existing BSRGAN pipeline")
|
||||
|
@ -43,7 +43,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
server.cache.set("bsrgan", cache_key, pipe)
|
||||
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..onnx import OnnxRRDBNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
|
@ -29,7 +29,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
model_path = path.join(server.model_path, model_file)
|
||||
|
||||
cache_key = (model_path, params.format)
|
||||
cache_pipe = server.cache.get("resrgan", cache_key)
|
||||
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing Real ESRGAN pipeline")
|
||||
return cache_pipe
|
||||
|
@ -66,7 +66,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
half=False, # TODO: use server optimizations
|
||||
)
|
||||
|
||||
server.cache.set("resrgan", cache_key, upsampler)
|
||||
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
|
||||
run_gc([device])
|
||||
|
||||
return upsampler
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
|
@ -28,7 +28,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
# must be within the load function for patch to take effect
|
||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||
cache_key = (model_path,)
|
||||
cache_pipe = server.cache.get("swinir", cache_key)
|
||||
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing SwinIR pipeline")
|
||||
|
@ -43,7 +43,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
server.cache.set("swinir", cache_key, pipe)
|
||||
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
@ -75,7 +75,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.info("SwinIR input shape: %s", image.shape)
|
||||
logger.trace("SwinIR input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
|
@ -86,7 +86,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
image.shape[3] * scale,
|
||||
)
|
||||
)
|
||||
logger.info("SwinIR output shape: %s", dest.shape)
|
||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
dest = swinir(image)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
|
|
|
@ -11,7 +11,7 @@ from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
|||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ..diffusers.utils import expand_prompt
|
||||
from ..params import DeviceParams, ImageParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from .patches.unet import UNetWrapper
|
||||
from .patches.vae import VAEWrapper
|
||||
|
@ -119,14 +119,14 @@ def load_pipeline(
|
|||
scheduler_key = (params.scheduler, model)
|
||||
scheduler_type = pipeline_schedulers[params.scheduler]
|
||||
|
||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||
cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.debug("reusing existing diffusion pipeline")
|
||||
pipe = cache_pipe
|
||||
|
||||
# update scheduler
|
||||
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
||||
cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
|
||||
if cache_scheduler is None:
|
||||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
|
@ -141,7 +141,7 @@ def load_pipeline(
|
|||
scheduler = scheduler.to(device.torch_str())
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
||||
run_gc([device])
|
||||
|
||||
else:
|
||||
|
@ -342,8 +342,8 @@ def load_pipeline(
|
|||
optimize_pipeline(server, pipe)
|
||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
||||
|
||||
server.cache.set("diffusion", pipe_key, pipe)
|
||||
server.cache.set("scheduler", scheduler_key, components["scheduler"])
|
||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||
|
|
|
@ -4,5 +4,5 @@ from .hacks import (
|
|||
apply_patch_facexlib,
|
||||
apply_patches,
|
||||
)
|
||||
from .model_cache import ModelCache
|
||||
from .model_cache import ModelCache, ModelTypes
|
||||
from .context import ServerContext
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
@ -6,6 +7,13 @@ logger = getLogger(__name__)
|
|||
cache: List[Tuple[str, Any, Any]] = []
|
||||
|
||||
|
||||
class ModelTypes(str, Enum):
|
||||
correction = "correction"
|
||||
diffusion = "diffusion"
|
||||
scheduler = "scheduler"
|
||||
upscaling = "upscaling"
|
||||
|
||||
|
||||
class ModelCache:
|
||||
# cache: List[Tuple[str, Any, Any]]
|
||||
limit: int
|
||||
|
|
Loading…
Reference in New Issue