From e9472bc005c2bda9c652595cfc0122ac23fff36b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 13 Feb 2023 18:04:46 -0600 Subject: [PATCH] feat(api): add model cache for diffusion models --- api/onnx_web/__init__.py | 2 +- api/onnx_web/chain/base.py | 2 +- api/onnx_web/chain/blend_img2img.py | 5 +- api/onnx_web/chain/blend_inpaint.py | 3 +- api/onnx_web/chain/blend_mask.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 2 +- api/onnx_web/chain/correct_gfpgan.py | 2 +- api/onnx_web/chain/persist_disk.py | 2 +- api/onnx_web/chain/persist_s3.py | 2 +- api/onnx_web/chain/reduce_crop.py | 2 +- api/onnx_web/chain/reduce_thumbnail.py | 2 +- api/onnx_web/chain/source_noise.py | 2 +- api/onnx_web/chain/source_txt2img.py | 5 +- api/onnx_web/chain/upscale_outpaint.py | 3 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- .../chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/diffusion/load.py | 67 ++++++++----------- api/onnx_web/diffusion/run.py | 6 +- api/onnx_web/serve.py | 4 +- api/onnx_web/{ => server}/device_pool.py | 4 +- api/onnx_web/{ => server}/hacks.py | 2 +- api/onnx_web/server/model_cache.py | 44 ++++++++++++ api/onnx_web/{ => server}/upscale.py | 6 +- api/onnx_web/utils.py | 4 ++ 24 files changed, 111 insertions(+), 66 deletions(-) rename api/onnx_web/{ => server}/device_pool.py (99%) rename api/onnx_web/{ => server}/hacks.py (99%) create mode 100644 api/onnx_web/server/model_cache.py rename api/onnx_web/{ => server}/upscale.py (93%) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index f3e3aabe..4eb18225 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -19,7 +19,7 @@ from .image import ( noise_source_uniform, ) from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams -from .upscale import run_upscale_correction +from .server.upscale import run_upscale_correction from .utils import ( ServerContext, base_join, diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index dd661f30..32d919ec 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -5,9 +5,9 @@ from typing import Any, List, Optional, Protocol, Tuple from PIL import Image -from ..device_pool import JobContext, ProgressCallback from ..output import save_image from ..params import ImageParams, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext, is_debug from .utils import process_tile_order diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index b101125c..99591b31 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -6,9 +6,9 @@ import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline from PIL import Image -from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext logger = getLogger(__name__) @@ -16,7 +16,7 @@ logger = getLogger(__name__) def blend_img2img( job: JobContext, - _server: ServerContext, + server: ServerContext, _stage: StageParams, params: ImageParams, source_image: Image.Image, @@ -30,6 +30,7 @@ def blend_img2img( logger.info("blending image using img2img, %s steps: %s", params.steps, prompt) pipe = load_pipeline( + server, OnnxStableDiffusionImg2ImgPipeline, params.model, params.scheduler, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 114f095a..7e517e91 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -6,11 +6,11 @@ import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image -from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext, is_debug from .utils import process_tile_order @@ -65,6 +65,7 @@ def blend_inpaint( latents = get_latents_from_seed(params.seed, size) pipe = load_pipeline( + server, OnnxStableDiffusionInpaintPipeline, params.model, params.scheduler, diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index b8f0d390..6d247d1d 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -5,8 +5,8 @@ from PIL import Image from onnx_web.output import save_image -from ..device_pool import JobContext, ProgressCallback from ..params import ImageParams, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext, is_debug logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index e8d74a1a..640ad829 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -2,8 +2,8 @@ from logging import getLogger from PIL import Image -from ..device_pool import JobContext from ..params import ImageParams, StageParams, UpscaleParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index a02d87ae..db56f762 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -6,8 +6,8 @@ import numpy as np from gfpgan import GFPGANer from PIL import Image -from ..device_pool import JobContext from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..server.device_pool import JobContext from ..utils import ServerContext, run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index abef41fb..17d86bde 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -2,9 +2,9 @@ from logging import getLogger from PIL import Image -from ..device_pool import JobContext from ..output import save_image from ..params import ImageParams, StageParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 31be69f3..4dd614e3 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -4,8 +4,8 @@ from logging import getLogger from boto3 import Session from PIL import Image -from ..device_pool import JobContext from ..params import ImageParams, StageParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index e8f70f7d..8107ac38 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -2,8 +2,8 @@ from logging import getLogger from PIL import Image -from ..device_pool import JobContext from ..params import ImageParams, Size, StageParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index b6beb9c6..50cccc1d 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -2,8 +2,8 @@ from logging import getLogger from PIL import Image -from ..device_pool import JobContext from ..params import ImageParams, Size, StageParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 9b01ccb6..fcaf5d5b 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -3,8 +3,8 @@ from typing import Callable from PIL import Image -from ..device_pool import JobContext from ..params import ImageParams, Size, StageParams +from ..server.device_pool import JobContext from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index af32da0a..b5d763eb 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -5,9 +5,9 @@ import torch from diffusers import OnnxStableDiffusionPipeline from PIL import Image -from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext logger = getLogger(__name__) @@ -16,7 +16,7 @@ logger = getLogger(__name__) def source_txt2img( job: JobContext, server: ServerContext, - stage: StageParams, + _stage: StageParams, params: ImageParams, source_image: Image.Image, *, @@ -35,6 +35,7 @@ def source_txt2img( latents = get_latents_from_seed(params.seed, size) pipe = load_pipeline( + server, OnnxStableDiffusionPipeline, params.model, params.scheduler, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 7872930c..d9b73d70 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -6,11 +6,11 @@ import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image, ImageDraw -from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext, is_debug from .utils import process_tile_grid, process_tile_order @@ -73,6 +73,7 @@ def upscale_outpaint( latents = get_tile_latents(full_latents, dims) pipe = load_pipeline( + server, OnnxStableDiffusionInpaintPipeline, params.model, params.scheduler, diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index f543086a..cbb90241 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -7,9 +7,9 @@ from PIL import Image from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact -from ..device_pool import JobContext from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..server.device_pool import JobContext from ..utils import ServerContext, run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index e947717d..85c23957 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -5,11 +5,11 @@ import torch from diffusers import StableDiffusionUpscalePipeline from PIL import Image -from ..device_pool import JobContext, ProgressCallback from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..server.device_pool import JobContext, ProgressCallback from ..utils import ServerContext, run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index da0bc13a..6e116b55 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -19,20 +19,10 @@ from diffusers import ( ) from ..params import DeviceParams, Size -from ..utils import run_gc +from ..utils import ServerContext, run_gc logger = getLogger(__name__) -last_pipeline_instance: Any = None -last_pipeline_options: Tuple[ - Optional[DiffusionPipeline], - Optional[str], - Optional[str], - Optional[str], - Optional[bool], -] = (None, None, None, None, None) -last_pipeline_scheduler: Any = None - latent_channels = 4 latent_factor = 8 @@ -90,24 +80,42 @@ def get_tile_latents( def load_pipeline( + server: ServerContext, pipeline: DiffusionPipeline, model: str, scheduler_type: Any, device: DeviceParams, lpw: bool, ): - global last_pipeline_instance - global last_pipeline_scheduler - global last_pipeline_options + pipe_key = (pipeline, model, device.device, device.provider, lpw) + scheduler_key = (scheduler_type,) - options = (pipeline, model, device.device, device.provider, lpw) - if last_pipeline_instance is not None and last_pipeline_options == options: + cache_pipe = server.cache.get("diffusion", pipe_key) + + if cache_pipe is not None: logger.debug("reusing existing diffusion pipeline") - pipe = last_pipeline_instance + pipe = cache_pipe + + cache_scheduler = server.cache.get("scheduler", scheduler_key) + if cache_scheduler is None: + logger.debug("loading new diffusion scheduler") + scheduler = scheduler_type.from_pretrained( + model, + provider=device.provider, + provider_options=device.options, + subfolder="scheduler", + ) + + if device is not None and hasattr(scheduler, "to"): + scheduler = scheduler.to(device.torch_device()) + + pipe.scheduler = scheduler + server.cache.set("scheduler", scheduler_key, scheduler) + run_gc() + else: logger.debug("unloading previous diffusion pipeline") - last_pipeline_instance = None - last_pipeline_scheduler = None + server.cache.drop("diffusion", pipe_key) run_gc() if lpw: @@ -135,24 +143,7 @@ def load_pipeline( if device is not None and hasattr(pipe, "to"): pipe = pipe.to(device.torch_device()) - last_pipeline_instance = pipe - last_pipeline_options = options - last_pipeline_scheduler = scheduler_type - - if last_pipeline_scheduler != scheduler_type: - logger.debug("loading new diffusion scheduler") - scheduler = scheduler_type.from_pretrained( - model, - provider=device.provider, - provider_options=device.options, - subfolder="scheduler", - ) - - if device is not None and hasattr(scheduler, "to"): - scheduler = scheduler.to(device.torch_device()) - - pipe.scheduler = scheduler - last_pipeline_scheduler = scheduler_type - run_gc() + server.cache.set("diffusion", pipe_key, pipe) + server.cache.set("scheduler", scheduler_key, scheduler) return pipe diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index e86ec9c4..4785a23e 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -10,10 +10,10 @@ from onnx_web.chain import blend_mask from onnx_web.chain.base import ChainProgress from ..chain import upscale_outpaint -from ..device_pool import JobContext from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams -from ..upscale import UpscaleParams, run_upscale_correction +from ..server.device_pool import JobContext +from ..server.upscale import UpscaleParams, run_upscale_correction from ..utils import ServerContext, run_gc from .load import get_latents_from_seed, load_pipeline @@ -30,6 +30,7 @@ def run_txt2img_pipeline( ) -> None: latents = get_latents_from_seed(params.seed, size) pipe = load_pipeline( + server, OnnxStableDiffusionPipeline, params.model, params.scheduler, @@ -97,6 +98,7 @@ def run_img2img_pipeline( strength: float, ) -> None: pipe = load_pipeline( + server, OnnxStableDiffusionImg2ImgPipeline, params.model, params.scheduler, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 12e9046d..cd35921d 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -31,7 +31,6 @@ from .chain import ( upscale_resrgan, upscale_stable_diffusion, ) -from .device_pool import DevicePoolExecutor from .diffusion.load import pipeline_schedulers from .diffusion.run import ( run_blend_pipeline, @@ -40,7 +39,6 @@ from .diffusion.run import ( run_txt2img_pipeline, run_upscale_pipeline, ) -from .hacks import apply_patches from .image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -62,6 +60,8 @@ from .params import ( TileOrder, UpscaleParams, ) +from .server.device_pool import DevicePoolExecutor +from .server.hacks import apply_patches from .utils import ( ServerContext, base_join, diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/server/device_pool.py similarity index 99% rename from api/onnx_web/device_pool.py rename to api/onnx_web/server/device_pool.py index 0d359ad7..29db85b0 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/server/device_pool.py @@ -5,8 +5,8 @@ from multiprocessing import Value from traceback import format_exception from typing import Any, Callable, List, Optional, Tuple, Union -from .params import DeviceParams -from .utils import run_gc +from ..params import DeviceParams +from ..utils import run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/hacks.py b/api/onnx_web/server/hacks.py similarity index 99% rename from api/onnx_web/hacks.py rename to api/onnx_web/server/hacks.py index acb93ed2..8263325a 100644 --- a/api/onnx_web/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse import basicsr.utils.download_util import codeformer.facelib.utils.misc -from .utils import ServerContext +from ..utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py new file mode 100644 index 00000000..bc4cf8f7 --- /dev/null +++ b/api/onnx_web/server/model_cache.py @@ -0,0 +1,44 @@ +from logging import getLogger +from typing import Any, List + +logger = getLogger(__name__) + + +class ModelCache: + cache: List[(str, Any, Any)] + limit: int + + def __init__(self, limit: int) -> None: + self.limit = limit + + def drop(self, tag: str, key: Any) -> None: + self.cache = [model for model in self.cache if model[0] != tag and model[1] != key] + + + def get(self, tag: str, key: Any) -> Any: + for t, k, v in self.cache: + if tag == t and key == k: + return v + + return None + + def set(self, tag: str, key: Any, value: Any) -> None: + for i in range(len(self.cache)): + t, k, v = self.cache[i] + if tag == t: + if key != k: + logger.debug("Updating model cache: %s", tag) + self.cache[i] = v + return + + logger.debug("Adding new model to cache: %s", tag) + self.cache.append((tag, key, value)) + self.prune() + + def prune(self): + total = len(self.cache) + if total > self.limit: + logger.info("Removing models from cache, %s of %s", (total - self.limit), total) + self.cache[:] = self.cache[: self.limit] + else: + logger.debug("Model cache below limit, %s of %s", total, self.limit) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/server/upscale.py similarity index 93% rename from api/onnx_web/upscale.py rename to api/onnx_web/server/upscale.py index d13ba7fd..702ef72c 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/server/upscale.py @@ -2,16 +2,16 @@ from logging import getLogger from PIL import Image -from .chain import ( +from ..chain import ( ChainPipeline, correct_codeformer, correct_gfpgan, upscale_resrgan, upscale_stable_diffusion, ) +from ..params import ImageParams, SizeChart, StageParams, UpscaleParams +from ..utils import ServerContext from .device_pool import JobContext, ProgressCallback -from .params import ImageParams, SizeChart, StageParams, UpscaleParams -from .utils import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 0eeecc9a..36436236 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union import torch from .params import SizeChart +from .server.model_cache import ModelCache logger = getLogger(__name__) @@ -23,6 +24,7 @@ class ServerContext: block_platforms: List[str] = [], default_platform: str = None, image_format: str = "png", + cache: ModelCache = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -34,6 +36,7 @@ class ServerContext: self.block_platforms = block_platforms self.default_platform = default_platform self.image_format = image_format + self.cache = cache or ModelCache() @classmethod def from_environ(cls): @@ -51,6 +54,7 @@ class ServerContext: block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), + cache=ModelCache(limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 3))), )