1
0
Fork 0

fix(api): update LPW pipeline (#298)

This commit is contained in:
Sean Sube 2023-03-28 17:34:23 -05:00
parent c0ece2453d
commit 93fcfd1422
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 18 additions and 18 deletions

View File

@ -30,6 +30,7 @@ This is an incomplete list of new and interesting features, with links to the us
- [hosted on Github Pages](https://ssube.github.io/onnx-web), from your CDN, or locally
- [persists your recent images and progress as you change tabs](docs/user-guide.md#image-history)
- queue up multiple images and retry errors
- translations available for English, French, German, and Spanish (please open an issue for more)
- supports many `diffusers` pipelines
- [txt2img](docs/user-guide.md#txt2img-tab)
- [img2img](docs/user-guide.md#img2img-tab)

View File

@ -8,15 +8,15 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTokenizer
import diffusers
import PIL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from diffusers.utils import logging
try:
@ -201,14 +201,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
@ -347,12 +347,14 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
@ -364,6 +366,7 @@ def get_weighted_text_embeddings(
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
@ -408,7 +411,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
@ -418,7 +421,7 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
@ -446,7 +449,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__(
@ -473,7 +476,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__(
vae_encoder=vae_encoder,
@ -672,7 +675,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
@ -749,10 +752,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@ -887,7 +886,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
@ -978,7 +977,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
@ -1070,7 +1069,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""