1
0
Fork 0

feat(api): enable optimizations for SD pipelines based on env vars (#155)

This commit is contained in:
Sean Sube 2023-02-18 11:53:13 -06:00
parent ff57527274
commit ab6462d095
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 35 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import torch
from diffusers import StableDiffusionUpscalePipeline from diffusers import StableDiffusionUpscalePipeline
from PIL import Image from PIL import Image
from ..diffusion.load import optimize_pipeline
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
@ -52,6 +53,8 @@ def load_stable_diffusion(
if not server.show_progress: if not server.show_progress:
pipe.set_progress_bar_config(disable=True) pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
server.cache.set("diffusion", cache_key, pipe) server.cache.set("diffusion", cache_key, pipe)
run_gc([device]) run_gc([device])

View File

@ -17,6 +17,7 @@ from diffusers import (
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline,
) )
try: try:
@ -87,6 +88,32 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if "attention-slicing" in server.optimizations:
logger.debug("enabling attention slicing on SD pipeline")
pipe.enable_attention_slicing()
if "vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
pipe.enable_vae_slicing()
if "sequential-cpu-offload" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
pipe.enable_sequential_cpu_offload()
elif "model-cpu-offload" in server.optimizations:
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
pipe.enable_model_cpu_offload()
if "memory-efficient-attention" in server.optimizations:
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
pipe.enable_xformers_memory_efficient_attention()
def load_pipeline( def load_pipeline(
server: ServerContext, server: ServerContext,
pipeline: DiffusionPipeline, pipeline: DiffusionPipeline,
@ -151,6 +178,8 @@ def load_pipeline(
if not server.show_progress: if not server.show_progress:
pipe.set_progress_bar_config(disable=True) pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
if device is not None and hasattr(pipe, "to"): if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str()) pipe = pipe.to(device.torch_str())

View File

@ -28,6 +28,7 @@ class ServerContext:
cache: ModelCache = None, cache: ModelCache = None,
cache_path: str = None, cache_path: str = None,
show_progress: bool = True, show_progress: bool = True,
optimizations: List[str] = [],
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -42,6 +43,7 @@ class ServerContext:
self.cache = cache or ModelCache(num_workers) self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache") self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress self.show_progress = show_progress
self.optimizations = optimizations
@classmethod @classmethod
def from_environ(cls): def from_environ(cls):
@ -64,6 +66,7 @@ class ServerContext:
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit), cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
) )