2023-01-29 02:15:39 +00:00
|
|
|
from logging import getLogger
|
2023-02-22 05:16:13 +00:00
|
|
|
from os import path
|
2023-09-24 23:15:58 +00:00
|
|
|
from typing import Any, List, Literal, Optional, Tuple
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-04-15 19:32:22 +00:00
|
|
|
from onnx import load_model
|
2023-09-05 22:36:17 +00:00
|
|
|
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
2023-08-06 15:07:39 +00:00
|
|
|
ORTStableDiffusionXLImg2ImgPipeline,
|
|
|
|
ORTStableDiffusionXLPipeline,
|
|
|
|
)
|
2023-04-15 19:32:22 +00:00
|
|
|
from transformers import CLIPTokenizer
|
|
|
|
|
2023-12-03 17:11:23 +00:00
|
|
|
from ..constants import LATENT_FACTOR, ONNX_MODEL
|
2023-04-15 19:32:22 +00:00
|
|
|
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
|
|
|
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
2023-04-23 22:33:13 +00:00
|
|
|
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
2023-12-03 17:11:23 +00:00
|
|
|
from ..diffusers.utils import expand_prompt
|
2023-05-02 04:20:40 +00:00
|
|
|
from ..params import DeviceParams, ImageParams
|
2023-07-03 16:33:56 +00:00
|
|
|
from ..server import ModelTypes, ServerContext
|
2023-08-26 04:36:30 +00:00
|
|
|
from ..torch_before_ort import InferenceSession
|
2023-04-15 19:32:22 +00:00
|
|
|
from ..utils import run_gc
|
2023-04-28 03:50:11 +00:00
|
|
|
from .patches.unet import UNetWrapper
|
|
|
|
from .patches.vae import VAEWrapper
|
2023-04-15 19:32:22 +00:00
|
|
|
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
|
|
|
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
2023-04-26 12:35:20 +00:00
|
|
|
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
2023-09-10 16:26:18 +00:00
|
|
|
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
2023-04-15 19:32:22 +00:00
|
|
|
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
|
|
|
from .version_safe_diffusers import (
|
2023-02-12 15:51:35 +00:00
|
|
|
DDIMScheduler,
|
|
|
|
DDPMScheduler,
|
2023-04-15 19:32:22 +00:00
|
|
|
DEISMultistepScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
DPMSolverMultistepScheduler,
|
2023-12-10 05:51:09 +00:00
|
|
|
DPMSolverSDEScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
DPMSolverSinglestepScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
|
|
|
EulerDiscreteScheduler,
|
|
|
|
HeunDiscreteScheduler,
|
2023-02-14 13:27:51 +00:00
|
|
|
IPNDMScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
KDPM2AncestralDiscreteScheduler,
|
|
|
|
KDPM2DiscreteScheduler,
|
2023-11-22 05:14:15 +00:00
|
|
|
LCMScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
LMSDiscreteScheduler,
|
2023-02-26 20:15:30 +00:00
|
|
|
OnnxRuntimeModel,
|
2023-04-15 20:43:31 +00:00
|
|
|
OnnxStableDiffusionImg2ImgPipeline,
|
|
|
|
OnnxStableDiffusionInpaintPipeline,
|
2023-04-13 03:58:48 +00:00
|
|
|
OnnxStableDiffusionPipeline,
|
2023-02-12 15:51:35 +00:00
|
|
|
PNDMScheduler,
|
2023-02-26 20:15:30 +00:00
|
|
|
StableDiffusionPipeline,
|
2023-04-15 19:32:22 +00:00
|
|
|
UniPCMultistepScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
)
|
2023-01-29 02:15:39 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-04-13 03:58:48 +00:00
|
|
|
available_pipelines = {
|
|
|
|
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
2023-04-15 20:43:31 +00:00
|
|
|
"img2img": OnnxStableDiffusionImg2ImgPipeline,
|
2023-08-06 15:05:03 +00:00
|
|
|
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
|
2023-04-15 20:43:31 +00:00
|
|
|
"inpaint": OnnxStableDiffusionInpaintPipeline,
|
2023-09-05 22:05:27 +00:00
|
|
|
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
|
2023-04-13 03:58:48 +00:00
|
|
|
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
|
2023-04-26 12:35:20 +00:00
|
|
|
"panorama": OnnxStableDiffusionPanoramaPipeline,
|
2023-09-10 16:26:18 +00:00
|
|
|
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
|
2023-04-13 03:58:48 +00:00
|
|
|
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
|
2023-08-06 15:05:03 +00:00
|
|
|
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
|
2023-04-15 20:43:31 +00:00
|
|
|
"txt2img": OnnxStableDiffusionPipeline,
|
2023-04-23 22:33:13 +00:00
|
|
|
"upscale": OnnxStableDiffusionUpscalePipeline,
|
2023-04-13 03:58:48 +00:00
|
|
|
}
|
|
|
|
|
2023-02-12 15:51:35 +00:00
|
|
|
pipeline_schedulers = {
|
|
|
|
"ddim": DDIMScheduler,
|
|
|
|
"ddpm": DDPMScheduler,
|
2023-02-14 04:37:54 +00:00
|
|
|
"deis-multi": DEISMultistepScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
"dpm-multi": DPMSolverMultistepScheduler,
|
2023-12-10 05:51:09 +00:00
|
|
|
"dpm-sde": DPMSolverSDEScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
"dpm-single": DPMSolverSinglestepScheduler,
|
|
|
|
"euler": EulerDiscreteScheduler,
|
|
|
|
"euler-a": EulerAncestralDiscreteScheduler,
|
|
|
|
"heun": HeunDiscreteScheduler,
|
2023-02-14 13:27:51 +00:00
|
|
|
"ipndm": IPNDMScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
|
|
|
"k-dpm-2": KDPM2DiscreteScheduler,
|
2023-11-22 05:23:46 +00:00
|
|
|
"lcm": LCMScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
"lms-discrete": LMSDiscreteScheduler,
|
|
|
|
"pndm": PNDMScheduler,
|
2023-03-09 14:00:59 +00:00
|
|
|
"unipc-multi": UniPCMultistepScheduler,
|
2023-02-12 15:51:35 +00:00
|
|
|
}
|
|
|
|
|
2023-02-12 18:33:36 +00:00
|
|
|
|
2023-11-18 23:20:13 +00:00
|
|
|
def add_pipeline(name: str, pipeline: Any) -> bool:
|
|
|
|
global available_pipelines
|
|
|
|
|
|
|
|
if name in available_pipelines:
|
|
|
|
# TODO: decide if this should be allowed or not
|
|
|
|
logger.warning("cannot replace existing pipeline: %s", name)
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
available_pipelines[name] = pipeline
|
2023-11-25 13:50:36 +00:00
|
|
|
return True
|
2023-11-18 23:20:13 +00:00
|
|
|
|
|
|
|
|
2023-04-13 03:58:48 +00:00
|
|
|
def get_available_pipelines() -> List[str]:
|
|
|
|
return list(available_pipelines.keys())
|
|
|
|
|
|
|
|
|
|
|
|
def get_pipeline_schedulers() -> List[str]:
|
|
|
|
return list(pipeline_schedulers.keys())
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
|
2023-02-12 15:51:35 +00:00
|
|
|
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
|
|
|
for k, v in pipeline_schedulers.items():
|
|
|
|
if scheduler == v or scheduler == v.__name__:
|
|
|
|
return k
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
2023-02-02 04:21:22 +00:00
|
|
|
|
2023-09-24 01:11:05 +00:00
|
|
|
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
|
|
|
|
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def load_pipeline(
|
2023-02-14 00:04:46 +00:00
|
|
|
server: ServerContext,
|
2023-05-02 04:20:40 +00:00
|
|
|
params: ImageParams,
|
2023-04-13 03:58:48 +00:00
|
|
|
pipeline: str,
|
2023-02-05 23:55:04 +00:00
|
|
|
device: DeviceParams,
|
2023-09-24 23:15:58 +00:00
|
|
|
embeddings: Optional[List[Tuple[str, float]]] = None,
|
2023-03-15 13:30:31 +00:00
|
|
|
loras: Optional[List[Tuple[str, float]]] = None,
|
2023-05-02 04:25:55 +00:00
|
|
|
model: Optional[str] = None,
|
2023-02-05 13:53:26 +00:00
|
|
|
):
|
2023-09-24 23:15:58 +00:00
|
|
|
embeddings = embeddings or []
|
2023-03-15 03:10:33 +00:00
|
|
|
loras = loras or []
|
2023-05-02 04:25:55 +00:00
|
|
|
model = model or params.model
|
2023-03-19 15:16:43 +00:00
|
|
|
|
2023-04-22 05:28:29 +00:00
|
|
|
torch_dtype = server.torch_dtype()
|
2023-03-19 15:16:43 +00:00
|
|
|
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
2023-05-02 04:20:40 +00:00
|
|
|
|
|
|
|
control_key = params.control.name if params.control is not None else None
|
2023-03-07 14:02:53 +00:00
|
|
|
pipe_key = (
|
2023-04-13 03:58:48 +00:00
|
|
|
pipeline,
|
2023-03-07 14:02:53 +00:00
|
|
|
model,
|
|
|
|
device.device,
|
|
|
|
device.provider,
|
2023-04-13 01:03:00 +00:00
|
|
|
control_key,
|
2023-09-24 23:15:58 +00:00
|
|
|
embeddings,
|
2023-03-15 03:10:33 +00:00
|
|
|
loras,
|
2023-03-07 14:02:53 +00:00
|
|
|
)
|
2023-05-02 04:20:40 +00:00
|
|
|
scheduler_key = (params.scheduler, model)
|
|
|
|
scheduler_type = pipeline_schedulers[params.scheduler]
|
2023-02-14 00:04:46 +00:00
|
|
|
|
2023-07-03 16:33:56 +00:00
|
|
|
cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-14 00:04:46 +00:00
|
|
|
if cache_pipe is not None:
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("reusing existing diffusion pipeline")
|
2023-02-14 00:04:46 +00:00
|
|
|
pipe = cache_pipe
|
|
|
|
|
2023-05-02 23:46:49 +00:00
|
|
|
# update scheduler
|
2023-07-03 16:33:56 +00:00
|
|
|
cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
|
2023-02-14 00:04:46 +00:00
|
|
|
if cache_scheduler is None:
|
|
|
|
logger.debug("loading new diffusion scheduler")
|
|
|
|
scheduler = scheduler_type.from_pretrained(
|
|
|
|
model,
|
2024-01-13 00:58:26 +00:00
|
|
|
provider=device.ort_provider("scheduler"),
|
2023-02-15 00:57:50 +00:00
|
|
|
sess_options=device.sess_options(),
|
2023-02-14 00:04:46 +00:00
|
|
|
subfolder="scheduler",
|
2023-03-19 15:16:43 +00:00
|
|
|
torch_dtype=torch_dtype,
|
2023-02-14 00:04:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if device is not None and hasattr(scheduler, "to"):
|
2023-02-17 00:11:35 +00:00
|
|
|
scheduler = scheduler.to(device.torch_str())
|
2023-02-14 00:04:46 +00:00
|
|
|
|
|
|
|
pipe.scheduler = scheduler
|
2023-07-03 16:33:56 +00:00
|
|
|
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
2023-02-17 00:11:35 +00:00
|
|
|
run_gc([device])
|
2023-02-14 00:04:46 +00:00
|
|
|
|
2023-01-29 02:15:39 +00:00
|
|
|
else:
|
2023-03-11 14:06:03 +00:00
|
|
|
if server.cache.drop("diffusion", pipe_key) > 0:
|
|
|
|
logger.debug("unloading previous diffusion pipeline")
|
|
|
|
run_gc([device])
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("loading new diffusion pipeline from %s", model)
|
2023-09-26 02:57:25 +00:00
|
|
|
scheduler = scheduler_type.from_pretrained(
|
|
|
|
model,
|
2024-01-13 00:58:26 +00:00
|
|
|
provider=device.ort_provider("scheduler"),
|
2023-09-26 02:57:25 +00:00
|
|
|
sess_options=device.sess_options(),
|
|
|
|
subfolder="scheduler",
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
)
|
2023-02-25 19:12:58 +00:00
|
|
|
components = {
|
2023-09-26 02:57:25 +00:00
|
|
|
"scheduler": scheduler,
|
2023-02-25 19:12:58 +00:00
|
|
|
}
|
2023-02-22 05:08:13 +00:00
|
|
|
|
2023-04-14 04:05:00 +00:00
|
|
|
# shared components
|
2023-04-15 17:28:55 +00:00
|
|
|
unet_type = "unet"
|
2023-04-14 04:05:00 +00:00
|
|
|
|
|
|
|
# ControlNet component
|
2023-08-21 03:28:08 +00:00
|
|
|
if params.is_control() and params.control is not None:
|
2023-09-24 01:11:05 +00:00
|
|
|
logger.debug("loading ControlNet components")
|
|
|
|
control_components = load_controlnet(server, device, params)
|
|
|
|
components.update(control_components)
|
2023-04-15 17:28:55 +00:00
|
|
|
unet_type = "cnet"
|
2023-04-12 13:43:15 +00:00
|
|
|
|
2023-09-24 23:15:58 +00:00
|
|
|
# load various pipeline components
|
2023-09-24 01:11:05 +00:00
|
|
|
encoder_components = load_text_encoders(
|
2023-09-24 23:15:58 +00:00
|
|
|
server, device, model, embeddings, loras, torch_dtype, params
|
2023-09-24 01:11:05 +00:00
|
|
|
)
|
|
|
|
components.update(encoder_components)
|
2023-04-24 22:40:12 +00:00
|
|
|
|
2023-09-24 15:04:44 +00:00
|
|
|
unet_components = load_unet(server, device, model, loras, unet_type, params)
|
2023-09-24 01:11:05 +00:00
|
|
|
components.update(unet_components)
|
2023-09-12 12:21:35 +00:00
|
|
|
|
2023-09-24 14:49:16 +00:00
|
|
|
vae_components = load_vae(server, device, model, params)
|
2023-09-24 01:11:05 +00:00
|
|
|
components.update(vae_components)
|
2023-04-24 22:40:12 +00:00
|
|
|
|
2023-04-13 03:58:48 +00:00
|
|
|
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
2023-08-30 00:05:01 +00:00
|
|
|
|
2023-11-24 23:02:21 +00:00
|
|
|
if params.is_xl():
|
|
|
|
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
|
|
|
|
pipe = pipeline_class(
|
2023-09-12 23:16:16 +00:00
|
|
|
components["vae_decoder_session"],
|
2023-11-24 23:02:21 +00:00
|
|
|
components["text_encoder_session"],
|
|
|
|
components["unet_session"],
|
2023-11-25 18:29:17 +00:00
|
|
|
{
|
|
|
|
"force_zeros_for_empty_prompt": True,
|
|
|
|
"requires_aesthetics_score": False,
|
|
|
|
},
|
2023-11-24 23:02:21 +00:00
|
|
|
components["tokenizer"],
|
|
|
|
scheduler,
|
|
|
|
vae_encoder_session=components.get("vae_encoder_session", None),
|
|
|
|
text_encoder_2_session=components.get("text_encoder_2_session", None),
|
|
|
|
tokenizer_2=components.get("tokenizer_2", None),
|
2023-12-29 14:19:58 +00:00
|
|
|
add_watermarker=False, # not so invisible: https://github.com/ssube/onnx-web/issues/438
|
2023-09-12 12:21:35 +00:00
|
|
|
)
|
2023-11-24 23:02:21 +00:00
|
|
|
else:
|
2023-12-31 19:00:34 +00:00
|
|
|
if params.is_control():
|
2024-01-13 04:28:59 +00:00
|
|
|
if "controlnet" not in components or components["controlnet"] is None:
|
|
|
|
raise ValueError("ControlNet is required for control pipelines")
|
|
|
|
|
2023-12-19 04:21:33 +00:00
|
|
|
logger.debug(
|
|
|
|
"assembling SD pipeline for %s with ControlNet",
|
|
|
|
pipeline_class.__name__,
|
|
|
|
)
|
|
|
|
pipe = pipeline_class(
|
2023-12-20 01:19:48 +00:00
|
|
|
components["vae_encoder"],
|
|
|
|
components["vae_decoder"],
|
2023-12-19 04:21:33 +00:00
|
|
|
components["text_encoder"],
|
|
|
|
components["tokenizer"],
|
|
|
|
components["unet"],
|
|
|
|
components["controlnet"],
|
|
|
|
scheduler,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
requires_safety_checker=False,
|
|
|
|
)
|
|
|
|
elif "vae" in components:
|
2023-11-26 20:13:45 +00:00
|
|
|
# upscale uses a single VAE
|
2023-12-03 21:34:34 +00:00
|
|
|
logger.debug(
|
|
|
|
"assembling SD pipeline for %s with single VAE",
|
|
|
|
pipeline_class.__name__,
|
|
|
|
)
|
2023-11-26 20:13:45 +00:00
|
|
|
pipe = pipeline_class(
|
|
|
|
components["vae"],
|
|
|
|
components["text_encoder"],
|
|
|
|
components["tokenizer"],
|
|
|
|
components["unet"],
|
|
|
|
scheduler,
|
|
|
|
scheduler,
|
|
|
|
)
|
|
|
|
else:
|
2023-12-03 21:34:34 +00:00
|
|
|
logger.debug(
|
|
|
|
"assembling SD pipeline for %s with VAE codec",
|
|
|
|
pipeline_class.__name__,
|
|
|
|
)
|
2023-11-26 20:13:45 +00:00
|
|
|
pipe = pipeline_class(
|
|
|
|
components["vae_encoder"],
|
|
|
|
components["vae_decoder"],
|
|
|
|
components["text_encoder"],
|
|
|
|
components["tokenizer"],
|
|
|
|
components["unet"],
|
|
|
|
scheduler,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
requires_safety_checker=False,
|
|
|
|
)
|
|
|
|
|
2023-02-18 15:42:38 +00:00
|
|
|
if not server.show_progress:
|
|
|
|
pipe.set_progress_bar_config(disable=True)
|
|
|
|
|
2023-02-18 17:53:13 +00:00
|
|
|
optimize_pipeline(server, pipe)
|
2023-09-10 16:26:18 +00:00
|
|
|
patch_pipeline(server, pipe, pipeline_class, params)
|
2023-03-08 01:00:25 +00:00
|
|
|
|
2023-07-03 16:33:56 +00:00
|
|
|
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
2023-09-26 02:57:25 +00:00
|
|
|
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-09-24 01:11:05 +00:00
|
|
|
for vae in VAE_COMPONENTS:
|
|
|
|
if hasattr(pipe, vae):
|
2023-11-06 03:41:40 +00:00
|
|
|
vae_model = getattr(pipe, vae)
|
2023-11-26 05:18:57 +00:00
|
|
|
if isinstance(vae_model, VAEWrapper):
|
|
|
|
vae_model.set_tiled(tiled=params.tiled_vae)
|
|
|
|
vae_model.set_window_size(
|
|
|
|
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
|
|
|
)
|
2023-05-06 01:27:27 +00:00
|
|
|
|
|
|
|
# update panorama params
|
2023-08-21 03:28:08 +00:00
|
|
|
if params.is_panorama():
|
2023-11-11 20:37:23 +00:00
|
|
|
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
|
2023-11-05 21:46:37 +00:00
|
|
|
logger.debug(
|
|
|
|
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
|
|
|
params.unet_tile,
|
|
|
|
unet_stride,
|
|
|
|
params.vae_tile,
|
|
|
|
params.vae_overlap,
|
|
|
|
)
|
2023-11-11 20:37:23 +00:00
|
|
|
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
|
2023-09-24 01:11:05 +00:00
|
|
|
|
2023-08-30 00:05:01 +00:00
|
|
|
run_gc([device])
|
|
|
|
|
2023-01-29 02:15:39 +00:00
|
|
|
return pipe
|
2023-03-19 14:59:01 +00:00
|
|
|
|
|
|
|
|
2023-09-24 23:15:58 +00:00
|
|
|
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
|
2023-09-24 01:11:05 +00:00
|
|
|
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
|
|
|
logger.debug("loading ControlNet weights from %s", cnet_path)
|
|
|
|
components = {}
|
|
|
|
components["controlnet"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
|
|
|
cnet_path,
|
2024-01-13 00:58:26 +00:00
|
|
|
provider=device.ort_provider("controlnet"),
|
2023-09-24 01:11:05 +00:00
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return components
|
|
|
|
|
|
|
|
|
|
|
|
def load_text_encoders(
|
2023-09-24 23:15:58 +00:00
|
|
|
server: ServerContext,
|
|
|
|
device: DeviceParams,
|
|
|
|
model: str,
|
|
|
|
embeddings: Optional[List[Tuple[str, float]]],
|
|
|
|
loras: Optional[List[Tuple[str, float]]],
|
|
|
|
torch_dtype,
|
|
|
|
params: ImageParams,
|
2023-09-24 01:11:05 +00:00
|
|
|
):
|
2023-09-25 23:24:16 +00:00
|
|
|
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
2023-09-24 01:11:05 +00:00
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(
|
|
|
|
model,
|
|
|
|
subfolder="tokenizer",
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
)
|
|
|
|
|
2023-09-25 23:24:16 +00:00
|
|
|
components = {
|
|
|
|
"tokenizer": tokenizer,
|
|
|
|
}
|
2023-09-24 23:01:42 +00:00
|
|
|
|
|
|
|
if params.is_xl():
|
|
|
|
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
2023-09-25 23:24:16 +00:00
|
|
|
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
|
|
|
model,
|
|
|
|
subfolder="tokenizer_2",
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
)
|
|
|
|
components["tokenizer_2"] = tokenizer_2
|
2023-09-24 23:01:42 +00:00
|
|
|
|
|
|
|
# blend embeddings, if any
|
2023-09-24 23:15:58 +00:00
|
|
|
if embeddings is not None and len(embeddings) > 0:
|
|
|
|
embedding_names, embedding_weights = zip(*embeddings)
|
|
|
|
embedding_models = [
|
|
|
|
path.join(server.model_path, "inversion", name) for name in embedding_names
|
2023-09-24 01:11:05 +00:00
|
|
|
]
|
2023-09-24 23:01:42 +00:00
|
|
|
logger.debug(
|
2023-09-24 23:15:58 +00:00
|
|
|
"blending base model %s with embeddings from %s", model, embedding_models
|
2023-09-24 23:01:42 +00:00
|
|
|
)
|
2023-09-24 01:11:05 +00:00
|
|
|
|
2023-09-24 23:01:42 +00:00
|
|
|
# TODO: blend text_encoder_2 as well
|
2023-09-24 01:11:05 +00:00
|
|
|
text_encoder, tokenizer = blend_textual_inversions(
|
|
|
|
server,
|
|
|
|
text_encoder,
|
|
|
|
tokenizer,
|
|
|
|
list(
|
|
|
|
zip(
|
2023-09-24 23:15:58 +00:00
|
|
|
embedding_models,
|
|
|
|
embedding_weights,
|
|
|
|
embedding_names,
|
|
|
|
[None] * len(embedding_models),
|
2023-09-24 01:11:05 +00:00
|
|
|
)
|
|
|
|
),
|
|
|
|
)
|
2023-09-25 23:24:16 +00:00
|
|
|
components["tokenizer"] = tokenizer
|
|
|
|
|
|
|
|
if params.is_xl():
|
|
|
|
text_encoder_2, tokenizer_2 = blend_textual_inversions(
|
|
|
|
server,
|
|
|
|
text_encoder_2,
|
|
|
|
tokenizer_2,
|
|
|
|
list(
|
|
|
|
zip(
|
|
|
|
embedding_models,
|
|
|
|
embedding_weights,
|
|
|
|
embedding_names,
|
|
|
|
[None] * len(embedding_models),
|
|
|
|
)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
components["tokenizer_2"] = tokenizer_2
|
2023-09-24 01:11:05 +00:00
|
|
|
|
2023-09-24 23:01:42 +00:00
|
|
|
# blend LoRAs, if any
|
|
|
|
if loras is not None and len(loras) > 0:
|
2023-09-24 01:11:05 +00:00
|
|
|
lora_names, lora_weights = zip(*loras)
|
|
|
|
lora_models = [
|
|
|
|
path.join(server.model_path, "lora", name) for name in lora_names
|
|
|
|
]
|
2023-09-24 23:15:58 +00:00
|
|
|
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
|
2023-09-24 01:11:05 +00:00
|
|
|
|
2023-09-24 23:01:42 +00:00
|
|
|
# blend and load text encoder
|
2023-09-24 01:11:05 +00:00
|
|
|
text_encoder = blend_loras(
|
|
|
|
server,
|
|
|
|
text_encoder,
|
|
|
|
list(zip(lora_models, lora_weights)),
|
|
|
|
"text_encoder",
|
|
|
|
1 if params.is_xl() else None,
|
|
|
|
params.is_xl(),
|
|
|
|
)
|
|
|
|
|
|
|
|
if params.is_xl():
|
2023-09-24 03:59:41 +00:00
|
|
|
text_encoder_2 = blend_loras(
|
|
|
|
server,
|
|
|
|
text_encoder_2,
|
|
|
|
list(zip(lora_models, lora_weights)),
|
|
|
|
"text_encoder",
|
|
|
|
2,
|
|
|
|
params.is_xl(),
|
|
|
|
)
|
|
|
|
|
2023-09-24 23:01:42 +00:00
|
|
|
# prepare external data for sessions
|
|
|
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
|
|
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
|
|
|
text_encoder_opts = device.sess_options(cache=False)
|
|
|
|
text_encoder_opts.add_external_initializers(
|
|
|
|
list(text_encoder_names), list(text_encoder_values)
|
|
|
|
)
|
|
|
|
|
|
|
|
if params.is_xl():
|
|
|
|
# encoder 2 only exists in XL
|
|
|
|
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
|
|
|
text_encoder_2
|
|
|
|
)
|
|
|
|
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
|
|
|
text_encoder_2_opts = device.sess_options(cache=False)
|
|
|
|
text_encoder_2_opts.add_external_initializers(
|
|
|
|
list(text_encoder_2_names), list(text_encoder_2_values)
|
|
|
|
)
|
|
|
|
|
|
|
|
# session for te1
|
|
|
|
text_encoder_session = InferenceSession(
|
2023-09-24 23:04:23 +00:00
|
|
|
text_encoder.SerializeToString(),
|
2024-01-13 00:58:26 +00:00
|
|
|
providers=[device.ort_provider("text-encoder", "sdxl")],
|
2023-09-24 23:01:42 +00:00
|
|
|
sess_options=text_encoder_opts,
|
|
|
|
)
|
|
|
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
|
|
|
components["text_encoder_session"] = text_encoder_session
|
|
|
|
|
|
|
|
# session for te2
|
|
|
|
text_encoder_2_session = InferenceSession(
|
2023-09-24 23:04:23 +00:00
|
|
|
text_encoder_2.SerializeToString(),
|
2024-01-13 00:58:26 +00:00
|
|
|
providers=[device.ort_provider("text-encoder", "sdxl")],
|
2023-09-24 23:01:42 +00:00
|
|
|
sess_options=text_encoder_2_opts,
|
|
|
|
)
|
|
|
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
|
|
|
components["text_encoder_2_session"] = text_encoder_2_session
|
|
|
|
else:
|
|
|
|
# session for te
|
|
|
|
components["text_encoder"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
2023-09-24 23:04:23 +00:00
|
|
|
text_encoder.SerializeToString(),
|
2023-09-24 23:01:42 +00:00
|
|
|
provider=device.ort_provider("text-encoder"),
|
|
|
|
sess_options=text_encoder_opts,
|
2023-09-24 01:11:05 +00:00
|
|
|
)
|
2023-09-24 23:01:42 +00:00
|
|
|
)
|
2023-09-24 01:11:05 +00:00
|
|
|
|
|
|
|
return components
|
|
|
|
|
|
|
|
|
2023-09-24 23:15:58 +00:00
|
|
|
def load_unet(
|
|
|
|
server: ServerContext,
|
|
|
|
device: DeviceParams,
|
|
|
|
model: str,
|
|
|
|
loras: List[Tuple[str, float]],
|
|
|
|
unet_type: Literal["cnet", "unet"],
|
|
|
|
params: ImageParams,
|
|
|
|
):
|
2023-09-24 01:11:05 +00:00
|
|
|
components = {}
|
2023-09-24 23:01:42 +00:00
|
|
|
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
2023-09-24 01:11:05 +00:00
|
|
|
|
|
|
|
# LoRA blending
|
|
|
|
if loras is not None and len(loras) > 0:
|
|
|
|
lora_names, lora_weights = zip(*loras)
|
|
|
|
lora_models = [
|
|
|
|
path.join(server.model_path, "lora", name) for name in lora_names
|
|
|
|
]
|
|
|
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
|
|
|
|
|
|
|
# blend and load unet
|
2023-09-24 23:05:48 +00:00
|
|
|
unet = blend_loras(
|
2023-09-24 01:11:05 +00:00
|
|
|
server,
|
|
|
|
unet,
|
|
|
|
list(zip(lora_models, lora_weights)),
|
|
|
|
"unet",
|
|
|
|
xl=params.is_xl(),
|
|
|
|
)
|
|
|
|
|
2023-09-24 23:05:48 +00:00
|
|
|
(unet_model, unet_data) = buffer_external_data_tensors(unet)
|
2023-09-24 23:01:42 +00:00
|
|
|
unet_names, unet_values = zip(*unet_data)
|
|
|
|
unet_opts = device.sess_options(cache=False)
|
|
|
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
2023-09-24 01:11:05 +00:00
|
|
|
|
2023-09-24 23:01:42 +00:00
|
|
|
if params.is_xl():
|
|
|
|
unet_session = InferenceSession(
|
|
|
|
unet_model.SerializeToString(),
|
2024-01-13 00:58:26 +00:00
|
|
|
providers=[device.ort_provider("unet", "sdxl")],
|
2023-09-24 23:01:42 +00:00
|
|
|
sess_options=unet_opts,
|
|
|
|
)
|
|
|
|
unet_session._model_path = path.join(model, "unet")
|
|
|
|
components["unet_session"] = unet_session
|
|
|
|
else:
|
2023-09-24 01:11:05 +00:00
|
|
|
components["unet"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
2023-09-24 23:01:42 +00:00
|
|
|
unet_model.SerializeToString(),
|
2023-09-24 01:11:05 +00:00
|
|
|
provider=device.ort_provider("unet"),
|
2023-09-24 23:01:42 +00:00
|
|
|
sess_options=unet_opts,
|
2023-09-24 01:11:05 +00:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return components
|
|
|
|
|
|
|
|
|
2023-09-24 23:15:58 +00:00
|
|
|
def load_vae(
|
2023-12-03 18:53:50 +00:00
|
|
|
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
2023-09-24 23:15:58 +00:00
|
|
|
):
|
2023-09-24 01:11:05 +00:00
|
|
|
# one or more VAE models need to be loaded
|
|
|
|
vae = path.join(model, "vae", ONNX_MODEL)
|
|
|
|
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
|
|
|
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
|
|
|
|
|
|
|
components = {}
|
|
|
|
if not params.is_xl() and path.exists(vae):
|
|
|
|
logger.debug("loading VAE from %s", vae)
|
|
|
|
components["vae"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
|
|
|
vae,
|
|
|
|
provider=device.ort_provider("vae"),
|
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
|
|
|
if params.is_xl():
|
|
|
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
|
|
|
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
|
|
|
vae_decoder,
|
2024-01-13 00:58:26 +00:00
|
|
|
provider=device.ort_provider("vae", "sdxl"),
|
2023-09-24 01:11:05 +00:00
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
2023-09-24 15:04:44 +00:00
|
|
|
components["vae_decoder_session"]._model_path = vae_decoder
|
2023-09-24 01:11:05 +00:00
|
|
|
|
|
|
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
|
|
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
|
|
|
vae_encoder,
|
2024-01-13 00:58:26 +00:00
|
|
|
provider=device.ort_provider("vae", "sdxl"),
|
2023-09-24 01:11:05 +00:00
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
2023-09-24 15:04:44 +00:00
|
|
|
components["vae_encoder_session"]._model_path = vae_encoder
|
2023-09-24 01:11:05 +00:00
|
|
|
|
|
|
|
else:
|
|
|
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
|
|
|
components["vae_decoder"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
|
|
|
vae_decoder,
|
|
|
|
provider=device.ort_provider("vae"),
|
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
|
|
|
components["vae_encoder"] = OnnxRuntimeModel(
|
|
|
|
OnnxRuntimeModel.load_model(
|
|
|
|
vae_encoder,
|
|
|
|
provider=device.ort_provider("vae"),
|
|
|
|
sess_options=device.sess_options(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return components
|
|
|
|
|
|
|
|
|
2023-03-19 14:59:01 +00:00
|
|
|
def optimize_pipeline(
|
|
|
|
server: ServerContext,
|
|
|
|
pipe: StableDiffusionPipeline,
|
|
|
|
) -> None:
|
2023-11-19 00:13:13 +00:00
|
|
|
if server.has_optimization(
|
|
|
|
"diffusers-attention-slicing"
|
|
|
|
) or server.has_optimization("diffusers-attention-slicing-auto"):
|
2023-04-25 00:13:32 +00:00
|
|
|
logger.debug("enabling auto attention slicing on SD pipeline")
|
2023-03-19 14:59:01 +00:00
|
|
|
try:
|
2023-04-25 00:13:32 +00:00
|
|
|
pipe.enable_attention_slicing(slice_size="auto")
|
2023-03-19 14:59:01 +00:00
|
|
|
except Exception as e:
|
2023-04-25 00:13:32 +00:00
|
|
|
logger.warning("error while enabling auto attention slicing: %s", e)
|
|
|
|
|
2023-11-17 03:45:50 +00:00
|
|
|
if server.has_optimization("diffusers-attention-slicing-max"):
|
2023-04-25 00:13:32 +00:00
|
|
|
logger.debug("enabling max attention slicing on SD pipeline")
|
|
|
|
try:
|
|
|
|
pipe.enable_attention_slicing(slice_size="max")
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error while enabling max attention slicing: %s", e)
|
2023-03-19 14:59:01 +00:00
|
|
|
|
2023-11-17 03:45:50 +00:00
|
|
|
if server.has_optimization("diffusers-vae-slicing"):
|
2023-03-19 14:59:01 +00:00
|
|
|
logger.debug("enabling VAE slicing on SD pipeline")
|
|
|
|
try:
|
|
|
|
pipe.enable_vae_slicing()
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error while enabling VAE slicing: %s", e)
|
|
|
|
|
2023-11-17 03:45:50 +00:00
|
|
|
if server.has_optimization("diffusers-cpu-offload-sequential"):
|
2023-03-19 14:59:01 +00:00
|
|
|
logger.debug("enabling sequential CPU offload on SD pipeline")
|
|
|
|
try:
|
|
|
|
pipe.enable_sequential_cpu_offload()
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error while enabling sequential CPU offload: %s", e)
|
|
|
|
|
2023-11-17 03:45:50 +00:00
|
|
|
elif server.has_optimization("diffusers-cpu-offload-model"):
|
2023-03-19 14:59:01 +00:00
|
|
|
# TODO: check for accelerate
|
|
|
|
logger.debug("enabling model CPU offload on SD pipeline")
|
|
|
|
try:
|
|
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error while enabling model CPU offload: %s", e)
|
|
|
|
|
2023-11-17 03:45:50 +00:00
|
|
|
if server.has_optimization("diffusers-memory-efficient-attention"):
|
2023-03-19 14:59:01 +00:00
|
|
|
# TODO: check for xformers
|
|
|
|
logger.debug("enabling memory efficient attention for SD pipeline")
|
|
|
|
try:
|
|
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning("error while enabling memory efficient attention: %s", e)
|
|
|
|
|
|
|
|
|
|
|
|
def patch_pipeline(
|
|
|
|
server: ServerContext,
|
|
|
|
pipe: StableDiffusionPipeline,
|
|
|
|
pipeline: Any,
|
2023-05-02 04:20:40 +00:00
|
|
|
params: ImageParams,
|
2023-03-19 14:59:01 +00:00
|
|
|
) -> None:
|
|
|
|
logger.debug("patching SD pipeline")
|
2023-06-11 12:52:46 +00:00
|
|
|
|
2023-09-14 03:09:41 +00:00
|
|
|
if not params.is_lpw() and not params.is_xl():
|
2023-06-11 12:52:46 +00:00
|
|
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
2023-03-19 14:59:01 +00:00
|
|
|
|
2023-09-21 00:28:34 +00:00
|
|
|
original_unet = pipe.unet
|
|
|
|
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
|
|
|
|
logger.debug("patched UNet with wrapper")
|
2023-04-23 22:46:48 +00:00
|
|
|
|
|
|
|
if hasattr(pipe, "vae_decoder"):
|
2023-04-28 20:56:30 +00:00
|
|
|
original_decoder = pipe.vae_decoder
|
2023-05-02 04:20:40 +00:00
|
|
|
pipe.vae_decoder = VAEWrapper(
|
|
|
|
server,
|
|
|
|
original_decoder,
|
|
|
|
decoder=True,
|
2023-11-05 01:41:58 +00:00
|
|
|
window=params.unet_tile,
|
|
|
|
overlap=params.vae_overlap,
|
2023-05-02 04:20:40 +00:00
|
|
|
)
|
2023-09-10 16:26:18 +00:00
|
|
|
logger.debug("patched VAE decoder with wrapper")
|
|
|
|
|
|
|
|
if hasattr(pipe, "vae_encoder"):
|
2023-04-28 20:56:30 +00:00
|
|
|
original_encoder = pipe.vae_encoder
|
2023-05-02 04:20:40 +00:00
|
|
|
pipe.vae_encoder = VAEWrapper(
|
|
|
|
server,
|
|
|
|
original_encoder,
|
|
|
|
decoder=False,
|
2023-11-05 01:41:58 +00:00
|
|
|
window=params.unet_tile,
|
|
|
|
overlap=params.vae_overlap,
|
2023-05-02 04:20:40 +00:00
|
|
|
)
|
2023-09-10 16:26:18 +00:00
|
|
|
logger.debug("patched VAE encoder with wrapper")
|
|
|
|
|
|
|
|
if hasattr(pipe, "vae"):
|
|
|
|
logger.warning("not patching single VAE, tiled VAE may not work")
|