fix(api): use consistent cache key for each model type
This commit is contained in:
parent
a9fa76737e
commit
47b10945ff
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .stage import BaseStage
|
||||||
|
@ -28,7 +28,7 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
|
|
||||||
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
||||||
cache_key = (face_path,)
|
cache_key = (face_path,)
|
||||||
cache_pipe = server.cache.get("gfpgan", cache_key)
|
cache_pipe = server.cache.get(ModelTypes.correction, cache_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing GFPGAN pipeline")
|
logger.info("reusing existing GFPGAN pipeline")
|
||||||
|
@ -46,7 +46,7 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
upscale=upscale.face_outscale,
|
upscale=upscale.face_outscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
server.cache.set(ModelTypes.correction, cache_key, gfpgan)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return gfpgan
|
return gfpgan
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..models.onnx import OnnxModel
|
from ..models.onnx import OnnxModel
|
||||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .stage import BaseStage
|
||||||
|
@ -28,7 +28,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||||
cache_key = (model_path,)
|
cache_key = (model_path,)
|
||||||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.debug("reusing existing BSRGAN pipeline")
|
logger.debug("reusing existing BSRGAN pipeline")
|
||||||
|
@ -43,7 +43,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("bsrgan", cache_key, pipe)
|
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxRRDBNet
|
from ..onnx import OnnxRRDBNet
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .stage import BaseStage
|
||||||
|
@ -29,7 +29,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
model_path = path.join(server.model_path, model_file)
|
model_path = path.join(server.model_path, model_file)
|
||||||
|
|
||||||
cache_key = (model_path, params.format)
|
cache_key = (model_path, params.format)
|
||||||
cache_pipe = server.cache.get("resrgan", cache_key)
|
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing Real ESRGAN pipeline")
|
logger.info("reusing existing Real ESRGAN pipeline")
|
||||||
return cache_pipe
|
return cache_pipe
|
||||||
|
@ -66,7 +66,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
half=False, # TODO: use server optimizations
|
half=False, # TODO: use server optimizations
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("resrgan", cache_key, upsampler)
|
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..models.onnx import OnnxModel
|
from ..models.onnx import OnnxModel
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .stage import BaseStage
|
||||||
|
@ -28,7 +28,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||||
cache_key = (model_path,)
|
cache_key = (model_path,)
|
||||||
cache_pipe = server.cache.get("swinir", cache_key)
|
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing SwinIR pipeline")
|
logger.info("reusing existing SwinIR pipeline")
|
||||||
|
@ -43,7 +43,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("swinir", cache_key, pipe)
|
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
@ -75,7 +75,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.info("SwinIR input shape: %s", image.shape)
|
logger.trace("SwinIR input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
dest = np.zeros(
|
||||||
|
@ -86,7 +86,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info("SwinIR output shape: %s", dest.shape)
|
logger.trace("SwinIR output shape: %s", dest.shape)
|
||||||
|
|
||||||
dest = swinir(image)
|
dest = swinir(image)
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ..diffusers.utils import expand_prompt
|
from ..diffusers.utils import expand_prompt
|
||||||
from ..params import DeviceParams, ImageParams
|
from ..params import DeviceParams, ImageParams
|
||||||
from ..server import ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from .patches.unet import UNetWrapper
|
from .patches.unet import UNetWrapper
|
||||||
from .patches.vae import VAEWrapper
|
from .patches.vae import VAEWrapper
|
||||||
|
@ -119,14 +119,14 @@ def load_pipeline(
|
||||||
scheduler_key = (params.scheduler, model)
|
scheduler_key = (params.scheduler, model)
|
||||||
scheduler_type = pipeline_schedulers[params.scheduler]
|
scheduler_type = pipeline_schedulers[params.scheduler]
|
||||||
|
|
||||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.debug("reusing existing diffusion pipeline")
|
logger.debug("reusing existing diffusion pipeline")
|
||||||
pipe = cache_pipe
|
pipe = cache_pipe
|
||||||
|
|
||||||
# update scheduler
|
# update scheduler
|
||||||
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
|
||||||
if cache_scheduler is None:
|
if cache_scheduler is None:
|
||||||
logger.debug("loading new diffusion scheduler")
|
logger.debug("loading new diffusion scheduler")
|
||||||
scheduler = scheduler_type.from_pretrained(
|
scheduler = scheduler_type.from_pretrained(
|
||||||
|
@ -141,7 +141,7 @@ def load_pipeline(
|
||||||
scheduler = scheduler.to(device.torch_str())
|
scheduler = scheduler.to(device.torch_str())
|
||||||
|
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -342,8 +342,8 @@ def load_pipeline(
|
||||||
optimize_pipeline(server, pipe)
|
optimize_pipeline(server, pipe)
|
||||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
||||||
|
|
||||||
server.cache.set("diffusion", pipe_key, pipe)
|
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||||
server.cache.set("scheduler", scheduler_key, components["scheduler"])
|
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||||
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
if hasattr(pipe, "vae_decoder"):
|
||||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||||
|
|
|
@ -4,5 +4,5 @@ from .hacks import (
|
||||||
apply_patch_facexlib,
|
apply_patch_facexlib,
|
||||||
apply_patches,
|
apply_patches,
|
||||||
)
|
)
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache, ModelTypes
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import Enum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
@ -6,6 +7,13 @@ logger = getLogger(__name__)
|
||||||
cache: List[Tuple[str, Any, Any]] = []
|
cache: List[Tuple[str, Any, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTypes(str, Enum):
|
||||||
|
correction = "correction"
|
||||||
|
diffusion = "diffusion"
|
||||||
|
scheduler = "scheduler"
|
||||||
|
upscaling = "upscaling"
|
||||||
|
|
||||||
|
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
# cache: List[Tuple[str, Any, Any]]
|
# cache: List[Tuple[str, Any, Any]]
|
||||||
limit: int
|
limit: int
|
||||||
|
|
Loading…
Reference in New Issue