1
0
Fork 0

feat(api): add model cache for diffusion models

This commit is contained in:
Sean Sube 2023-02-13 18:04:46 -06:00
parent 7fa1783be4
commit e9472bc005
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
24 changed files with 111 additions and 66 deletions

View File

@ -19,7 +19,7 @@ from .image import (
noise_source_uniform,
)
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
from .upscale import run_upscale_correction
from .server.upscale import run_upscale_correction
from .utils import (
ServerContext,
base_join,

View File

@ -5,9 +5,9 @@ from typing import Any, List, Optional, Protocol, Tuple
from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_order

View File

@ -6,9 +6,9 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext
logger = getLogger(__name__)
@ -16,7 +16,7 @@ logger = getLogger(__name__)
def blend_img2img(
job: JobContext,
_server: ServerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
@ -30,6 +30,7 @@ def blend_img2img(
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,

View File

@ -6,11 +6,11 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_order
@ -65,6 +65,7 @@ def blend_inpaint(
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,

View File

@ -5,8 +5,8 @@ from PIL import Image
from onnx_web.output import save_image
from ..device_pool import JobContext, ProgressCallback
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -6,8 +6,8 @@ import numpy as np
from gfpgan import GFPGANer
from PIL import Image
from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)

View File

@ -2,9 +2,9 @@ from logging import getLogger
from PIL import Image
from ..device_pool import JobContext
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -4,8 +4,8 @@ from logging import getLogger
from boto3 import Session
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -3,8 +3,8 @@ from typing import Callable
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -5,9 +5,9 @@ import torch
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext
logger = getLogger(__name__)
@ -16,7 +16,7 @@ logger = getLogger(__name__)
def source_txt2img(
job: JobContext,
server: ServerContext,
stage: StageParams,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
@ -35,6 +35,7 @@ def source_txt2img(
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
params.model,
params.scheduler,

View File

@ -6,11 +6,11 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid, process_tile_order
@ -73,6 +73,7 @@ def upscale_outpaint(
latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,

View File

@ -7,9 +7,9 @@ from PIL import Image
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from ..device_pool import JobContext
from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)

View File

@ -5,11 +5,11 @@ import torch
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)

View File

@ -19,20 +19,10 @@ from diffusers import (
)
from ..params import DeviceParams, Size
from ..utils import run_gc
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)
last_pipeline_instance: Any = None
last_pipeline_options: Tuple[
Optional[DiffusionPipeline],
Optional[str],
Optional[str],
Optional[str],
Optional[bool],
] = (None, None, None, None, None)
last_pipeline_scheduler: Any = None
latent_channels = 4
latent_factor = 8
@ -90,24 +80,42 @@ def get_tile_latents(
def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
model: str,
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
pipe_key = (pipeline, model, device.device, device.provider, lpw)
scheduler_key = (scheduler_type,)
options = (pipeline, model, device.device, device.provider, lpw)
if last_pipeline_instance is not None and last_pipeline_options == options:
cache_pipe = server.cache.get("diffusion", pipe_key)
if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance
pipe = cache_pipe
cache_scheduler = server.cache.get("scheduler", scheduler_key)
if cache_scheduler is None:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
run_gc()
else:
logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None
last_pipeline_scheduler = None
server.cache.drop("diffusion", pipe_key)
run_gc()
if lpw:
@ -135,24 +143,7 @@ def load_pipeline(
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_device())
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler_type
if last_pipeline_scheduler != scheduler_type:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler_type
run_gc()
server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, scheduler)
return pipe

View File

@ -10,10 +10,10 @@ from onnx_web.chain import blend_mask
from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams
from ..upscale import UpscaleParams, run_upscale_correction
from ..server.device_pool import JobContext
from ..server.upscale import UpscaleParams, run_upscale_correction
from ..utils import ServerContext, run_gc
from .load import get_latents_from_seed, load_pipeline
@ -30,6 +30,7 @@ def run_txt2img_pipeline(
) -> None:
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
params.model,
params.scheduler,
@ -97,6 +98,7 @@ def run_img2img_pipeline(
strength: float,
) -> None:
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,

View File

@ -31,7 +31,6 @@ from .chain import (
upscale_resrgan,
upscale_stable_diffusion,
)
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
@ -40,7 +39,6 @@ from .diffusion.run import (
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .hacks import apply_patches
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -62,6 +60,8 @@ from .params import (
TileOrder,
UpscaleParams,
)
from .server.device_pool import DevicePoolExecutor
from .server.hacks import apply_patches
from .utils import (
ServerContext,
base_join,

View File

@ -5,8 +5,8 @@ from multiprocessing import Value
from traceback import format_exception
from typing import Any, Callable, List, Optional, Tuple, Union
from .params import DeviceParams
from .utils import run_gc
from ..params import DeviceParams
from ..utils import run_gc
logger = getLogger(__name__)

View File

@ -7,7 +7,7 @@ from urllib.parse import urlparse
import basicsr.utils.download_util
import codeformer.facelib.utils.misc
from .utils import ServerContext
from ..utils import ServerContext
logger = getLogger(__name__)

View File

@ -0,0 +1,44 @@
from logging import getLogger
from typing import Any, List
logger = getLogger(__name__)
class ModelCache:
cache: List[(str, Any, Any)]
limit: int
def __init__(self, limit: int) -> None:
self.limit = limit
def drop(self, tag: str, key: Any) -> None:
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache:
if tag == t and key == k:
return v
return None
def set(self, tag: str, key: Any, value: Any) -> None:
for i in range(len(self.cache)):
t, k, v = self.cache[i]
if tag == t:
if key != k:
logger.debug("Updating model cache: %s", tag)
self.cache[i] = v
return
logger.debug("Adding new model to cache: %s", tag)
self.cache.append((tag, key, value))
self.prune()
def prune(self):
total = len(self.cache)
if total > self.limit:
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
self.cache[:] = self.cache[: self.limit]
else:
logger.debug("Model cache below limit, %s of %s", total, self.limit)

View File

@ -2,16 +2,16 @@ from logging import getLogger
from PIL import Image
from .chain import (
from ..chain import (
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..utils import ServerContext
from .device_pool import JobContext, ProgressCallback
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .utils import ServerContext
logger = getLogger(__name__)

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
from .params import SizeChart
from .server.model_cache import ModelCache
logger = getLogger(__name__)
@ -23,6 +24,7 @@ class ServerContext:
block_platforms: List[str] = [],
default_platform: str = None,
image_format: str = "png",
cache: ModelCache = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -34,6 +36,7 @@ class ServerContext:
self.block_platforms = block_platforms
self.default_platform = default_platform
self.image_format = image_format
self.cache = cache or ModelCache()
@classmethod
def from_environ(cls):
@ -51,6 +54,7 @@ class ServerContext:
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 3))),
)