initial integration of controlnet pipeline
This commit is contained in:
parent
fcc04982b1
commit
a3daaf0112
|
@ -1,8 +1,10 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||
from .blend_controlnet import blend_controlnet
|
||||
from .blend_img2img import blend_img2img
|
||||
from .blend_inpaint import blend_inpaint
|
||||
from .blend_linear import blend_linear
|
||||
from .blend_mask import blend_mask
|
||||
from .blend_pix2pix import blend_pix2pix
|
||||
from .correct_codeformer import correct_codeformer
|
||||
from .correct_gfpgan import correct_gfpgan
|
||||
from .persist_disk import persist_disk
|
||||
|
@ -20,10 +22,12 @@ from .upscale_stable_diffusion import upscale_stable_diffusion
|
|||
from .upscale_swinir import upscale_swinir
|
||||
|
||||
CHAIN_STAGES = {
|
||||
"blend-controlnet": blend_controlnet,
|
||||
"blend-img2img": blend_img2img,
|
||||
"blend-inpaint": blend_inpaint,
|
||||
"blend-linear": blend_linear,
|
||||
"blend-mask": blend_mask,
|
||||
"blend-pix2pix": blend_pix2pix,
|
||||
"correct-codeformer": correct_codeformer,
|
||||
"correct-gfpgan": correct_gfpgan,
|
||||
"persist-disk": persist_disk,
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_controlnet(
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"blending image using controlnet, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionControlNetPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
)
|
||||
if params.lpw:
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=params.strength,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=params.strength,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
output = result.images[0]
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
|
@ -0,0 +1,73 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_pix2pix(
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"blending image using instruct pix2pix, %s steps: %s",
|
||||
params.steps,
|
||||
params.prompt,
|
||||
)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionInstructPix2PixPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
)
|
||||
if params.lpw:
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=params.strength,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=params.strength,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
output = result.images[0]
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
|
@ -7,9 +7,7 @@ from diffusers import StableDiffusionUpscalePipeline
|
|||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import optimize_pipeline, patch_pipeline
|
||||
from ..diffusers.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
|
|
@ -12,6 +12,8 @@ from onnx import load_model, save_model
|
|||
from transformers import CLIPTokenizer
|
||||
from yaml import safe_load
|
||||
|
||||
from onnx_web.convert.diffusion.control import convert_diffusion_control
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
|
@ -258,6 +260,19 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
source = network["source"]
|
||||
|
||||
try:
|
||||
if network_type == "control":
|
||||
dest = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
convert_diffusion_control(
|
||||
conversion,
|
||||
network,
|
||||
dest,
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest = fetch_model(
|
||||
conversion,
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.models.cross_attention import CrossAttnProcessor
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_control(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
opset: int,
|
||||
attention_slicing: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for ControlNet", dtype)
|
||||
|
||||
output_path = Path(output_path)
|
||||
logger.info("converting ControlNet model %s: %s -> %s", name, source, output_path)
|
||||
if path.exists(output_path):
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype)
|
||||
if attention_slicing is not None:
|
||||
logger.info("enabling attention slicing for ControlNet")
|
||||
controlnet.set_attention_slice(attention_slicing)
|
||||
|
||||
# UNET
|
||||
if is_torch_2_0:
|
||||
controlnet.set_attn_processor(CrossAttnProcessor())
|
||||
|
||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||
onnx_export(
|
||||
controlnet,
|
||||
model_args=(
|
||||
torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
||||
torch.randn(2, 3, 512, 512).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=cnet_path,
|
||||
ordered_input_names=[
|
||||
"sample",
|
||||
"timestep",
|
||||
"encoder_hidden_states",
|
||||
"controlnet_cond",
|
||||
"return_dict",
|
||||
],
|
||||
output_names=[
|
||||
"down_block_res_samples",
|
||||
"mid_block_res_sample",
|
||||
], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
"controlnet_cond": {0: "batch", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
logger.info("ONNX ControlNet saved to %s", output_path)
|
|
@ -4,7 +4,7 @@
|
|||
#
|
||||
# Originally by https://github.com/huggingface
|
||||
# Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project.
|
||||
#
|
||||
# ...diffusers.pipelines.pipeline_onnx_stable_diffusion_upscale
|
||||
# HuggingFace code used under the Apache License, Version 2.0
|
||||
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
||||
###
|
||||
|
@ -24,71 +24,15 @@ from diffusers import (
|
|||
)
|
||||
from diffusers.models.cross_attention import CrossAttnProcessor
|
||||
from onnx import load_model, save_model
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
from packaging import version
|
||||
from torch.onnx import export
|
||||
|
||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ...diffusers.load import optimize_pipeline
|
||||
from ...diffusers.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..utils import ConversionContext
|
||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
is_torch_2_0 = version.parse(
|
||||
version.parse(torch.__version__).base_version
|
||||
) >= version.parse("2.0")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
half=False,
|
||||
external_data=False,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_file = output_path.absolute().as_posix()
|
||||
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_file,
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
if half:
|
||||
logger.info("converting model to fp16 internally: %s", output_file)
|
||||
infer_shapes_path(output_file)
|
||||
base_model = load_model(output_file)
|
||||
opt_model = convert_float_to_float16(
|
||||
base_model,
|
||||
disable_shape_infer=True,
|
||||
keep_io_types=True,
|
||||
force_fp16_initializers=True,
|
||||
)
|
||||
save_model(
|
||||
opt_model,
|
||||
f"{output_file}",
|
||||
save_as_external_data=external_data,
|
||||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_diffusers(
|
||||
|
@ -104,6 +48,7 @@ def convert_diffusion_diffusers(
|
|||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||
|
||||
|
@ -119,6 +64,7 @@ def convert_diffusion_diffusers(
|
|||
logger.info("converting model with single VAE")
|
||||
|
||||
if path.exists(dest_path) and path.exists(model_index):
|
||||
# TODO: check if CNet has been converted
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
|
@ -126,7 +72,7 @@ def convert_diffusion_diffusers(
|
|||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(conversion.training_device)
|
||||
).to(device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
optimize_pipeline(conversion, pipeline)
|
||||
|
@ -145,13 +91,11 @@ def convert_diffusion_diffusers(
|
|||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(
|
||||
text_input.input_ids.to(
|
||||
device=conversion.training_device, dtype=torch.int32
|
||||
),
|
||||
text_input.input_ids.to(device=device, dtype=torch.int32),
|
||||
None, # attention mask
|
||||
None, # position ids
|
||||
None, # output attentions
|
||||
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
|
||||
torch.tensor(True).to(device=device, dtype=torch.bool),
|
||||
),
|
||||
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
||||
ordered_input_names=["input_ids"],
|
||||
|
@ -169,14 +113,10 @@ def convert_diffusion_diffusers(
|
|||
# UNET
|
||||
if single_vae:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||
unet_scale = torch.tensor(4).to(
|
||||
device=conversion.training_device, dtype=torch.long
|
||||
)
|
||||
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
|
||||
else:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = torch.tensor(False).to(
|
||||
device=conversion.training_device, dtype=torch.bool
|
||||
)
|
||||
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipeline.unet.set_attn_processor(CrossAttnProcessor())
|
||||
|
@ -188,12 +128,10 @@ def convert_diffusion_diffusers(
|
|||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=conversion.training_device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=conversion.training_device, dtype=dtype
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||
unet_scale,
|
||||
),
|
||||
output_path=unet_path,
|
||||
|
@ -212,9 +150,11 @@ def convert_diffusion_diffusers(
|
|||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = path.dirname(unet_model_path)
|
||||
unet = load_model(unet_model_path)
|
||||
|
||||
# clean up existing tensor files
|
||||
rmtree(unet_dir)
|
||||
mkdir(unet_dir)
|
||||
|
||||
# collate external tensor files into one
|
||||
save_model(
|
||||
unet,
|
||||
|
@ -226,6 +166,124 @@ def convert_diffusion_diffusers(
|
|||
)
|
||||
del pipeline.unet
|
||||
|
||||
# CNet
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
||||
|
||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||
onnx_export(
|
||||
pipe_cnet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
output_path=cnet_path,
|
||||
ordered_input_names=[
|
||||
"sample",
|
||||
"timestep",
|
||||
"encoder_hidden_states",
|
||||
"down_block_0",
|
||||
"down_block_1",
|
||||
"down_block_2",
|
||||
"down_block_3",
|
||||
"down_block_4",
|
||||
"down_block_5",
|
||||
"down_block_6",
|
||||
"down_block_7",
|
||||
"down_block_8",
|
||||
"down_block_9",
|
||||
"down_block_10",
|
||||
"down_block_11",
|
||||
"mid_block_additional_residual",
|
||||
"return_dict",
|
||||
],
|
||||
output_names=[
|
||||
"out_sample"
|
||||
], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
"down_block_0": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_1": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_2": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
|
||||
},
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
external_data=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
cnet_model_path = str(cnet_path.absolute().as_posix())
|
||||
cnet_dir = path.dirname(cnet_model_path)
|
||||
cnet = load_model(cnet_model_path)
|
||||
|
||||
# clean up existing tensor files
|
||||
rmtree(cnet_dir)
|
||||
mkdir(cnet_dir)
|
||||
|
||||
# collate external tensor files into one
|
||||
save_model(
|
||||
cnet,
|
||||
cnet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
convert_attribute=False,
|
||||
)
|
||||
del pipe_cnet
|
||||
|
||||
# VAE
|
||||
if replace_vae is not None:
|
||||
logger.debug("loading custom VAE: %s", replace_vae)
|
||||
vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
|
@ -244,7 +302,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=conversion.training_device, dtype=dtype),
|
||||
).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / ONNX_MODEL,
|
||||
|
@ -269,7 +327,7 @@ def convert_diffusion_diffusers(
|
|||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=conversion.training_device, dtype=dtype
|
||||
device=device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
|
@ -293,7 +351,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=conversion.training_device, dtype=dtype),
|
||||
).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / ONNX_MODEL,
|
||||
|
|
|
@ -9,12 +9,22 @@ import requests
|
|||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub.utils.tqdm import tqdm
|
||||
from onnx import load_model, save_model
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
from packaging import version
|
||||
from torch.onnx import export
|
||||
from yaml import safe_load
|
||||
|
||||
from ..constants import ONNX_WEIGHTS
|
||||
from ..server import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
is_torch_2_0 = version.parse(
|
||||
version.parse(torch.__version__).base_version
|
||||
) >= version.parse("2.0")
|
||||
|
||||
|
||||
ModelDict = Dict[str, Union[str, int]]
|
||||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||
|
@ -263,3 +273,50 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
half=False,
|
||||
external_data=False,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_file = output_path.absolute().as_posix()
|
||||
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_file,
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
if half:
|
||||
logger.info("converting model to fp16 internally: %s", output_file)
|
||||
infer_shapes_path(output_file)
|
||||
base_model = load_model(output_file)
|
||||
opt_model = convert_float_to_float16(
|
||||
base_model,
|
||||
disable_shape_infer=True,
|
||||
keep_io_types=True,
|
||||
force_fp16_initializers=True,
|
||||
)
|
||||
save_model(
|
||||
opt_model,
|
||||
f"{output_file}",
|
||||
save_as_external_data=external_data,
|
||||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,532 @@
|
|||
# Copyright 2023 The HuggingFace Team.
|
||||
# Converted for use with ONNX as part of https://github.com/Amblyopius/Stable-Diffusion-ONNX-FP16
|
||||
# Special thanks to https://github.com/uchuusen for the initial conversion effort
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
unet: OnnxRuntimeModel
|
||||
controlnet: OnnxRuntimeModel
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
controlnet: OnnxRuntimeModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (
|
||||
hasattr(scheduler.config, "steps_offset")
|
||||
and scheduler.config.steps_offset != 1
|
||||
):
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate(
|
||||
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
||||
)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if (
|
||||
hasattr(scheduler.config, "clip_sample")
|
||||
and scheduler.config.clip_sample is True
|
||||
):
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate(
|
||||
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
||||
)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
if isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, np.ndarray):
|
||||
height = image.shape[3]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, np.ndarray):
|
||||
width = image.shape[2]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
def prepare_image(
|
||||
self, image, width, height, batch_size, num_images_per_prompt, dtype
|
||||
):
|
||||
if not isinstance(image, np.ndarray):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
image = [
|
||||
np.array(
|
||||
i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
)[None, :]
|
||||
for i in image
|
||||
]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = generator.randn(*shape).astype(dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta, torch_gen):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = torch_gen
|
||||
return extra_step_kwargs
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt to be encoded
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="max_length", return_tensors="np"
|
||||
).input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.astype(np.int32)
|
||||
)[0]
|
||||
negative_prompt_embeds = np.repeat(
|
||||
negative_prompt_embeds, num_images_per_prompt, axis=0
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if generator:
|
||||
torch_seed = generator.randint(2147483647)
|
||||
torch_gen = torch.Generator().manual_seed(torch_seed)
|
||||
else:
|
||||
generator = np.random
|
||||
torch_gen = None
|
||||
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None
|
||||
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
image = self.prepare_image(
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt,
|
||||
np.float32,
|
||||
).numpy()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = np.concatenate([image] * 2)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
num_channels_latents = 4
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, torch_gen)
|
||||
|
||||
timestep_dtype = next(
|
||||
(
|
||||
input.type
|
||||
for input in self.unet.model.get_inputs()
|
||||
if input.name == "timestep"
|
||||
),
|
||||
"tensor(float)",
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
np.concatenate([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_model_input), t
|
||||
)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
|
||||
blocksamples = self.controlnet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
)
|
||||
|
||||
mid_block_res_sample = blocksamples[12]
|
||||
down_block_res_samples = blocksamples[0:12]
|
||||
|
||||
down_block_res_samples = [
|
||||
down_block_res_sample * controlnet_conditioning_scale
|
||||
for down_block_res_sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample *= controlnet_conditioning_scale
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
down_block_0=down_block_res_samples[0],
|
||||
down_block_1=down_block_res_samples[1],
|
||||
down_block_2=down_block_res_samples[2],
|
||||
down_block_3=down_block_res_samples[3],
|
||||
down_block_4=down_block_res_samples[4],
|
||||
down_block_5=down_block_res_samples[5],
|
||||
down_block_6=down_block_res_samples[6],
|
||||
down_block_7=down_block_res_samples[7],
|
||||
down_block_8=down_block_res_samples[8],
|
||||
down_block_9=down_block_res_samples[9],
|
||||
down_block_10=down_block_res_samples[10],
|
||||
down_block_11=down_block_res_samples[11],
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[
|
||||
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||
for i in range(latents.shape[0])
|
||||
]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
|
@ -0,0 +1,606 @@
|
|||
# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team.
|
||||
# Converted for use with ONNX as part of https://github.com/Amblyopius/Stable-Diffusion-ONNX-FP16
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
try:
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
||||
except ImportError:
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
KarrasDiffusionSchedulers,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Simplified and ONNX specific version (only allows 1 image, np over torch)
|
||||
def preprocess(image):
|
||||
if isinstance(image, np.ndarray):
|
||||
return image
|
||||
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
image = np.array(image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
||||
None, :
|
||||
]
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
unet: OnnxRuntimeModel
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet_in_channels = 8
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 7.5,
|
||||
image_guidance_scale: float = 1.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be repainted according to `prompt`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality. This pipeline requires a value of at least `1`.
|
||||
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
||||
Image guidance scale is to push the generated image towards the inital image `image`. Image guidance
|
||||
scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to
|
||||
generate images that are closely linked to the source image `image`, usually at the expense of lower
|
||||
image quality. This pipeline requires a value of at least `1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||
is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> import PIL
|
||||
>>> import requests
|
||||
>>> import torch
|
||||
>>> from io import BytesIO
|
||||
|
||||
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
|
||||
|
||||
>>> def download_image(url):
|
||||
... response = requests.get(url)
|
||||
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
||||
|
||||
>>> image = download_image(img_url).resize((512, 512))
|
||||
|
||||
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "make the mountains snowy"
|
||||
>>> image = pipe(prompt=prompt, image=image).images[0]
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
# We need a deterministic torch generator for schedulers if a (likely seeded) generator was provided
|
||||
|
||||
if generator:
|
||||
torch_seed = generator.randint(2147483647)
|
||||
torch_gen = torch.Generator().manual_seed(torch_seed)
|
||||
else:
|
||||
generator = np.random
|
||||
torch_gen = None
|
||||
|
||||
# 0. Check inputs
|
||||
self.check_inputs(prompt, callback_steps)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
# 1. Define call parameters
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = (
|
||||
guidance_scale > 1.0 and image_guidance_scale >= 1.0
|
||||
)
|
||||
# check if scheduler is in sigmas space
|
||||
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
||||
|
||||
# 2. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
image = preprocess(image)
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# 4. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare Image latents
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
image = image.astype(latents_dtype)
|
||||
# encode the init image into latents and scale the latents
|
||||
image_latents = self.vae_encoder(sample=image)[0]
|
||||
if do_classifier_free_guidance:
|
||||
uncond_image_latents = np.zeros_like(image_latents)
|
||||
image_latents = np.concatenate(
|
||||
(image_latents, image_latents, uncond_image_latents), axis=0
|
||||
)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
elif latents.shape != latents_shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
|
||||
)
|
||||
latents = latents * self.scheduler.init_noise_sigma.numpy()
|
||||
|
||||
# 7. Check that shapes of latents and image match the UNet channels
|
||||
num_channels_image = image_latents.shape[1]
|
||||
if 4 + num_channels_image != self.unet_in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: expects"
|
||||
f" {self.unet_in_channels} but received `num_channels_latents`: 4 +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {4+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
timestep_dtype = next(
|
||||
(
|
||||
input.type
|
||||
for input in self.unet.model.get_inputs()
|
||||
if input.name == "timestep"
|
||||
),
|
||||
"tensor(float)",
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta, torch_gen)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Expand the latents if we are doing classifier free guidance.
|
||||
# The latents are expanded 3 times because for pix2pix the guidance\
|
||||
# is applied for both the text and the input image.
|
||||
latent_model_input = (
|
||||
np.concatenate([latents] * 3)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
scaled_latent_model_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_model_input), t
|
||||
)
|
||||
scaled_latent_model_input = scaled_latent_model_input.cpu().numpy()
|
||||
|
||||
scaled_latent_model_input = np.concatenate(
|
||||
[scaled_latent_model_input, image_latents], axis=1
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
noise_pred = self.unet(
|
||||
sample=scaled_latent_model_input,
|
||||
timestep=np.array([t], dtype=timestep_dtype),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
)[0]
|
||||
|
||||
# Hack:
|
||||
# For karras style schedulers the model does classifer free guidance using the
|
||||
# predicted_original_sample instead of the noise_pred. So we need to compute the
|
||||
# predicted_original_sample here if we are using a karras style scheduler.
|
||||
if scheduler_is_in_sigma_space:
|
||||
step_index = (self.scheduler.timesteps == t).nonzero().item()
|
||||
sigma = self.scheduler.sigmas[step_index]
|
||||
noise_pred = latent_model_input - sigma.numpy() * noise_pred
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_text, noise_pred_image, noise_pred_uncond = np.split(
|
||||
noise_pred, 3
|
||||
)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
||||
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
||||
)
|
||||
|
||||
# Hack:
|
||||
# For karras style schedulers the model does classifer free guidance using the
|
||||
# predicted_original_sample instead of the noise_pred. But the scheduler.step function
|
||||
# expects the noise_pred and computes the predicted_original_sample internally. So we
|
||||
# need to overwrite the noise_pred here such that the value of the computed
|
||||
# predicted_original_sample is correct.
|
||||
if scheduler_is_in_sigma_space:
|
||||
noise_pred = (noise_pred - latents) / (-sigma)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents.numpy())
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 11. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image)
|
||||
|
||||
# 12. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt to be encoded
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
negative_prompt_embeds = None
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="max_length", return_tensors="np"
|
||||
).input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.astype(np.int32)
|
||||
)[0]
|
||||
negative_prompt_embeds = np.repeat(
|
||||
negative_prompt_embeds, num_images_per_prompt, axis=0
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered
|
||||
# [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
||||
|
||||
prompt_embeds = np.concatenate(
|
||||
(prompt_embeds, negative_prompt_embeds, negative_prompt_embeds)
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# safety_checker does not support batched inputs yet
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta, torch_gen):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = torch_gen
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = np.concatenate(
|
||||
[
|
||||
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||
for i in range(latents.shape[0])
|
||||
]
|
||||
)
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None
|
||||
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
|
@ -9,13 +9,12 @@ from logging import getLogger
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
|
||||
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
|
||||
import torch
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -23,21 +22,6 @@ logger = getLogger(__name__)
|
|||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
|
@ -65,7 +49,8 @@ def preprocess(image):
|
|||
|
||||
return image
|
||||
|
||||
class FakeConfig():
|
||||
|
||||
class FakeConfig:
|
||||
scaling_factor: float
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -83,10 +68,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
if hasattr(vae, "config") == False:
|
||||
if not hasattr(vae, "config"):
|
||||
setattr(vae, "config", FakeConfig())
|
||||
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
low_res_scheduler,
|
||||
scheduler,
|
||||
max_noise_level,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -118,7 +111,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
|
@ -133,7 +130,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
noise = torch.randn(
|
||||
image.shape, generator=generator, device=device, dtype=latents_dtype
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
|
@ -168,7 +167,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
(
|
||||
input.type
|
||||
for input in self.unet.model.get_inputs()
|
||||
if input.name == "timestep"
|
||||
),
|
||||
"tensor(float)",
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
|
@ -177,10 +181,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = (
|
||||
np.concatenate([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||
|
||||
# timestep to tensor
|
||||
|
@ -197,7 +207,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
|
@ -205,7 +217,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
@ -229,7 +243,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
|
@ -240,10 +261,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
|
@ -258,7 +285,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
|
||||
text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
text_embeddings = text_embeddings.reshape(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
@ -298,7 +327,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
seq_len = uncond_embeddings.shape[1]
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
uncond_embeddings = uncond_embeddings.reshape(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
@ -329,7 +329,13 @@ def run_inpaint_pipeline(
|
|||
)
|
||||
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||
job,
|
||||
server,
|
||||
stage,
|
||||
params,
|
||||
image,
|
||||
upscale=upscale,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
dest = save_image(server, outputs[0], image)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..chain import (
|
||||
ChainPipeline,
|
||||
PipelineStage,
|
||||
correct_codeformer,
|
||||
correct_gfpgan,
|
||||
upscale_bsrgan,
|
||||
|
@ -28,6 +29,8 @@ def run_upscale_correction(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
pre_stages: List[PipelineStage] = None,
|
||||
post_stages: List[PipelineStage] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
|
@ -40,6 +43,9 @@ def run_upscale_correction(
|
|||
)
|
||||
|
||||
chain = ChainPipeline()
|
||||
if pre_stages is not None:
|
||||
for stage, params in pre_stages:
|
||||
chain.append((stage, params))
|
||||
|
||||
upscale_stage = None
|
||||
if upscale.scale > 1:
|
||||
|
@ -93,6 +99,10 @@ def run_upscale_correction(
|
|||
else:
|
||||
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
||||
|
||||
if post_stages is not None:
|
||||
for stage, params in post_stages:
|
||||
chain.append((stage, params))
|
||||
|
||||
return chain(
|
||||
job,
|
||||
server,
|
||||
|
|
|
@ -0,0 +1,735 @@
|
|||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import UNet2DConditionLoadersMixin
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from diffusers.models.embeddings import (
|
||||
GaussianFourierProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UNetMidBlock2DSimpleCrossAttn,
|
||||
UpBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
r"""
|
||||
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
||||
and returns sample shaped output.
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
|
||||
mid block layer if `None`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
||||
Whether to include self-attention in the basic transformer blocks, see
|
||||
[`~models.attention.BasicTransformerBlock`].
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
||||
If `None`, it will skip the normalization and activation layers in post-processing
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, default to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
timestep_post_act (`str, *optional*, default to `None`):
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, default to `None`):
|
||||
The dimension of `cond_proj` layer in timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
||||
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
time_embedding_type: str = "positional",
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
conv_in_kernel: int = 3,
|
||||
conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(
|
||||
only_cross_attention
|
||||
) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding,
|
||||
)
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(
|
||||
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
||||
)
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
time_embed_dim // 2,
|
||||
set_W_to_weight=False,
|
||||
log=False,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
)
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos, freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
|
||||
)
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
post_act_fn=timestep_post_act,
|
||||
cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
if norm_num_groups is not None:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0],
|
||||
num_groups=norm_num_groups,
|
||||
eps=norm_eps,
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
|
||||
conv_out_padding = (conv_out_kernel - 1) // 2
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[0],
|
||||
out_channels,
|
||||
kernel_size=conv_out_kernel,
|
||||
padding=conv_out_padding,
|
||||
)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttnProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(
|
||||
name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]
|
||||
):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]
|
||||
):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `CrossAttention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = (
|
||||
num_slicable_layers * [slice_size]
|
||||
if not isinstance(slice_size, list)
|
||||
else slice_size
|
||||
)
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(
|
||||
module: torch.nn.Module, slice_size: List[int]
|
||||
):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(
|
||||
module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)
|
||||
):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
# class_labels: Optional[torch.Tensor] = None,
|
||||
# timestep_cond: Optional[torch.Tensor] = None,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals0: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals1: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals2: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals3: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals4: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals5: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals6: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals7: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals8: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals9: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals10: Optional[torch.Tensor] = None,
|
||||
down_block_additional_residuals11: Optional[torch.Tensor] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
|
||||
class_labels = None
|
||||
timestep_cond = None
|
||||
attention_mask = None
|
||||
cross_attention_kwargs = None
|
||||
|
||||
down_block_additional_residuals = (
|
||||
down_block_additional_residuals0,
|
||||
down_block_additional_residuals1,
|
||||
down_block_additional_residuals2,
|
||||
down_block_additional_residuals3,
|
||||
down_block_additional_residuals4,
|
||||
down_block_additional_residuals5,
|
||||
down_block_additional_residuals6,
|
||||
down_block_additional_residuals7,
|
||||
down_block_additional_residuals8,
|
||||
down_block_additional_residuals9,
|
||||
down_block_additional_residuals10,
|
||||
down_block_additional_residuals11,
|
||||
)
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError(
|
||||
"class_labels should be provided when num_class_embeds > 0"
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if down_block_additional_residuals is not None:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample += down_block_additional_residual
|
||||
new_down_block_res_samples += (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
|
@ -240,6 +240,13 @@ def load_models(server: ServerContext) -> None:
|
|||
"stable-diffusion-*",
|
||||
],
|
||||
)
|
||||
diffusion_models.extend(
|
||||
list_model_globs(
|
||||
server,
|
||||
["*"],
|
||||
base_path=path.join(server.model_path, "diffusion"),
|
||||
)
|
||||
)
|
||||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||
|
||||
correction_models = list_model_globs(
|
||||
|
@ -248,6 +255,13 @@ def load_models(server: ServerContext) -> None:
|
|||
"correction-*",
|
||||
],
|
||||
)
|
||||
correction_models.extend(
|
||||
list_model_globs(
|
||||
server,
|
||||
["*"],
|
||||
base_path=path.join(server.model_path, "correction"),
|
||||
)
|
||||
)
|
||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||
|
||||
upscaling_models = list_model_globs(
|
||||
|
@ -256,9 +270,26 @@ def load_models(server: ServerContext) -> None:
|
|||
"upscaling-*",
|
||||
],
|
||||
)
|
||||
upscaling_models.extend(
|
||||
list_model_globs(
|
||||
server,
|
||||
["*"],
|
||||
base_path=path.join(server.model_path, "upscaling"),
|
||||
)
|
||||
)
|
||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||
|
||||
# additional networks
|
||||
control_models = list_model_globs(
|
||||
server,
|
||||
[
|
||||
"*",
|
||||
],
|
||||
base_path=path.join(server.model_path, "control"),
|
||||
)
|
||||
logger.debug("loaded ControlNet models from disk: %s", control_models)
|
||||
network_models.extend([NetworkModel(model, "control") for model in control_models])
|
||||
|
||||
inversion_models = list_model_globs(
|
||||
server,
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue