1
0
Fork 0

initial integration of controlnet pipeline

This commit is contained in:
Sean Sube 2023-04-11 19:29:25 -05:00
parent fcc04982b1
commit a3daaf0112
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 2425 additions and 120 deletions

View File

@ -1,8 +1,10 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_controlnet import blend_controlnet
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_linear import blend_linear
from .blend_mask import blend_mask
from .blend_pix2pix import blend_pix2pix
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
@ -20,10 +22,12 @@ from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
CHAIN_STAGES = {
"blend-controlnet": blend_controlnet,
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-linear": blend_linear,
"blend-mask": blend_mask,
"blend-pix2pix": blend_pix2pix,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,

View File

@ -0,0 +1,71 @@
from logging import getLogger
from typing import Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_controlnet(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using controlnet, %s steps: %s", params.steps, params.prompt
)
pipe = load_pipeline(
server,
OnnxStableDiffusionControlNetPipeline,
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

View File

@ -0,0 +1,73 @@
from logging import getLogger
from typing import Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_pix2pix(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using instruct pix2pix, %s steps: %s",
params.steps,
params.prompt,
)
pipe = load_pipeline(
server,
OnnxStableDiffusionInstructPix2PixPipeline,
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -7,9 +7,7 @@ from diffusers import StableDiffusionUpscalePipeline
from PIL import Image
from ..diffusers.load import optimize_pipeline, patch_pipeline
from ..diffusers.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc

View File

@ -12,6 +12,8 @@ from onnx import load_model, save_model
from transformers import CLIPTokenizer
from yaml import safe_load
from onnx_web.convert.diffusion.control import convert_diffusion_control
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers
@ -258,6 +260,19 @@ def convert_models(conversion: ConversionContext, args, models: Models):
source = network["source"]
try:
if network_type == "control":
dest = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
)
if network_type == "inversion" and network_model == "concept":
dest = fetch_model(
conversion,

View File

@ -0,0 +1,78 @@
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict
import torch
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.cross_attention import CrossAttnProcessor
from ...constants import ONNX_MODEL
from ..utils import ConversionContext, is_torch_2_0, onnx_export
logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_control(
conversion: ConversionContext,
model: Dict,
source: str,
model_path: str,
output_path: str,
opset: int,
attention_slicing: str,
):
name = model.get("name")
source = source or model.get("source")
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for ControlNet", dtype)
output_path = Path(output_path)
logger.info("converting ControlNet model %s: %s -> %s", name, source, output_path)
if path.exists(output_path):
logger.info("ONNX model already exists, skipping")
return
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype)
if attention_slicing is not None:
logger.info("enabling attention slicing for ControlNet")
controlnet.set_attention_slice(attention_slicing)
# UNET
if is_torch_2_0:
controlnet.set_attn_processor(CrossAttnProcessor())
cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
controlnet,
model_args=(
torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
torch.randn(2, 3, 512, 512).to(device=device, dtype=dtype),
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"controlnet_cond",
"return_dict",
],
output_names=[
"down_block_res_samples",
"mid_block_res_sample",
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"controlnet_cond": {0: "batch", 2: "height", 3: "width"},
},
opset=opset,
)
logger.info("ONNX ControlNet saved to %s", output_path)

View File

@ -4,7 +4,7 @@
#
# Originally by https://github.com/huggingface
# Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project.
#
# ...diffusers.pipelines.pipeline_onnx_stable_diffusion_upscale
# HuggingFace code used under the Apache License, Version 2.0
# https://github.com/huggingface/diffusers/blob/main/LICENSE
###
@ -24,71 +24,15 @@ from diffusers import (
)
from diffusers.models.cross_attention import CrossAttnProcessor
from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from packaging import version
from torch.onnx import export
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..utils import ConversionContext
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, onnx_export
logger = getLogger(__name__)
is_torch_2_0 = version.parse(
version.parse(torch.__version__).base_version
) >= version.parse("2.0")
def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
half=False,
external_data=False,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
output_file = output_path.absolute().as_posix()
export(
model,
model_args,
f=output_file,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
if half:
logger.info("converting model to fp16 internally: %s", output_file)
infer_shapes_path(output_file)
base_model = load_model(output_file)
opt_model = convert_float_to_float16(
base_model,
disable_shape_infer=True,
keep_io_types=True,
force_fp16_initializers=True,
)
save_model(
opt_model,
f"{output_file}",
save_as_external_data=external_data,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
@torch.no_grad()
def convert_diffusion_diffusers(
@ -104,6 +48,7 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
@ -119,6 +64,7 @@ def convert_diffusion_diffusers(
logger.info("converting model with single VAE")
if path.exists(dest_path) and path.exists(model_index):
# TODO: check if CNet has been converted
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
@ -126,7 +72,7 @@ def convert_diffusion_diffusers(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(conversion.training_device)
).to(device)
output_path = Path(dest_path)
optimize_pipeline(conversion, pipeline)
@ -145,13 +91,11 @@ def convert_diffusion_diffusers(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(
device=conversion.training_device, dtype=torch.int32
),
text_input.input_ids.to(device=device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
torch.tensor(True).to(device=device, dtype=torch.bool),
),
output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"],
@ -169,14 +113,10 @@ def convert_diffusion_diffusers(
# UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(
device=conversion.training_device, dtype=torch.long
)
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(
device=conversion.training_device, dtype=torch.bool
)
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
if is_torch_2_0:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
@ -188,12 +128,10 @@ def convert_diffusion_diffusers(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=conversion.training_device, dtype=dtype
),
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=conversion.training_device, dtype=dtype
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
unet_scale,
),
output_path=unet_path,
@ -212,9 +150,11 @@ def convert_diffusion_diffusers(
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)
# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)
# collate external tensor files into one
save_model(
unet,
@ -226,6 +166,124 @@ def convert_diffusion_diffusers(
)
del pipeline.unet
# CNet
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"down_block_0",
"down_block_1",
"down_block_2",
"down_block_3",
"down_block_4",
"down_block_5",
"down_block_6",
"down_block_7",
"down_block_8",
"down_block_9",
"down_block_10",
"down_block_11",
"mid_block_additional_residual",
"return_dict",
],
output_names=[
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
)
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
cnet = load_model(cnet_model_path)
# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)
# collate external tensor files into one
save_model(
cnet,
cnet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del pipe_cnet
# VAE
if replace_vae is not None:
logger.debug("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
@ -244,7 +302,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=conversion.training_device, dtype=dtype),
).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae" / ONNX_MODEL,
@ -269,7 +327,7 @@ def convert_diffusion_diffusers(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=conversion.training_device, dtype=dtype
device=device, dtype=dtype
),
False,
),
@ -293,7 +351,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=conversion.training_device, dtype=dtype),
).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / ONNX_MODEL,

View File

@ -9,12 +9,22 @@ import requests
import safetensors
import torch
from huggingface_hub.utils.tqdm import tqdm
from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from packaging import version
from torch.onnx import export
from yaml import safe_load
from ..constants import ONNX_WEIGHTS
from ..server import ServerContext
logger = getLogger(__name__)
is_torch_2_0 = version.parse(
version.parse(torch.__version__).base_version
) >= version.parse("2.0")
ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
@ -263,3 +273,50 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = checkpoint["state_dict"]
return checkpoint
def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
half=False,
external_data=False,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
output_file = output_path.absolute().as_posix()
export(
model,
model_args,
f=output_file,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
if half:
logger.info("converting model to fp16 internally: %s", output_file)
infer_shapes_path(output_file)
base_model = load_model(output_file)
opt_model = convert_float_to_float16(
base_model,
disable_shape_infer=True,
keep_io_types=True,
force_fp16_initializers=True,
)
save_model(
opt_model,
f"{output_file}",
save_as_external_data=external_data,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)

View File

@ -0,0 +1,532 @@
# Copyright 2023 The HuggingFace Team.
# Converted for use with ONNX as part of https://github.com/Amblyopius/Stable-Diffusion-ONNX-FP16
# Special thanks to https://github.com/uchuusen for the initial conversion effort
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTokenizer
logger = logging.get_logger(__name__)
class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
controlnet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
controlnet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _default_height_width(self, height, width, image):
if isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, np.ndarray):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, np.ndarray):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, dtype
):
if not isinstance(image, np.ndarray):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
image = [
np.array(
i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
)[None, :]
for i in image
]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0)
image = torch.from_numpy(image)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
generator,
latents=None,
):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = generator.randn(*shape).astype(dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta, torch_gen):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = torch_gen
return extra_step_kwargs
def _encode_prompt(
self,
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="max_length", return_tensors="np"
).input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
negative_prompt_embeds = np.repeat(
negative_prompt_embeds, num_images_per_prompt, axis=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[np.ndarray, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
controlnet_conditioning_scale: float = 1.0,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if generator:
torch_seed = generator.randint(2147483647)
torch_gen = torch.Generator().manual_seed(torch_seed)
else:
generator = np.random
torch_gen = None
height, width = self._default_height_width(height, width, image)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare image
image = self.prepare_image(
image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
np.float32,
).numpy()
if do_classifier_free_guidance:
image = np.concatenate([image] * 2)
# get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype
num_channels_latents = 4
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
latents_dtype,
generator,
latents,
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, torch_gen)
timestep_dtype = next(
(
input.type
for input in self.unet.model.get_inputs()
if input.name == "timestep"
),
"tensor(float)",
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
timestep = np.array([t], dtype=timestep_dtype)
blocksamples = self.controlnet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
)
mid_block_res_sample = blocksamples[12]
down_block_res_samples = blocksamples[0:12]
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
down_block_0=down_block_res_samples[0],
down_block_1=down_block_res_samples[1],
down_block_2=down_block_res_samples[2],
down_block_3=down_block_res_samples[3],
down_block_4=down_block_res_samples[4],
down_block_5=down_block_res_samples[5],
down_block_6=down_block_res_samples[6],
down_block_7=down_block_res_samples[7],
down_block_8=down_block_res_samples[8],
down_block_9=down_block_res_samples[9],
down_block_10=down_block_res_samples[10],
down_block_11=down_block_res_samples[11],
mid_block_additional_residual=mid_block_res_sample,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents),
**extra_step_kwargs,
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

