1
0
Fork 0

feat(api): enable optimizations for SD pipelines based on env vars (#155)

This commit is contained in:
Sean Sube 2023-02-18 11:53:13 -06:00
parent ff57527274
commit ab6462d095
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 35 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import torch
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image
from ..diffusion.load import optimize_pipeline
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
@ -52,6 +53,8 @@ def load_stable_diffusion(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
server.cache.set("diffusion", cache_key, pipe)
run_gc([device])

View File

@ -17,6 +17,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)
try:
@ -87,6 +88,32 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]
def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if "attention-slicing" in server.optimizations:
logger.debug("enabling attention slicing on SD pipeline")
pipe.enable_attention_slicing()
if "vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
pipe.enable_vae_slicing()
if "sequential-cpu-offload" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
pipe.enable_sequential_cpu_offload()
elif "model-cpu-offload" in server.optimizations:
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
pipe.enable_model_cpu_offload()
if "memory-efficient-attention" in server.optimizations:
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
pipe.enable_xformers_memory_efficient_attention()
def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
@ -151,6 +178,8 @@ def load_pipeline(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())

View File

@ -28,6 +28,7 @@ class ServerContext:
cache: ModelCache = None,
cache_path: str = None,
show_progress: bool = True,
optimizations: List[str] = [],
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -42,6 +43,7 @@ class ServerContext:
self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
self.optimizations = optimizations
@classmethod
def from_environ(cls):
@ -64,6 +66,7 @@ class ServerContext:
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
)