from logging import getLogger from os import path from typing import Any, List, Literal, Optional, Tuple from onnx import load_model from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLPipeline, ) from transformers import CLIPTokenizer from ..constants import LATENT_FACTOR, ONNX_MODEL from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ..diffusers.utils import expand_prompt from ..params import DeviceParams, ImageParams from ..server import ModelTypes, ServerContext from ..torch_before_ort import InferenceSession from ..utils import run_gc from .patches.unet import UNetWrapper from .patches.vae import VAEWrapper from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline from .version_safe_diffusers import ( DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, LMSDiscreteScheduler, OnnxRuntimeModel, OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionPipeline, PNDMScheduler, StableDiffusionPipeline, UniPCMultistepScheduler, ) logger = getLogger(__name__) available_pipelines = { "controlnet": OnnxStableDiffusionControlNetPipeline, "img2img": OnnxStableDiffusionImg2ImgPipeline, "img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline, "inpaint": OnnxStableDiffusionInpaintPipeline, # "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline, "lpw": OnnxStableDiffusionLongPromptWeightingPipeline, "panorama": OnnxStableDiffusionPanoramaPipeline, "panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline, "pix2pix": OnnxStableDiffusionInstructPix2PixPipeline, "txt2img-sdxl": ORTStableDiffusionXLPipeline, "txt2img": OnnxStableDiffusionPipeline, "upscale": OnnxStableDiffusionUpscalePipeline, } pipeline_schedulers = { "ddim": DDIMScheduler, "ddpm": DDPMScheduler, "deis-multi": DEISMultistepScheduler, "dpm-multi": DPMSolverMultistepScheduler, "dpm-sde": DPMSolverSDEScheduler, "dpm-single": DPMSolverSinglestepScheduler, "euler": EulerDiscreteScheduler, "euler-a": EulerAncestralDiscreteScheduler, "heun": HeunDiscreteScheduler, "ipndm": IPNDMScheduler, "k-dpm-2-a": KDPM2AncestralDiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler, "lcm": LCMScheduler, "lms-discrete": LMSDiscreteScheduler, "pndm": PNDMScheduler, "unipc-multi": UniPCMultistepScheduler, } def add_pipeline(name: str, pipeline: Any) -> bool: global available_pipelines if name in available_pipelines: # TODO: decide if this should be allowed or not logger.warning("cannot replace existing pipeline: %s", name) return False else: available_pipelines[name] = pipeline return True def get_available_pipelines() -> List[str]: return list(available_pipelines.keys()) def get_pipeline_schedulers() -> List[str]: return list(pipeline_schedulers.keys()) def get_scheduler_name(scheduler: Any) -> Optional[str]: for k, v in pipeline_schedulers.items(): if scheduler == v or scheduler == v.__name__: return k return None VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"] def load_pipeline( server: ServerContext, params: ImageParams, pipeline: str, device: DeviceParams, embeddings: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, model: Optional[str] = None, ): embeddings = embeddings or [] loras = loras or [] model = model or params.model torch_dtype = server.torch_dtype() logger.debug("using Torch dtype %s for pipeline", torch_dtype) control_key = params.control.name if params.control is not None else None pipe_key = ( pipeline, model, device.device, device.provider, control_key, embeddings, loras, ) scheduler_key = (params.scheduler, model) scheduler_type = pipeline_schedulers[params.scheduler] cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key) if cache_pipe is not None: logger.debug("reusing existing diffusion pipeline") pipe = cache_pipe # update scheduler cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key) if cache_scheduler is None: logger.debug("loading new diffusion scheduler") scheduler = scheduler_type.from_pretrained( model, provider=device.ort_provider(), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, ) if device is not None and hasattr(scheduler, "to"): scheduler = scheduler.to(device.torch_str()) pipe.scheduler = scheduler server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler) run_gc([device]) else: if server.cache.drop("diffusion", pipe_key) > 0: logger.debug("unloading previous diffusion pipeline") run_gc([device]) logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler_type.from_pretrained( model, provider=device.ort_provider(), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, ) components = { "scheduler": scheduler, } # shared components unet_type = "unet" # ControlNet component if params.is_control() and params.control is not None: logger.debug("loading ControlNet components") control_components = load_controlnet(server, device, params) components.update(control_components) unet_type = "cnet" # load various pipeline components encoder_components = load_text_encoders( server, device, model, embeddings, loras, torch_dtype, params ) components.update(encoder_components) unet_components = load_unet(server, device, model, loras, unet_type, params) components.update(unet_components) vae_components = load_vae(server, device, model, params) components.update(vae_components) pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) if params.is_xl(): logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__) pipe = pipeline_class( components["vae_decoder_session"], components["text_encoder_session"], components["unet_session"], { "force_zeros_for_empty_prompt": True, "requires_aesthetics_score": False, }, components["tokenizer"], scheduler, vae_encoder_session=components.get("vae_encoder_session", None), text_encoder_2_session=components.get("text_encoder_2_session", None), tokenizer_2=components.get("tokenizer_2", None), ) else: if "controlnet" in components: logger.debug( "assembling SD pipeline for %s with ControlNet", pipeline_class.__name__, ) pipe = pipeline_class( components["vae_encoder"], components["vae_decoder"], components["text_encoder"], components["tokenizer"], components["unet"], components["controlnet"], scheduler, None, None, requires_safety_checker=False, ) elif "vae" in components: # upscale uses a single VAE logger.debug( "assembling SD pipeline for %s with single VAE", pipeline_class.__name__, ) pipe = pipeline_class( components["vae"], components["text_encoder"], components["tokenizer"], components["unet"], scheduler, scheduler, ) else: logger.debug( "assembling SD pipeline for %s with VAE codec", pipeline_class.__name__, ) pipe = pipeline_class( components["vae_encoder"], components["vae_decoder"], components["text_encoder"], components["tokenizer"], components["unet"], scheduler, None, None, requires_safety_checker=False, ) if not server.show_progress: pipe.set_progress_bar_config(disable=True) optimize_pipeline(server, pipe) patch_pipeline(server, pipe, pipeline_class, params) server.cache.set(ModelTypes.diffusion, pipe_key, pipe) server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler) for vae in VAE_COMPONENTS: if hasattr(pipe, vae): vae_model = getattr(pipe, vae) if isinstance(vae_model, VAEWrapper): vae_model.set_tiled(tiled=params.tiled_vae) vae_model.set_window_size( params.vae_tile // LATENT_FACTOR, params.vae_overlap ) # update panorama params if params.is_panorama(): unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR logger.debug( "setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap, ) pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride) run_gc([device]) return pipe def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams): cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx") logger.debug("loading ControlNet weights from %s", cnet_path) components = {} components["controlnet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( cnet_path, provider=device.ort_provider(), sess_options=device.sess_options(), ) ) return components def load_text_encoders( server: ServerContext, device: DeviceParams, model: str, embeddings: Optional[List[Tuple[str, float]]], loras: Optional[List[Tuple[str, float]]], torch_dtype, params: ImageParams, ): text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) tokenizer = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer", torch_dtype=torch_dtype, ) components = { "tokenizer": tokenizer, } if params.is_xl(): text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) tokenizer_2 = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer_2", torch_dtype=torch_dtype, ) components["tokenizer_2"] = tokenizer_2 # blend embeddings, if any if embeddings is not None and len(embeddings) > 0: embedding_names, embedding_weights = zip(*embeddings) embedding_models = [ path.join(server.model_path, "inversion", name) for name in embedding_names ] logger.debug( "blending base model %s with embeddings from %s", model, embedding_models ) # TODO: blend text_encoder_2 as well text_encoder, tokenizer = blend_textual_inversions( server, text_encoder, tokenizer, list( zip( embedding_models, embedding_weights, embedding_names, [None] * len(embedding_models), ) ), ) components["tokenizer"] = tokenizer if params.is_xl(): text_encoder_2, tokenizer_2 = blend_textual_inversions( server, text_encoder_2, tokenizer_2, list( zip( embedding_models, embedding_weights, embedding_names, [None] * len(embedding_models), ) ), ) components["tokenizer_2"] = tokenizer_2 # blend LoRAs, if any if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ path.join(server.model_path, "lora", name) for name in lora_names ] logger.info("blending base model %s with LoRAs from %s", model, lora_models) # blend and load text encoder text_encoder = blend_loras( server, text_encoder, list(zip(lora_models, lora_weights)), "text_encoder", 1 if params.is_xl() else None, params.is_xl(), ) if params.is_xl(): text_encoder_2 = blend_loras( server, text_encoder_2, list(zip(lora_models, lora_weights)), "text_encoder", 2, params.is_xl(), ) # prepare external data for sessions (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder) text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_opts = device.sess_options(cache=False) text_encoder_opts.add_external_initializers( list(text_encoder_names), list(text_encoder_values) ) if params.is_xl(): # encoder 2 only exists in XL (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( text_encoder_2 ) text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) text_encoder_2_opts = device.sess_options(cache=False) text_encoder_2_opts.add_external_initializers( list(text_encoder_2_names), list(text_encoder_2_values) ) # session for te1 text_encoder_session = InferenceSession( text_encoder.SerializeToString(), providers=[device.ort_provider("text-encoder")], sess_options=text_encoder_opts, ) text_encoder_session._model_path = path.join(model, "text_encoder") components["text_encoder_session"] = text_encoder_session # session for te2 text_encoder_2_session = InferenceSession( text_encoder_2.SerializeToString(), providers=[device.ort_provider("text-encoder")], sess_options=text_encoder_2_opts, ) text_encoder_2_session._model_path = path.join(model, "text_encoder_2") components["text_encoder_2_session"] = text_encoder_2_session else: # session for te components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( text_encoder.SerializeToString(), provider=device.ort_provider("text-encoder"), sess_options=text_encoder_opts, ) ) return components def load_unet( server: ServerContext, device: DeviceParams, model: str, loras: List[Tuple[str, float]], unet_type: Literal["cnet", "unet"], params: ImageParams, ): components = {} unet = load_model(path.join(model, unet_type, ONNX_MODEL)) # LoRA blending if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ path.join(server.model_path, "lora", name) for name in lora_names ] logger.info("blending base model %s with LoRA models: %s", model, lora_models) # blend and load unet unet = blend_loras( server, unet, list(zip(lora_models, lora_weights)), "unet", xl=params.is_xl(), ) (unet_model, unet_data) = buffer_external_data_tensors(unet) unet_names, unet_values = zip(*unet_data) unet_opts = device.sess_options(cache=False) unet_opts.add_external_initializers(list(unet_names), list(unet_values)) if params.is_xl(): unet_session = InferenceSession( unet_model.SerializeToString(), providers=[device.ort_provider("unet")], sess_options=unet_opts, ) unet_session._model_path = path.join(model, "unet") components["unet_session"] = unet_session else: components["unet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( unet_model.SerializeToString(), provider=device.ort_provider("unet"), sess_options=unet_opts, ) ) return components def load_vae( _server: ServerContext, device: DeviceParams, model: str, params: ImageParams ): # one or more VAE models need to be loaded vae = path.join(model, "vae", ONNX_MODEL) vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) components = {} if not params.is_xl() and path.exists(vae): logger.debug("loading VAE from %s", vae) components["vae"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( vae, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) ) elif path.exists(vae_decoder) and path.exists(vae_encoder): if params.is_xl(): logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder_session"] = OnnxRuntimeModel.load_model( vae_decoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) components["vae_decoder_session"]._model_path = vae_decoder logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder_session"] = OnnxRuntimeModel.load_model( vae_encoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) components["vae_encoder_session"]._model_path = vae_encoder else: logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( vae_decoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) ) logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( vae_encoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) ) return components def optimize_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, ) -> None: if server.has_optimization( "diffusers-attention-slicing" ) or server.has_optimization("diffusers-attention-slicing-auto"): logger.debug("enabling auto attention slicing on SD pipeline") try: pipe.enable_attention_slicing(slice_size="auto") except Exception as e: logger.warning("error while enabling auto attention slicing: %s", e) if server.has_optimization("diffusers-attention-slicing-max"): logger.debug("enabling max attention slicing on SD pipeline") try: pipe.enable_attention_slicing(slice_size="max") except Exception as e: logger.warning("error while enabling max attention slicing: %s", e) if server.has_optimization("diffusers-vae-slicing"): logger.debug("enabling VAE slicing on SD pipeline") try: pipe.enable_vae_slicing() except Exception as e: logger.warning("error while enabling VAE slicing: %s", e) if server.has_optimization("diffusers-cpu-offload-sequential"): logger.debug("enabling sequential CPU offload on SD pipeline") try: pipe.enable_sequential_cpu_offload() except Exception as e: logger.warning("error while enabling sequential CPU offload: %s", e) elif server.has_optimization("diffusers-cpu-offload-model"): # TODO: check for accelerate logger.debug("enabling model CPU offload on SD pipeline") try: pipe.enable_model_cpu_offload() except Exception as e: logger.warning("error while enabling model CPU offload: %s", e) if server.has_optimization("diffusers-memory-efficient-attention"): # TODO: check for xformers logger.debug("enabling memory efficient attention for SD pipeline") try: pipe.enable_xformers_memory_efficient_attention() except Exception as e: logger.warning("error while enabling memory efficient attention: %s", e) def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, pipeline: Any, params: ImageParams, ) -> None: logger.debug("patching SD pipeline") if not params.is_lpw() and not params.is_xl(): pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) original_unet = pipe.unet pipe.unet = UNetWrapper(server, original_unet, params.is_xl()) logger.debug("patched UNet with wrapper") if hasattr(pipe, "vae_decoder"): original_decoder = pipe.vae_decoder pipe.vae_decoder = VAEWrapper( server, original_decoder, decoder=True, window=params.unet_tile, overlap=params.vae_overlap, ) logger.debug("patched VAE decoder with wrapper") if hasattr(pipe, "vae_encoder"): original_encoder = pipe.vae_encoder pipe.vae_encoder = VAEWrapper( server, original_encoder, decoder=False, window=params.unet_tile, overlap=params.vae_overlap, ) logger.debug("patched VAE encoder with wrapper") if hasattr(pipe, "vae"): logger.warning("not patching single VAE, tiled VAE may not work")