feat(api): enable optimizations for SD pipelines based on env vars (#155)
This commit is contained in:
parent
ff57527274
commit
ab6462d095
|
@ -5,6 +5,7 @@ import torch
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..diffusion.load import optimize_pipeline
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
|
@ -52,6 +53,8 @@ def load_stable_diffusion(
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
pipe.set_progress_bar_config(disable=True)
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
optimize_pipeline(server, pipe)
|
||||||
|
|
||||||
server.cache.set("diffusion", cache_key, pipe)
|
server.cache.set("diffusion", cache_key, pipe)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from diffusers import (
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -87,6 +88,32 @@ def get_tile_latents(
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_pipeline(
|
||||||
|
server: ServerContext,
|
||||||
|
pipe: StableDiffusionPipeline,
|
||||||
|
) -> None:
|
||||||
|
if "attention-slicing" in server.optimizations:
|
||||||
|
logger.debug("enabling attention slicing on SD pipeline")
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
|
if "vae-slicing" in server.optimizations:
|
||||||
|
logger.debug("enabling VAE slicing on SD pipeline")
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
|
if "sequential-cpu-offload" in server.optimizations:
|
||||||
|
logger.debug("enabling sequential CPU offload on SD pipeline")
|
||||||
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
elif "model-cpu-offload" in server.optimizations:
|
||||||
|
# TODO: check for accelerate
|
||||||
|
logger.debug("enabling model CPU offload on SD pipeline")
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
if "memory-efficient-attention" in server.optimizations:
|
||||||
|
# TODO: check for xformers
|
||||||
|
logger.debug("enabling memory efficient attention for SD pipeline")
|
||||||
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipeline: DiffusionPipeline,
|
pipeline: DiffusionPipeline,
|
||||||
|
@ -151,6 +178,8 @@ def load_pipeline(
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
pipe.set_progress_bar_config(disable=True)
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
optimize_pipeline(server, pipe)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, "to"):
|
if device is not None and hasattr(pipe, "to"):
|
||||||
pipe = pipe.to(device.torch_str())
|
pipe = pipe.to(device.torch_str())
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ class ServerContext:
|
||||||
cache: ModelCache = None,
|
cache: ModelCache = None,
|
||||||
cache_path: str = None,
|
cache_path: str = None,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
|
optimizations: List[str] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -42,6 +43,7 @@ class ServerContext:
|
||||||
self.cache = cache or ModelCache(num_workers)
|
self.cache = cache or ModelCache(num_workers)
|
||||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
|
self.optimizations = optimizations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -64,6 +66,7 @@ class ServerContext:
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
cache=ModelCache(limit=cache_limit),
|
cache=ModelCache(limit=cache_limit),
|
||||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||||
|
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue