fix(api): allow random seed in reseed regions
This commit is contained in:
parent
c0a4fb6cad
commit
798fa5fc6d
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
from ..diffusers.utils import (
|
from ..diffusers.utils import (
|
||||||
|
LATENT_FACTOR,
|
||||||
encode_prompt,
|
encode_prompt,
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
get_tile_latents,
|
get_tile_latents,
|
||||||
|
@ -78,8 +79,12 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||||
|
|
||||||
# reseed latents as needed
|
# reseed latents as needed
|
||||||
|
reseed_rng = np.random.default_rng(params.seed)
|
||||||
prompt, reseed = parse_reseed(prompt)
|
prompt, reseed = parse_reseed(prompt)
|
||||||
for top, left, bottom, right, region_seed in reseed:
|
for top, left, bottom, right, region_seed in reseed:
|
||||||
|
if region_seed == -1:
|
||||||
|
region_seed = reseed_rng.integers(2**32)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||||
top,
|
top,
|
||||||
|
@ -89,8 +94,13 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
region_seed,
|
region_seed,
|
||||||
)
|
)
|
||||||
latents[
|
latents[
|
||||||
:, :, top // 8 : bottom // 8, left // 8 : right // 8
|
:,
|
||||||
] = get_latents_from_seed(region_seed, Size(right - left, bottom - top), params.batch)
|
:,
|
||||||
|
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
|
||||||
|
left // LATENT_FACTOR : right // LATENT_FACTOR,
|
||||||
|
] = get_latents_from_seed(
|
||||||
|
region_seed, Size(right - left, bottom - top), params.batch
|
||||||
|
)
|
||||||
|
|
||||||
pipe_type = params.get_valid_pipeline("txt2img")
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ..constants import ONNX_MODEL
|
||||||
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ..diffusers.utils import expand_prompt
|
from ..diffusers.utils import LATENT_FACTOR, expand_prompt
|
||||||
from ..params import DeviceParams, ImageParams
|
from ..params import DeviceParams, ImageParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..torch_before_ort import InferenceSession
|
from ..torch_before_ort import InferenceSession
|
||||||
|
@ -264,11 +264,13 @@ def load_pipeline(
|
||||||
if hasattr(pipe, vae):
|
if hasattr(pipe, vae):
|
||||||
vae_model = getattr(pipe, vae)
|
vae_model = getattr(pipe, vae)
|
||||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||||
vae_model.set_window_size(params.vae_tile // 8, params.vae_overlap)
|
vae_model.set_window_size(
|
||||||
|
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||||
|
)
|
||||||
|
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
|
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||||
params.unet_tile,
|
params.unet_tile,
|
||||||
|
@ -276,7 +278,7 @@ def load_pipeline(
|
||||||
params.vae_tile,
|
params.vae_tile,
|
||||||
params.vae_overlap,
|
params.vae_overlap,
|
||||||
)
|
)
|
||||||
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
|
||||||
|
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||||
|
|
||||||
from onnx_web.chain.tile import make_tile_mask
|
from onnx_web.chain.tile import make_tile_mask
|
||||||
|
|
||||||
from ..utils import parse_regions
|
from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -512,7 +512,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
latents_dtype = prompt_embeds.dtype
|
latents_dtype = prompt_embeds.dtype
|
||||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
latents_shape = (
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
LATENT_CHANNELS,
|
||||||
|
height // LATENT_FACTOR,
|
||||||
|
width // LATENT_FACTOR,
|
||||||
|
)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||||
elif latents.shape != latents_shape:
|
elif latents.shape != latents_shape:
|
||||||
|
@ -612,10 +617,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert coordinates to latent space
|
# convert coordinates to latent space
|
||||||
h_start = top // 8
|
h_start = top // LATENT_FACTOR
|
||||||
h_end = bottom // 8
|
h_end = bottom // LATENT_FACTOR
|
||||||
w_start = left // 8
|
w_start = left // LATENT_FACTOR
|
||||||
w_end = right // 8
|
w_end = right // LATENT_FACTOR
|
||||||
|
|
||||||
# get the latents corresponding to the current view coordinates
|
# get the latents corresponding to the current view coordinates
|
||||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
@ -1170,8 +1175,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
num_channels_latents,
|
num_channels_latents,
|
||||||
height // 8,
|
height // LATENT_FACTOR,
|
||||||
width // 8,
|
width // LATENT_FACTOR,
|
||||||
)
|
)
|
||||||
latents_dtype = prompt_embeds.dtype
|
latents_dtype = prompt_embeds.dtype
|
||||||
if latents is None:
|
if latents is None:
|
||||||
|
|
|
@ -14,7 +14,7 @@ from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise
|
||||||
|
|
||||||
from onnx_web.chain.tile import make_tile_mask
|
from onnx_web.chain.tile import make_tile_mask
|
||||||
|
|
||||||
from ..utils import parse_regions
|
from ..utils import LATENT_FACTOR, parse_regions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -457,10 +457,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert coordinates to latent space
|
# convert coordinates to latent space
|
||||||
h_start = top // 8
|
h_start = top // LATENT_FACTOR
|
||||||
h_end = bottom // 8
|
h_end = bottom // LATENT_FACTOR
|
||||||
w_start = left // 8
|
w_start = left // LATENT_FACTOR
|
||||||
w_end = right // 8
|
w_end = right // LATENT_FACTOR
|
||||||
|
|
||||||
# get the latents corresponding to the current view coordinates
|
# get the latents corresponding to the current view coordinates
|
||||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
|
@ -24,7 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
REGION_TOKEN = compile(
|
REGION_TOKEN = compile(
|
||||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
|
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
|
||||||
)
|
)
|
||||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>")
|
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
|
Loading…
Reference in New Issue