fix(api): update LPW pipeline (#298)
This commit is contained in:
parent
c0ece2453d
commit
93fcfd1422
|
@ -30,6 +30,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
|||
- [hosted on Github Pages](https://ssube.github.io/onnx-web), from your CDN, or locally
|
||||
- [persists your recent images and progress as you change tabs](docs/user-guide.md#image-history)
|
||||
- queue up multiple images and retry errors
|
||||
- translations available for English, French, German, and Spanish (please open an issue for more)
|
||||
- supports many `diffusers` pipelines
|
||||
- [txt2img](docs/user-guide.md#txt2img-tab)
|
||||
- [img2img](docs/user-guide.md#img2img-tab)
|
||||
|
|
|
@ -8,15 +8,15 @@ import re
|
|||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import deprecate, logging
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
try:
|
||||
|
@ -201,14 +201,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
|||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
||||
if no_boseos_middle:
|
||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
else:
|
||||
|
@ -347,12 +347,14 @@ def get_weighted_text_embeddings(
|
|||
# pad the length of tokens and weights
|
||||
bos = pipe.tokenizer.bos_token_id
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.tokenizer.model_max_length,
|
||||
)
|
||||
|
@ -364,6 +366,7 @@ def get_weighted_text_embeddings(
|
|||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.tokenizer.model_max_length,
|
||||
)
|
||||
|
@ -408,7 +411,7 @@ def get_weighted_text_embeddings(
|
|||
|
||||
def preprocess_image(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
|
@ -418,7 +421,7 @@ def preprocess_image(image):
|
|||
def preprocess_mask(mask, scale_factor=8):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
|
@ -446,7 +449,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -473,7 +476,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
|
@ -672,7 +675,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
callback_steps: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
@ -749,10 +752,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
message = "Please use `image` instead of `init_image`."
|
||||
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
||||
image = init_image or image
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
@ -887,7 +886,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
callback_steps: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
@ -978,7 +977,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
callback_steps: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
@ -1070,7 +1069,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
callback_steps: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue