2023-01-26 03:04:00 +00:00
|
|
|
from diffusers import (
|
|
|
|
DDPMScheduler,
|
|
|
|
OnnxRuntimeModel,
|
|
|
|
StableDiffusionUpscalePipeline,
|
|
|
|
)
|
|
|
|
from typing import (
|
2023-01-26 03:29:18 +00:00
|
|
|
Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
Callable,
|
|
|
|
Union,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
)
|
|
|
|
|
|
|
|
import PIL
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
vae: OnnxRuntimeModel,
|
|
|
|
text_encoder: OnnxRuntimeModel,
|
2023-01-26 03:29:18 +00:00
|
|
|
tokenizer: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
unet: OnnxRuntimeModel,
|
|
|
|
low_res_scheduler: DDPMScheduler,
|
2023-01-26 03:29:18 +00:00
|
|
|
scheduler: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
max_noise_level: int = 350,
|
|
|
|
):
|
2023-01-26 03:29:18 +00:00
|
|
|
super().__init__(vae, text_encoder, tokenizer, unet,
|
|
|
|
low_res_scheduler, scheduler, max_noise_level)
|
2023-01-26 03:04:00 +00:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
2023-01-26 03:29:18 +00:00
|
|
|
*args,
|
|
|
|
**kwargs,
|
2023-01-26 03:04:00 +00:00
|
|
|
):
|
2023-01-26 03:29:18 +00:00
|
|
|
super().__call__(*args, **kwargs)
|
2023-01-28 05:28:14 +00:00
|
|
|
|
|
|
|
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
|
|
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
|
|
|
|
|
|
|
text_inputs = self.tokenizer(
|
|
|
|
prompt,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt",
|
|
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
|
|
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
|
|
|
logger.warning(
|
|
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
|
|
)
|
|
|
|
|
|
|
|
text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device))
|
|
|
|
text_embeddings = text_embeddings[0]
|
|
|
|
|
|
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
|
|
bs_embed, seq_len, _ = text_embeddings.shape
|
|
|
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
|
|
|
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
uncond_tokens: List[str]
|
|
|
|
if negative_prompt is None:
|
|
|
|
uncond_tokens = [""] * batch_size
|
|
|
|
elif type(prompt) is not type(negative_prompt):
|
|
|
|
raise TypeError(
|
|
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
|
|
f" {type(prompt)}."
|
|
|
|
)
|
|
|
|
elif isinstance(negative_prompt, str):
|
|
|
|
uncond_tokens = [negative_prompt]
|
|
|
|
elif batch_size != len(negative_prompt):
|
|
|
|
raise ValueError(
|
|
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
|
|
" the batch size of `prompt`."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
uncond_tokens = negative_prompt
|
|
|
|
|
|
|
|
max_length = text_input_ids.shape[-1]
|
|
|
|
uncond_input = self.tokenizer(
|
|
|
|
uncond_tokens,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt",
|
|
|
|
)
|
|
|
|
|
|
|
|
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device))
|
|
|
|
uncond_embeddings = uncond_embeddings[0]
|
|
|
|
|
|
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
|
|
seq_len = uncond_embeddings.shape[1]
|
|
|
|
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
|
|
|
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
|
|
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
|
# to avoid doing two forward passes
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
|
|
|
|
|
return text_embeddings
|