1
0
Fork 0

dedupe new latent resizing code

This commit is contained in:
Sean Sube 2023-12-03 11:11:23 -06:00
parent c42ca9ca38
commit 10fab12cd0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 64 additions and 33 deletions

View File

@ -5,9 +5,9 @@ import numpy as np
import torch
from PIL import Image
from ..constants import LATENT_FACTOR
from ..diffusers.load import load_pipeline
from ..diffusers.utils import (
LATENT_FACTOR,
encode_prompt,
get_latents_from_seed,
get_tile_latents,

View File

@ -1,2 +1,5 @@
ONNX_MODEL = "model.onnx"
ONNX_WEIGHTS = "weights.pb"
LATENT_FACTOR = 8
LATENT_CHANNELS = 4

View File

@ -9,11 +9,11 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
)
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
from ..constants import LATENT_FACTOR, ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import LATENT_FACTOR, expand_prompt
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, ImageParams
from ..server import ModelTypes, ServerContext
from ..torch_before_ort import InferenceSession

View File

@ -12,8 +12,6 @@ from ...server import ServerContext
logger = getLogger(__name__)
LATENT_CHANNELS = 4
class VAEWrapper(object):
def __init__(

View File

@ -28,13 +28,14 @@ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer
from ...chain.tile import make_tile_mask
from ...constants import LATENT_CHANNELS, LATENT_FACTOR
from ...params import Size
from ..utils import (
LATENT_CHANNELS,
LATENT_FACTOR,
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.get_logger(__name__)
@ -563,13 +564,13 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros((latents.shape[0], latents.shape[1], *resize))
value = np.zeros((latents.shape[0], latents.shape[1], *resize))
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
generator.randint(np.iinfo(np.int32).max),
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
@ -726,7 +727,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:(height // 8), 0:(width // 8)]
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents
@ -975,13 +978,13 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros((latents.shape[0], latents.shape[1], *resize))
value = np.zeros((latents.shape[0], latents.shape[1], *resize))
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
generator.randint(np.iinfo(np.int32).max),
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
@ -1041,7 +1044,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:(height // 8), 0:(width // 8)]
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
@ -1294,13 +1299,13 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros((latents.shape[0], latents.shape[1], *resize))
value = np.zeros((latents.shape[0], latents.shape[1], *resize))
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
generator.randint(np.iinfo(np.int32).max),
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
@ -1367,7 +1372,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:(height // 8), 0:(width // 8)]
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]

View File

@ -15,8 +15,15 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from ...chain.tile import make_tile_mask
from ...constants import LATENT_FACTOR
from ...params import Size
from ..utils import LATENT_FACTOR, expand_latents, parse_regions, repair_nan
from ..utils import (
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.getLogger(__name__)
@ -387,13 +394,13 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros((latents.shape[0], latents.shape[1], *resize))
value = np.zeros((latents.shape[0], latents.shape[1], *resize))
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
generator.randint(np.iinfo(np.int32).max),
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
@ -573,7 +580,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:(height // 8), 0:(width // 8)]
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents
@ -810,12 +819,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros((latents.shape[0], latents.shape[1], *resize))
value = np.zeros((latents.shape[0], latents.shape[1], *resize))
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
latents = expand_latents(
latents,
generator.randint(np.iinfo(np.int32).max),
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
@ -889,7 +898,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:(height // 8), 0:(width // 8)]
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents

View File

@ -9,12 +9,11 @@ import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
from ..params import ImageParams, Size
logger = getLogger(__name__)
LATENT_CHANNELS = 4
LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
@ -261,6 +260,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
def random_seed(generator=None) -> int:
if generator is None:
generator = np.random
return generator.randint(np.iinfo(np.int32).max)
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""
From https://www.travelneil.com/stable-diffusion-updates.html.
@ -288,6 +294,13 @@ def expand_latents(
return extra_latents * np.float64(sigma)
def resize_latent_shape(
latents: np.ndarray,
size: Size,
) -> Tuple[int, int, int, int]:
return (latents.shape[0], latents.shape[1], size.height, size.width)
def get_tile_latents(
full_latents: np.ndarray,
seed: int,

View File

@ -1,10 +1,10 @@
from logging import getLogger
from typing import Dict, Optional, Tuple
import numpy as np
from flask import request
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.utils import random_seed
from ..params import (
Border,
DeviceParams,
@ -149,8 +149,7 @@ def build_params(
seed = int(data.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
seed = random_seed()
params = ImageParams(
model_path,