From f5506b17f078aa14f7cce705a5f0c47e5af29823 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 12 Jan 2024 18:58:26 -0600 Subject: [PATCH] feat(api): support CPU model offloading specifically for SDXL --- api/onnx_web/chain/persist_disk.py | 6 +++--- api/onnx_web/chain/result.py | 2 +- api/onnx_web/chain/upscale_bsrgan.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_swinir.py | 2 +- api/onnx_web/diffusers/load.py | 16 +++++++------- api/onnx_web/output.py | 2 +- api/onnx_web/params.py | 21 +++++++++++++------ api/onnx_web/prompt/base.py | 30 +++++++++++++++++++++++++++ 9 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 api/onnx_web/prompt/base.py diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 825fabf5..2afaed00 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,10 +1,10 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image -from ..output import save_image, save_result -from ..params import ImageParams, Size, SizeChart, StageParams +from ..output import save_result +from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index e8b40ea3..a07f82f9 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -229,7 +229,7 @@ class ImageMetadata: if self.models is not None: for name, weight in self.models: - name, hash = self.get_model_hash() + name, hash = self.get_model_hash(server) json["models"].append({"name": name, "weight": weight, "hash": hash}) return json diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 80f0af5f..3d410992 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -47,7 +47,7 @@ class UpscaleBSRGANStage(BaseStage): pipe = OnnxModel( server, model_path, - provider=device.ort_provider(), + provider=device.ort_provider("bsrgan"), sess_options=device.sess_options(), ) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 4d165ab1..51f0a5ae 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -64,7 +64,7 @@ class UpscaleRealESRGANStage(BaseStage): model = OnnxRRDBNet( server, model_file, - provider=device.ort_provider(), + provider=device.ort_provider("esrgan"), sess_options=device.sess_options(), ) diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index 7d55d9b1..cab60078 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -40,7 +40,7 @@ class UpscaleSwinIRStage(BaseStage): pipe = OnnxModel( server, model_path, - provider=device.ort_provider(), + provider=device.ort_provider("swinir"), sess_options=device.sess_options(), ) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 9e8a4e50..bafdd032 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -158,7 +158,7 @@ def load_pipeline( logger.debug("loading new diffusion scheduler") scheduler = scheduler_type.from_pretrained( model, - provider=device.ort_provider(), + provider=device.ort_provider("scheduler"), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, @@ -179,7 +179,7 @@ def load_pipeline( logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler_type.from_pretrained( model, - provider=device.ort_provider(), + provider=device.ort_provider("scheduler"), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, @@ -320,7 +320,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa components["controlnet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( cnet_path, - provider=device.ort_provider(), + provider=device.ort_provider("controlnet"), sess_options=device.sess_options(), ) ) @@ -448,7 +448,7 @@ def load_text_encoders( # session for te1 text_encoder_session = InferenceSession( text_encoder.SerializeToString(), - providers=[device.ort_provider("text-encoder")], + providers=[device.ort_provider("text-encoder", "sdxl")], sess_options=text_encoder_opts, ) text_encoder_session._model_path = path.join(model, "text_encoder") @@ -457,7 +457,7 @@ def load_text_encoders( # session for te2 text_encoder_2_session = InferenceSession( text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], + providers=[device.ort_provider("text-encoder", "sdxl")], sess_options=text_encoder_2_opts, ) text_encoder_2_session._model_path = path.join(model, "text_encoder_2") @@ -511,7 +511,7 @@ def load_unet( if params.is_xl(): unet_session = InferenceSession( unet_model.SerializeToString(), - providers=[device.ort_provider("unet")], + providers=[device.ort_provider("unet", "sdxl")], sess_options=unet_opts, ) unet_session._model_path = path.join(model, "unet") @@ -551,7 +551,7 @@ def load_vae( logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder_session"] = OnnxRuntimeModel.load_model( vae_decoder, - provider=device.ort_provider("vae"), + provider=device.ort_provider("vae", "sdxl"), sess_options=device.sess_options(), ) components["vae_decoder_session"]._model_path = vae_decoder @@ -559,7 +559,7 @@ def load_vae( logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder_session"] = OnnxRuntimeModel.load_model( vae_encoder, - provider=device.ort_provider("vae"), + provider=device.ort_provider("vae", "sdxl"), sess_options=device.sess_options(), ) components["vae_encoder_session"]._model_path = vae_encoder diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 49106ac7..9404df16 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -84,7 +84,7 @@ def save_image( server: ServerContext, output: str, image: Image.Image, - metadata: ImageMetadata, + metadata: Optional[ImageMetadata] = None, ) -> str: path = base_join(server.output_path, output) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index b78f2421..6662e067 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -136,13 +136,22 @@ class DeviceParams: return "%s - %s (%s)" % (self.device, self.provider, self.options) def ort_provider( - self, model_type: Optional[str] = None + self, + model_type: str, + suffix: Optional[str] = None, ) -> Union[str, Tuple[str, Any]]: - if model_type is not None: - # check if model has been pinned to CPU - # TODO: check whether the CPU device is allowed - if f"onnx-cpu-{model_type}" in self.optimizations: - return "CPUExecutionProvider" + # check if model has been pinned to CPU + # TODO: check whether the CPU device is allowed + if f"onnx-cpu-{model_type}" in self.optimizations: + logger.debug("pinning %s to CPU", model_type) + return "CPUExecutionProvider" + + if ( + suffix is not None + and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations + ): + logger.debug("pinning %s-%s to CPU", model_type, suffix) + return "CPUExecutionProvider" if self.options is None: return self.provider diff --git a/api/onnx_web/prompt/base.py b/api/onnx_web/prompt/base.py new file mode 100644 index 00000000..42cb11f7 --- /dev/null +++ b/api/onnx_web/prompt/base.py @@ -0,0 +1,30 @@ +from typing import List, Optional + + +class NetworkWeight: + pass + + +class PromptRegion: + pass + + +class PromptSeed: + pass + + +class StructuredPrompt: + prompt: str + negative_prompt: Optional[str] + networks: List[NetworkWeight] + region_prompts: List[PromptRegion] + region_seeds: List[PromptSeed] + + def __init__( + self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight] + ) -> None: + self.prompt = prompt + self.negative_prompt = negative_prompt + self.networks = networks or [] + self.region_prompts = [] + self.region_seeds = []