1
0
Fork 0

fix(api): allow random seed in reseed regions

This commit is contained in:
Sean Sube 2023-11-11 14:37:23 -06:00
parent c0a4fb6cad
commit 798fa5fc6d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 37 additions and 20 deletions

View File

@ -7,6 +7,7 @@ from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import ( from ..diffusers.utils import (
LATENT_FACTOR,
encode_prompt, encode_prompt,
get_latents_from_seed, get_latents_from_seed,
get_tile_latents, get_tile_latents,
@ -78,8 +79,12 @@ class SourceTxt2ImgStage(BaseStage):
latents = get_tile_latents(latents, int(params.seed), latent_size, dims) latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
# reseed latents as needed # reseed latents as needed
reseed_rng = np.random.default_rng(params.seed)
prompt, reseed = parse_reseed(prompt) prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed: for top, left, bottom, right, region_seed in reseed:
if region_seed == -1:
region_seed = reseed_rng.integers(2**32)
logger.debug( logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s", "reseed latent region: [:, :, %s:%s, %s:%s] with %s",
top, top,
@ -89,8 +94,13 @@ class SourceTxt2ImgStage(BaseStage):
region_seed, region_seed,
) )
latents[ latents[
:, :, top // 8 : bottom // 8, left // 8 : right // 8 :,
] = get_latents_from_seed(region_seed, Size(right - left, bottom - top), params.batch) :,
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
left // LATENT_FACTOR : right // LATENT_FACTOR,
] = get_latents_from_seed(
region_seed, Size(right - left, bottom - top), params.batch
)
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(

View File

@ -19,7 +19,7 @@ from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt from ..diffusers.utils import LATENT_FACTOR, expand_prompt
from ..params import DeviceParams, ImageParams from ..params import DeviceParams, ImageParams
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..torch_before_ort import InferenceSession from ..torch_before_ort import InferenceSession
@ -264,11 +264,13 @@ def load_pipeline(
if hasattr(pipe, vae): if hasattr(pipe, vae):
vae_model = getattr(pipe, vae) vae_model = getattr(pipe, vae)
vae_model.set_tiled(tiled=params.tiled_vae) vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(params.vae_tile // 8, params.vae_overlap) vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
# update panorama params # update panorama params
if params.is_panorama(): if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8 unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
logger.debug( logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", "setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile, params.unet_tile,
@ -276,7 +278,7 @@ def load_pipeline(
params.vae_tile, params.vae_tile,
params.vae_overlap, params.vae_overlap,
) )
pipe.set_window_size(params.unet_tile // 8, unet_stride) pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
run_gc([device]) run_gc([device])

View File

@ -28,7 +28,7 @@ from transformers import CLIPImageProcessor, CLIPTokenizer
from onnx_web.chain.tile import make_tile_mask from onnx_web.chain.tile import make_tile_mask
from ..utils import parse_regions from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -512,7 +512,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) latents_shape = (
batch_size * num_images_per_prompt,
LATENT_CHANNELS,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape: elif latents.shape != latents_shape:
@ -612,10 +617,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
# convert coordinates to latent space # convert coordinates to latent space
h_start = top // 8 h_start = top // LATENT_FACTOR
h_end = bottom // 8 h_end = bottom // LATENT_FACTOR
w_start = left // 8 w_start = left // LATENT_FACTOR
w_end = right // 8 w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
@ -1170,8 +1175,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
latents_shape = ( latents_shape = (
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height // 8, height // LATENT_FACTOR,
width // 8, width // LATENT_FACTOR,
) )
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
if latents is None: if latents is None:

View File

@ -14,7 +14,7 @@ from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise
from onnx_web.chain.tile import make_tile_mask from onnx_web.chain.tile import make_tile_mask
from ..utils import parse_regions from ..utils import LATENT_FACTOR, parse_regions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -457,10 +457,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
# convert coordinates to latent space # convert coordinates to latent space
h_start = top // 8 h_start = top // LATENT_FACTOR
h_end = bottom // 8 h_end = bottom // LATENT_FACTOR
w_start = left // 8 w_start = left // LATENT_FACTOR
w_end = right // 8 w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]

View File

@ -24,7 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile( REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>" r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
) )
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>") RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")