feat(api): enable optimizations for SD pipelines based on env vars (#155)
This commit is contained in:
parent
ff57527274
commit
ab6462d095
|
@ -5,6 +5,7 @@ import torch
|
|||
from diffusers import StableDiffusionUpscalePipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusion.load import optimize_pipeline
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
|
@ -52,6 +53,8 @@ def load_stable_diffusion(
|
|||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
optimize_pipeline(server, pipe)
|
||||
|
||||
server.cache.set("diffusion", cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from diffusers import (
|
|||
KDPM2DiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -87,6 +88,32 @@ def get_tile_latents(
|
|||
return full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
|
||||
def optimize_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
) -> None:
|
||||
if "attention-slicing" in server.optimizations:
|
||||
logger.debug("enabling attention slicing on SD pipeline")
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
if "vae-slicing" in server.optimizations:
|
||||
logger.debug("enabling VAE slicing on SD pipeline")
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
if "sequential-cpu-offload" in server.optimizations:
|
||||
logger.debug("enabling sequential CPU offload on SD pipeline")
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
elif "model-cpu-offload" in server.optimizations:
|
||||
# TODO: check for accelerate
|
||||
logger.debug("enabling model CPU offload on SD pipeline")
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
if "memory-efficient-attention" in server.optimizations:
|
||||
# TODO: check for xformers
|
||||
logger.debug("enabling memory efficient attention for SD pipeline")
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
|
||||
def load_pipeline(
|
||||
server: ServerContext,
|
||||
pipeline: DiffusionPipeline,
|
||||
|
@ -151,6 +178,8 @@ def load_pipeline(
|
|||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
optimize_pipeline(server, pipe)
|
||||
|
||||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device.torch_str())
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class ServerContext:
|
|||
cache: ModelCache = None,
|
||||
cache_path: str = None,
|
||||
show_progress: bool = True,
|
||||
optimizations: List[str] = [],
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -42,6 +43,7 @@ class ServerContext:
|
|||
self.cache = cache or ModelCache(num_workers)
|
||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||
self.show_progress = show_progress
|
||||
self.optimizations = optimizations
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -64,6 +66,7 @@ class ServerContext:
|
|||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
cache=ModelCache(limit=cache_limit),
|
||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue