fix(api): update LPW pipeline (#298)
This commit is contained in:
parent
c0ece2453d
commit
93fcfd1422
|
@ -30,6 +30,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
||||||
- [hosted on Github Pages](https://ssube.github.io/onnx-web), from your CDN, or locally
|
- [hosted on Github Pages](https://ssube.github.io/onnx-web), from your CDN, or locally
|
||||||
- [persists your recent images and progress as you change tabs](docs/user-guide.md#image-history)
|
- [persists your recent images and progress as you change tabs](docs/user-guide.md#image-history)
|
||||||
- queue up multiple images and retry errors
|
- queue up multiple images and retry errors
|
||||||
|
- translations available for English, French, German, and Spanish (please open an issue for more)
|
||||||
- supports many `diffusers` pipelines
|
- supports many `diffusers` pipelines
|
||||||
- [txt2img](docs/user-guide.md#txt2img-tab)
|
- [txt2img](docs/user-guide.md#txt2img-tab)
|
||||||
- [img2img](docs/user-guide.md#img2img-tab)
|
- [img2img](docs/user-guide.md#img2img-tab)
|
||||||
|
|
|
@ -8,15 +8,15 @@ import re
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import PIL
|
|
||||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import logging
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -201,14 +201,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
||||||
r"""
|
r"""
|
||||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||||
"""
|
"""
|
||||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||||
for i in range(len(tokens)):
|
for i in range(len(tokens)):
|
||||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||||
else:
|
else:
|
||||||
|
@ -347,12 +347,14 @@ def get_weighted_text_embeddings(
|
||||||
# pad the length of tokens and weights
|
# pad the length of tokens and weights
|
||||||
bos = pipe.tokenizer.bos_token_id
|
bos = pipe.tokenizer.bos_token_id
|
||||||
eos = pipe.tokenizer.eos_token_id
|
eos = pipe.tokenizer.eos_token_id
|
||||||
|
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
|
||||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
prompt_weights,
|
prompt_weights,
|
||||||
max_length,
|
max_length,
|
||||||
bos,
|
bos,
|
||||||
eos,
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
chunk_length=pipe.tokenizer.model_max_length,
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
)
|
)
|
||||||
|
@ -364,6 +366,7 @@ def get_weighted_text_embeddings(
|
||||||
max_length,
|
max_length,
|
||||||
bos,
|
bos,
|
||||||
eos,
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
chunk_length=pipe.tokenizer.model_max_length,
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
)
|
)
|
||||||
|
@ -408,7 +411,7 @@ def get_weighted_text_embeddings(
|
||||||
|
|
||||||
def preprocess_image(image):
|
def preprocess_image(image):
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
||||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
@ -418,7 +421,7 @@ def preprocess_image(image):
|
||||||
def preprocess_mask(mask, scale_factor=8):
|
def preprocess_mask(mask, scale_factor=8):
|
||||||
mask = mask.convert("L")
|
mask = mask.convert("L")
|
||||||
w, h = mask.size
|
w, h = mask.size
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
||||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
||||||
mask = np.array(mask).astype(np.float32) / 255.0
|
mask = np.array(mask).astype(np.float32) / 255.0
|
||||||
mask = np.tile(mask, (4, 1, 1))
|
mask = np.tile(mask, (4, 1, 1))
|
||||||
|
@ -446,7 +449,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
unet: OnnxRuntimeModel,
|
unet: OnnxRuntimeModel,
|
||||||
scheduler: SchedulerMixin,
|
scheduler: SchedulerMixin,
|
||||||
safety_checker: OnnxRuntimeModel,
|
safety_checker: OnnxRuntimeModel,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPImageProcessor,
|
||||||
requires_safety_checker: bool = True,
|
requires_safety_checker: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -473,7 +476,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
unet: OnnxRuntimeModel,
|
unet: OnnxRuntimeModel,
|
||||||
scheduler: SchedulerMixin,
|
scheduler: SchedulerMixin,
|
||||||
safety_checker: OnnxRuntimeModel,
|
safety_checker: OnnxRuntimeModel,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPImageProcessor,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae_encoder=vae_encoder,
|
vae_encoder=vae_encoder,
|
||||||
|
@ -672,7 +675,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -749,10 +752,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
message = "Please use `image` instead of `init_image`."
|
|
||||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
|
||||||
image = init_image or image
|
|
||||||
|
|
||||||
# 0. Default height and width to unet
|
# 0. Default height and width to unet
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
@ -887,7 +886,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -978,7 +977,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -1070,7 +1069,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue