From ab6462d095fe2305fd5a8e6ab9ac794662880a99 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 11:53:13 -0600 Subject: [PATCH] feat(api): enable optimizations for SD pipelines based on env vars (#155) --- .../chain/upscale_stable_diffusion.py | 3 ++ api/onnx_web/diffusion/load.py | 29 +++++++++++++++++++ api/onnx_web/utils.py | 3 ++ 3 files changed, 35 insertions(+) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index c2032bfb..ffdb0036 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -5,6 +5,7 @@ import torch from diffusers import StableDiffusionUpscalePipeline from PIL import Image +from ..diffusion.load import optimize_pipeline from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) @@ -52,6 +53,8 @@ def load_stable_diffusion( if not server.show_progress: pipe.set_progress_bar_config(disable=True) + optimize_pipeline(server, pipe) + server.cache.set("diffusion", cache_key, pipe) run_gc([device]) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index dceb509f..53cd6af0 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -17,6 +17,7 @@ from diffusers import ( KDPM2DiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, + StableDiffusionPipeline, ) try: @@ -87,6 +88,32 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] +def optimize_pipeline( + server: ServerContext, + pipe: StableDiffusionPipeline, +) -> None: + if "attention-slicing" in server.optimizations: + logger.debug("enabling attention slicing on SD pipeline") + pipe.enable_attention_slicing() + + if "vae-slicing" in server.optimizations: + logger.debug("enabling VAE slicing on SD pipeline") + pipe.enable_vae_slicing() + + if "sequential-cpu-offload" in server.optimizations: + logger.debug("enabling sequential CPU offload on SD pipeline") + pipe.enable_sequential_cpu_offload() + elif "model-cpu-offload" in server.optimizations: + # TODO: check for accelerate + logger.debug("enabling model CPU offload on SD pipeline") + pipe.enable_model_cpu_offload() + + if "memory-efficient-attention" in server.optimizations: + # TODO: check for xformers + logger.debug("enabling memory efficient attention for SD pipeline") + pipe.enable_xformers_memory_efficient_attention() + + def load_pipeline( server: ServerContext, pipeline: DiffusionPipeline, @@ -151,6 +178,8 @@ def load_pipeline( if not server.show_progress: pipe.set_progress_bar_config(disable=True) + optimize_pipeline(server, pipe) + if device is not None and hasattr(pipe, "to"): pipe = pipe.to(device.torch_str()) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 59860639..20d9d32c 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -28,6 +28,7 @@ class ServerContext: cache: ModelCache = None, cache_path: str = None, show_progress: bool = True, + optimizations: List[str] = [], ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -42,6 +43,7 @@ class ServerContext: self.cache = cache or ModelCache(num_workers) self.cache_path = cache_path or path.join(model_path, ".cache") self.show_progress = show_progress + self.optimizations = optimizations @classmethod def from_environ(cls): @@ -64,6 +66,7 @@ class ServerContext: image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), cache=ModelCache(limit=cache_limit), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), + optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), )