View File

@ -0,0 +1,606 @@
# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team.
# Converted for use with ONNX as part of https://github.com/Amblyopius/Stable-Diffusion-ONNX-FP16
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer
try:
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
except ImportError:
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
from diffusers import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import (
DDIMScheduler,
KarrasDiffusionSchedulers,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import PIL_INTERPOLATION, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Simplified and ONNX specific version (only allows 1 image, np over torch)
def preprocess(image):
if isinstance(image, np.ndarray):
return image
w, h = image.size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
image = np.array(image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
None, :
]
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
return image
class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
self.unet_in_channels = 8
self.vae_scale_factor = 8
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
num_inference_steps: int = 100,
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. This pipeline requires a value of at least `1`.
image_guidance_scale (`float`, *optional*, defaults to 1.5):
Image guidance scale is to push the generated image towards the inital image `image`. Image guidance
scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to
generate images that are closely linked to the source image `image`, usually at the expense of lower
image quality. This pipeline requires a value of at least `1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> import PIL
>>> import requests
>>> import torch
>>> from io import BytesIO
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
>>> image = download_image(img_url).resize((512, 512))
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "make the mountains snowy"
>>> image = pipe(prompt=prompt, image=image).images[0]
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# We need a deterministic torch generator for schedulers if a (likely seeded) generator was provided
if generator:
torch_seed = generator.randint(2147483647)
torch_gen = torch.Generator().manual_seed(torch_seed)
else:
generator = np.random
torch_gen = None
# 0. Check inputs
self.check_inputs(prompt, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 1. Define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = (
guidance_scale > 1.0 and image_guidance_scale >= 1.0
)
# check if scheduler is in sigmas space
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
# 2. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
# 3. Preprocess image
image = preprocess(image)
height, width = image.shape[-2:]
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare Image latents
latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents
image_latents = self.vae_encoder(sample=image)[0]
if do_classifier_free_guidance:
uncond_image_latents = np.zeros_like(image_latents)
image_latents = np.concatenate(
(image_latents, image_latents, uncond_image_latents), axis=0
)
# 6. Prepare latent variables
latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
)
latents = latents * self.scheduler.init_noise_sigma.numpy()
# 7. Check that shapes of latents and image match the UNet channels
num_channels_image = image_latents.shape[1]
if 4 + num_channels_image != self.unet_in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: expects"
f" {self.unet_in_channels} but received `num_channels_latents`: 4 +"
f" `num_channels_image`: {num_channels_image} "
f" = {4+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
timestep_dtype = next(
(
input.type
for input in self.unet.model.get_inputs()
if input.name == "timestep"
),
"tensor(float)",
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, torch_gen)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance.
# The latents are expanded 3 times because for pix2pix the guidance\
# is applied for both the text and the input image.
latent_model_input = (
np.concatenate([latents] * 3)
if do_classifier_free_guidance
else latents
)
scaled_latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
scaled_latent_model_input = scaled_latent_model_input.cpu().numpy()
scaled_latent_model_input = np.concatenate(
[scaled_latent_model_input, image_latents], axis=1
)
# predict the noise residual
noise_pred = self.unet(
sample=scaled_latent_model_input,
timestep=np.array([t], dtype=timestep_dtype),
encoder_hidden_states=prompt_embeds,
)[0]
# Hack:
# For karras style schedulers the model does classifer free guidance using the
# predicted_original_sample instead of the noise_pred. So we need to compute the
# predicted_original_sample here if we are using a karras style scheduler.
if scheduler_is_in_sigma_space:
step_index = (self.scheduler.timesteps == t).nonzero().item()
sigma = self.scheduler.sigmas[step_index]
noise_pred = latent_model_input - sigma.numpy() * noise_pred
# perform guidance
if do_classifier_free_guidance:
noise_pred_text, noise_pred_image, noise_pred_uncond = np.split(
noise_pred, 3
)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_image)
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
)
# Hack:
# For karras style schedulers the model does classifer free guidance using the
# predicted_original_sample instead of the noise_pred. But the scheduler.step function
# expects the noise_pred and computes the predicted_original_sample internally. So we
# need to overwrite the noise_pred here such that the value of the computed
# predicted_original_sample is correct.
if scheduler_is_in_sigma_space:
noise_pred = (noise_pred - latents) / (-sigma)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents.numpy())
# 10. Post-processing
image = self.decode_latents(latents)
# 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image)
# 12. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
def _encode_prompt(
self,
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
negative_prompt_embeds = None
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="max_length", return_tensors="np"
).input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
negative_prompt_embeds = np.repeat(
negative_prompt_embeds, num_images_per_prompt, axis=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered
# [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
prompt_embeds = np.concatenate(
(prompt_embeds, negative_prompt_embeds, negative_prompt_embeds)
)
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# safety_checker does not support batched inputs yet
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta, torch_gen):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = torch_gen
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def check_inputs(self, prompt, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

View File

@ -9,13 +9,12 @@ from logging import getLogger
from typing import Any, Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
import torch
from diffusers.pipeline_utils import ImagePipelineOutput
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
from diffusers.schedulers import DDPMScheduler
logger = getLogger(__name__)
@ -23,21 +22,6 @@ logger = getLogger(__name__)
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
ORT_TO_PT_TYPE = {
"float16": torch.float16,
"float32": torch.float32,
@ -65,7 +49,8 @@ def preprocess(image):
return image
class FakeConfig():
class FakeConfig:
scaling_factor: float
def __init__(self) -> None:
@ -83,10 +68,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler: Any,
max_noise_level: int = 350,
):
if hasattr(vae, "config") == False:
if not hasattr(vae, "config"):
setattr(vae, "config", FakeConfig())
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level,
)
def __call__(
self,
@ -118,7 +111,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
@ -133,7 +130,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
noise = torch.randn(
image.shape, generator=generator, device=device, dtype=latents_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -168,7 +167,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
(
input.type
for input in self.unet.model.get_inputs()
if input.name == "timestep"
),
"tensor(float)",
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
@ -177,10 +181,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
@ -197,7 +207,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
@ -205,7 +217,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
@ -229,7 +243,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
@ -240,10 +261,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
@ -258,7 +285,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
bs_embed, seq_len, _ = text_embeddings.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.reshape(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
@ -298,7 +327,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
uncond_embeddings = uncond_embeddings.reshape(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch

View File

@ -329,7 +329,13 @@ def run_inpaint_pipeline(
)
image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
job,
server,
stage,
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(server, outputs[0], image)

View File

@ -1,10 +1,11 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
from ..chain import (
ChainPipeline,
PipelineStage,
correct_codeformer,
correct_gfpgan,
upscale_bsrgan,
@ -28,6 +29,8 @@ def run_upscale_correction(
*,
upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
) -> Image.Image:
"""
This is a convenience method for a chain pipeline that will run upscaling and
@ -40,6 +43,9 @@ def run_upscale_correction(
)
chain = ChainPipeline()
if pre_stages is not None:
for stage, params in pre_stages:
chain.append((stage, params))
upscale_stage = None
if upscale.scale > 1:
@ -93,6 +99,10 @@ def run_upscale_correction(
else:
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
if post_stages is not None:
for stage, params in post_stages:
chain.append((stage, params))
return chain(
job,
server,

735
api/onnx_web/models/cnet.py Normal file
View File

@ -0,0 +1,735 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.models.cross_attention import AttnProcessor
from diffusers.models.embeddings import (
GaussianFourierProjection,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
)
from diffusers.utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(
only_cross_attention
) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
)
# time
if time_embedding_type == "fourier":
time_embed_dim = block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
)
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2,
set_W_to_weight=False,
log=False,
flip_sin_to_cos=flip_sin_to_cos,
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos, freq_shift
)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=norm_eps,
)
self.conv_act = nn.SiLU()
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0],
out_channels,
kernel_size=conv_out_kernel,
padding=conv_out_padding,
)
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]
):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(
self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]
):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = (
num_slicable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(
module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)
):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
# class_labels: Optional[torch.Tensor] = None,
# timestep_cond: Optional[torch.Tensor] = None,
# attention_mask: Optional[torch.Tensor] = None,
# cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals0: Optional[torch.Tensor] = None,
down_block_additional_residuals1: Optional[torch.Tensor] = None,
down_block_additional_residuals2: Optional[torch.Tensor] = None,
down_block_additional_residuals3: Optional[torch.Tensor] = None,
down_block_additional_residuals4: Optional[torch.Tensor] = None,
down_block_additional_residuals5: Optional[torch.Tensor] = None,
down_block_additional_residuals6: Optional[torch.Tensor] = None,
down_block_additional_residuals7: Optional[torch.Tensor] = None,
down_block_additional_residuals8: Optional[torch.Tensor] = None,
down_block_additional_residuals9: Optional[torch.Tensor] = None,
down_block_additional_residuals10: Optional[torch.Tensor] = None,
down_block_additional_residuals11: Optional[torch.Tensor] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = False,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
class_labels = None
timestep_cond = None
attention_mask = None
cross_attention_kwargs = None
down_block_additional_residuals = (
down_block_additional_residuals0,
down_block_additional_residuals1,
down_block_additional_residuals2,
down_block_additional_residuals3,
down_block_additional_residuals4,
down_block_additional_residuals5,
down_block_additional_residuals6,
down_block_additional_residuals7,
down_block_additional_residuals8,
down_block_additional_residuals9,
down_block_additional_residuals10,
down_block_additional_residuals11,
)
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample += down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)

View File

@ -240,6 +240,13 @@ def load_models(server: ServerContext) -> None:
"stable-diffusion-*",
],
)
diffusion_models.extend(
list_model_globs(
server,
["*"],
base_path=path.join(server.model_path, "diffusion"),
)
)
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = list_model_globs(
@ -248,6 +255,13 @@ def load_models(server: ServerContext) -> None:
"correction-*",
],
)
correction_models.extend(
list_model_globs(
server,
["*"],
base_path=path.join(server.model_path, "correction"),
)
)
logger.debug("loaded correction models from disk: %s", correction_models)
upscaling_models = list_model_globs(
@ -256,9 +270,26 @@ def load_models(server: ServerContext) -> None:
"upscaling-*",
],
)
upscaling_models.extend(
list_model_globs(
server,
["*"],
base_path=path.join(server.model_path, "upscaling"),
)
)
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
# additional networks
control_models = list_model_globs(
server,
[
"*",
],
base_path=path.join(server.model_path, "control"),
)
logger.debug("loaded ControlNet models from disk: %s", control_models)
network_models.extend([NetworkModel(model, "control") for model in control_models])
inversion_models = list_model_globs(
server,
[