1
0
Fork 0

feat(api): add model cache for diffusion models

This commit is contained in:
Sean Sube 2023-02-13 18:04:46 -06:00
parent 7fa1783be4
commit e9472bc005
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
24 changed files with 111 additions and 66 deletions

View File

@ -19,7 +19,7 @@ from .image import (
noise_source_uniform, noise_source_uniform,
) )
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
from .upscale import run_upscale_correction from .server.upscale import run_upscale_correction
from .utils import ( from .utils import (
ServerContext, ServerContext,
base_join, base_join,

View File

@ -5,9 +5,9 @@ from typing import Any, List, Optional, Protocol, Tuple
from PIL import Image from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_order from .utils import process_tile_order

View File

@ -6,9 +6,9 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -16,7 +16,7 @@ logger = getLogger(__name__)
def blend_img2img( def blend_img2img(
job: JobContext, job: JobContext,
_server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -30,6 +30,7 @@ def blend_img2img(
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt) logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
params.model, params.model,
params.scheduler, params.scheduler,

View File

@ -6,11 +6,11 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_order from .utils import process_tile_order
@ -65,6 +65,7 @@ def blend_inpaint(
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
params.model, params.model,
params.scheduler, params.scheduler,

View File

@ -5,8 +5,8 @@ from PIL import Image
from onnx_web.output import save_image from onnx_web.output import save_image
from ..device_pool import JobContext, ProgressCallback
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -6,8 +6,8 @@ import numpy as np
from gfpgan import GFPGANer from gfpgan import GFPGANer
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -2,9 +2,9 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -4,8 +4,8 @@ from logging import getLogger
from boto3 import Session from boto3 import Session
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -2,8 +2,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -3,8 +3,8 @@ from typing import Callable
from PIL import Image from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -5,9 +5,9 @@ import torch
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -16,7 +16,7 @@ logger = getLogger(__name__)
def source_txt2img( def source_txt2img(
job: JobContext, job: JobContext,
server: ServerContext, server: ServerContext,
stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*, *,
@ -35,6 +35,7 @@ def source_txt2img(
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
params.model, params.model,
params.scheduler, params.scheduler,

View File

@ -6,11 +6,11 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_grid, process_tile_order from .utils import process_tile_grid, process_tile_order
@ -73,6 +73,7 @@ def upscale_outpaint(
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
params.model, params.model,
params.scheduler, params.scheduler,

View File

@ -7,9 +7,9 @@ from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from ..device_pool import JobContext
from ..onnx import OnnxNet from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -5,11 +5,11 @@ import torch
from diffusers import StableDiffusionUpscalePipeline from diffusers import StableDiffusionUpscalePipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, run_gc from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -19,20 +19,10 @@ from diffusers import (
) )
from ..params import DeviceParams, Size from ..params import DeviceParams, Size
from ..utils import run_gc from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
last_pipeline_instance: Any = None
last_pipeline_options: Tuple[
Optional[DiffusionPipeline],
Optional[str],
Optional[str],
Optional[str],
Optional[bool],
] = (None, None, None, None, None)
last_pipeline_scheduler: Any = None
latent_channels = 4 latent_channels = 4
latent_factor = 8 latent_factor = 8
@ -90,24 +80,42 @@ def get_tile_latents(
def load_pipeline( def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline, pipeline: DiffusionPipeline,
model: str, model: str,
scheduler_type: Any, scheduler_type: Any,
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
): ):
global last_pipeline_instance pipe_key = (pipeline, model, device.device, device.provider, lpw)
global last_pipeline_scheduler scheduler_key = (scheduler_type,)
global last_pipeline_options
options = (pipeline, model, device.device, device.provider, lpw) cache_pipe = server.cache.get("diffusion", pipe_key)
if last_pipeline_instance is not None and last_pipeline_options == options:
if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline") logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance pipe = cache_pipe
cache_scheduler = server.cache.get("scheduler", scheduler_key)
if cache_scheduler is None:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
run_gc()
else: else:
logger.debug("unloading previous diffusion pipeline") logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None server.cache.drop("diffusion", pipe_key)
last_pipeline_scheduler = None
run_gc() run_gc()
if lpw: if lpw:
@ -135,24 +143,7 @@ def load_pipeline(
if device is not None and hasattr(pipe, "to"): if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_device()) pipe = pipe.to(device.torch_device())
last_pipeline_instance = pipe server.cache.set("diffusion", pipe_key, pipe)
last_pipeline_options = options server.cache.set("scheduler", scheduler_key, scheduler)
last_pipeline_scheduler = scheduler_type
if last_pipeline_scheduler != scheduler_type:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler_type
run_gc()
return pipe return pipe

View File

@ -10,10 +10,10 @@ from onnx_web.chain import blend_mask
from onnx_web.chain.base import ChainProgress from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams from ..params import Border, ImageParams, Size, StageParams
from ..upscale import UpscaleParams, run_upscale_correction from ..server.device_pool import JobContext
from ..server.upscale import UpscaleParams, run_upscale_correction
from ..utils import ServerContext, run_gc from ..utils import ServerContext, run_gc
from .load import get_latents_from_seed, load_pipeline from .load import get_latents_from_seed, load_pipeline
@ -30,6 +30,7 @@ def run_txt2img_pipeline(
) -> None: ) -> None:
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
params.model, params.model,
params.scheduler, params.scheduler,
@ -97,6 +98,7 @@ def run_img2img_pipeline(
strength: float, strength: float,
) -> None: ) -> None:
pipe = load_pipeline( pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
params.model, params.model,
params.scheduler, params.scheduler,

View File

@ -31,7 +31,6 @@ from .chain import (
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
) )
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers from .diffusion.load import pipeline_schedulers
from .diffusion.run import ( from .diffusion.run import (
run_blend_pipeline, run_blend_pipeline,
@ -40,7 +39,6 @@ from .diffusion.run import (
run_txt2img_pipeline, run_txt2img_pipeline,
run_upscale_pipeline, run_upscale_pipeline,
) )
from .hacks import apply_patches
from .image import ( # mask filters; noise sources from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
mask_filter_gaussian_screen, mask_filter_gaussian_screen,
@ -62,6 +60,8 @@ from .params import (
TileOrder, TileOrder,
UpscaleParams, UpscaleParams,
) )
from .server.device_pool import DevicePoolExecutor
from .server.hacks import apply_patches
from .utils import ( from .utils import (
ServerContext, ServerContext,
base_join, base_join,

View File

@ -5,8 +5,8 @@ from multiprocessing import Value
from traceback import format_exception from traceback import format_exception
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
from .params import DeviceParams from ..params import DeviceParams
from .utils import run_gc from ..utils import run_gc
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,7 +7,7 @@ from urllib.parse import urlparse
import basicsr.utils.download_util import basicsr.utils.download_util
import codeformer.facelib.utils.misc import codeformer.facelib.utils.misc
from .utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -0,0 +1,44 @@
from logging import getLogger
from typing import Any, List
logger = getLogger(__name__)
class ModelCache:
cache: List[(str, Any, Any)]
limit: int
def __init__(self, limit: int) -> None:
self.limit = limit
def drop(self, tag: str, key: Any) -> None:
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache:
if tag == t and key == k:
return v
return None
def set(self, tag: str, key: Any, value: Any) -> None:
for i in range(len(self.cache)):
t, k, v = self.cache[i]
if tag == t:
if key != k:
logger.debug("Updating model cache: %s", tag)
self.cache[i] = v
return
logger.debug("Adding new model to cache: %s", tag)
self.cache.append((tag, key, value))
self.prune()
def prune(self):
total = len(self.cache)
if total > self.limit:
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
self.cache[:] = self.cache[: self.limit]
else:
logger.debug("Model cache below limit, %s of %s", total, self.limit)

View File

@ -2,16 +2,16 @@ from logging import getLogger
from PIL import Image from PIL import Image
from .chain import ( from ..chain import (
ChainPipeline, ChainPipeline,
correct_codeformer, correct_codeformer,
correct_gfpgan, correct_gfpgan,
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
) )
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..utils import ServerContext
from .device_pool import JobContext, ProgressCallback from .device_pool import JobContext, ProgressCallback
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from .params import SizeChart from .params import SizeChart
from .server.model_cache import ModelCache
logger = getLogger(__name__) logger = getLogger(__name__)
@ -23,6 +24,7 @@ class ServerContext:
block_platforms: List[str] = [], block_platforms: List[str] = [],
default_platform: str = None, default_platform: str = None,
image_format: str = "png", image_format: str = "png",
cache: ModelCache = None,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -34,6 +36,7 @@ class ServerContext:
self.block_platforms = block_platforms self.block_platforms = block_platforms
self.default_platform = default_platform self.default_platform = default_platform
self.image_format = image_format self.image_format = image_format
self.cache = cache or ModelCache()
@classmethod @classmethod
def from_environ(cls): def from_environ(cls):
@ -51,6 +54,7 @@ class ServerContext:
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 3))),
) )