feat(api): add model cache for diffusion models
This commit is contained in:
parent
7fa1783be4
commit
e9472bc005
|
@ -19,7 +19,7 @@ from .image import (
|
|||
noise_source_uniform,
|
||||
)
|
||||
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
|
||||
from .upscale import run_upscale_correction
|
||||
from .server.upscale import run_upscale_correction
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
base_join,
|
||||
|
|
|
@ -5,9 +5,9 @@ from typing import Any, List, Optional, Protocol, Tuple
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_order
|
||||
|
||||
|
|
|
@ -6,9 +6,9 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..diffusion.load import load_pipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
|
||||
def blend_img2img(
|
||||
job: JobContext,
|
||||
_server: ServerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -30,6 +30,7 @@ def blend_img2img(
|
|||
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_order
|
||||
|
||||
|
@ -65,6 +65,7 @@ def blend_inpaint(
|
|||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -5,8 +5,8 @@ from PIL import Image
|
|||
|
||||
from onnx_web.output import save_image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext, is_debug
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -6,8 +6,8 @@ import numpy as np
|
|||
from gfpgan import GFPGANer
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -2,9 +2,9 @@ from logging import getLogger
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -4,8 +4,8 @@ from logging import getLogger
|
|||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Callable
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -5,9 +5,9 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
|||
def source_txt2img(
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
|
@ -35,6 +35,7 @@ def source_txt2img(
|
|||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid, process_tile_order
|
||||
|
||||
|
@ -73,6 +73,7 @@ def upscale_outpaint(
|
|||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -7,9 +7,9 @@ from PIL import Image
|
|||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..onnx import OnnxNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server.device_pool import JobContext
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
from diffusers import StableDiffusionUpscalePipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server.device_pool import JobContext, ProgressCallback
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -19,20 +19,10 @@ from diffusers import (
|
|||
)
|
||||
|
||||
from ..params import DeviceParams, Size
|
||||
from ..utils import run_gc
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
last_pipeline_instance: Any = None
|
||||
last_pipeline_options: Tuple[
|
||||
Optional[DiffusionPipeline],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
Optional[bool],
|
||||
] = (None, None, None, None, None)
|
||||
last_pipeline_scheduler: Any = None
|
||||
|
||||
latent_channels = 4
|
||||
latent_factor = 8
|
||||
|
||||
|
@ -90,24 +80,42 @@ def get_tile_latents(
|
|||
|
||||
|
||||
def load_pipeline(
|
||||
server: ServerContext,
|
||||
pipeline: DiffusionPipeline,
|
||||
model: str,
|
||||
scheduler_type: Any,
|
||||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
pipe_key = (pipeline, model, device.device, device.provider, lpw)
|
||||
scheduler_key = (scheduler_type,)
|
||||
|
||||
options = (pipeline, model, device.device, device.provider, lpw)
|
||||
if last_pipeline_instance is not None and last_pipeline_options == options:
|
||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.debug("reusing existing diffusion pipeline")
|
||||
pipe = last_pipeline_instance
|
||||
pipe = cache_pipe
|
||||
|
||||
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
||||
if cache_scheduler is None:
|
||||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
scheduler = scheduler.to(device.torch_device())
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||
run_gc()
|
||||
|
||||
else:
|
||||
logger.debug("unloading previous diffusion pipeline")
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_scheduler = None
|
||||
server.cache.drop("diffusion", pipe_key)
|
||||
run_gc()
|
||||
|
||||
if lpw:
|
||||
|
@ -135,24 +143,7 @@ def load_pipeline(
|
|||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device.torch_device())
|
||||
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler_type
|
||||
|
||||
if last_pipeline_scheduler != scheduler_type:
|
||||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
scheduler = scheduler.to(device.torch_device())
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
last_pipeline_scheduler = scheduler_type
|
||||
run_gc()
|
||||
server.cache.set("diffusion", pipe_key, pipe)
|
||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||
|
||||
return pipe
|
||||
|
|
|
@ -10,10 +10,10 @@ from onnx_web.chain import blend_mask
|
|||
from onnx_web.chain.base import ChainProgress
|
||||
|
||||
from ..chain import upscale_outpaint
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image, save_params
|
||||
from ..params import Border, ImageParams, Size, StageParams
|
||||
from ..upscale import UpscaleParams, run_upscale_correction
|
||||
from ..server.device_pool import JobContext
|
||||
from ..server.upscale import UpscaleParams, run_upscale_correction
|
||||
from ..utils import ServerContext, run_gc
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
|
||||
|
@ -30,6 +30,7 @@ def run_txt2img_pipeline(
|
|||
) -> None:
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
@ -97,6 +98,7 @@ def run_img2img_pipeline(
|
|||
strength: float,
|
||||
) -> None:
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
|
|
|
@ -31,7 +31,6 @@ from .chain import (
|
|||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .device_pool import DevicePoolExecutor
|
||||
from .diffusion.load import pipeline_schedulers
|
||||
from .diffusion.run import (
|
||||
run_blend_pipeline,
|
||||
|
@ -40,7 +39,6 @@ from .diffusion.run import (
|
|||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from .hacks import apply_patches
|
||||
from .image import ( # mask filters; noise sources
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
|
@ -62,6 +60,8 @@ from .params import (
|
|||
TileOrder,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .server.device_pool import DevicePoolExecutor
|
||||
from .server.hacks import apply_patches
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
base_join,
|
||||
|
|
|
@ -5,8 +5,8 @@ from multiprocessing import Value
|
|||
from traceback import format_exception
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
from .params import DeviceParams
|
||||
from .utils import run_gc
|
||||
from ..params import DeviceParams
|
||||
from ..utils import run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -7,7 +7,7 @@ from urllib.parse import urlparse
|
|||
import basicsr.utils.download_util
|
||||
import codeformer.facelib.utils.misc
|
||||
|
||||
from .utils import ServerContext
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, List
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class ModelCache:
|
||||
cache: List[(str, Any, Any)]
|
||||
limit: int
|
||||
|
||||
def __init__(self, limit: int) -> None:
|
||||
self.limit = limit
|
||||
|
||||
def drop(self, tag: str, key: Any) -> None:
|
||||
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
|
||||
|
||||
|
||||
def get(self, tag: str, key: Any) -> Any:
|
||||
for t, k, v in self.cache:
|
||||
if tag == t and key == k:
|
||||
return v
|
||||
|
||||
return None
|
||||
|
||||
def set(self, tag: str, key: Any, value: Any) -> None:
|
||||
for i in range(len(self.cache)):
|
||||
t, k, v = self.cache[i]
|
||||
if tag == t:
|
||||
if key != k:
|
||||
logger.debug("Updating model cache: %s", tag)
|
||||
self.cache[i] = v
|
||||
return
|
||||
|
||||
logger.debug("Adding new model to cache: %s", tag)
|
||||
self.cache.append((tag, key, value))
|
||||
self.prune()
|
||||
|
||||
def prune(self):
|
||||
total = len(self.cache)
|
||||
if total > self.limit:
|
||||
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
|
||||
self.cache[:] = self.cache[: self.limit]
|
||||
else:
|
||||
logger.debug("Model cache below limit, %s of %s", total, self.limit)
|
|
@ -2,16 +2,16 @@ from logging import getLogger
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from .chain import (
|
||||
from ..chain import (
|
||||
ChainPipeline,
|
||||
correct_codeformer,
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext
|
||||
from .device_pool import JobContext, ProgressCallback
|
||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import torch
|
||||
|
||||
from .params import SizeChart
|
||||
from .server.model_cache import ModelCache
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -23,6 +24,7 @@ class ServerContext:
|
|||
block_platforms: List[str] = [],
|
||||
default_platform: str = None,
|
||||
image_format: str = "png",
|
||||
cache: ModelCache = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -34,6 +36,7 @@ class ServerContext:
|
|||
self.block_platforms = block_platforms
|
||||
self.default_platform = default_platform
|
||||
self.image_format = image_format
|
||||
self.cache = cache or ModelCache()
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -51,6 +54,7 @@ class ServerContext:
|
|||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
cache=ModelCache(limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 3))),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue