fix(api): allow random seed in reseed regions
This commit is contained in:
parent
c0a4fb6cad
commit
798fa5fc6d
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import (
|
||||
LATENT_FACTOR,
|
||||
encode_prompt,
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
|
@ -78,8 +79,12 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||
|
||||
# reseed latents as needed
|
||||
reseed_rng = np.random.default_rng(params.seed)
|
||||
prompt, reseed = parse_reseed(prompt)
|
||||
for top, left, bottom, right, region_seed in reseed:
|
||||
if region_seed == -1:
|
||||
region_seed = reseed_rng.integers(2**32)
|
||||
|
||||
logger.debug(
|
||||
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||
top,
|
||||
|
@ -89,8 +94,13 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
region_seed,
|
||||
)
|
||||
latents[
|
||||
:, :, top // 8 : bottom // 8, left // 8 : right // 8
|
||||
] = get_latents_from_seed(region_seed, Size(right - left, bottom - top), params.batch)
|
||||
:,
|
||||
:,
|
||||
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
|
||||
left // LATENT_FACTOR : right // LATENT_FACTOR,
|
||||
] = get_latents_from_seed(
|
||||
region_seed, Size(right - left, bottom - top), params.batch
|
||||
)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
pipe = load_pipeline(
|
||||
|
|
|
@ -19,7 +19,7 @@ from ..constants import ONNX_MODEL
|
|||
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ..diffusers.utils import expand_prompt
|
||||
from ..diffusers.utils import LATENT_FACTOR, expand_prompt
|
||||
from ..params import DeviceParams, ImageParams
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..torch_before_ort import InferenceSession
|
||||
|
@ -264,11 +264,13 @@ def load_pipeline(
|
|||
if hasattr(pipe, vae):
|
||||
vae_model = getattr(pipe, vae)
|
||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||
vae_model.set_window_size(params.vae_tile // 8, params.vae_overlap)
|
||||
vae_model.set_window_size(
|
||||
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||
)
|
||||
|
||||
# update panorama params
|
||||
if params.is_panorama():
|
||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
|
||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
|
||||
logger.debug(
|
||||
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||
params.unet_tile,
|
||||
|
@ -276,7 +278,7 @@ def load_pipeline(
|
|||
params.vae_tile,
|
||||
params.vae_overlap,
|
||||
)
|
||||
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
||||
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
|
||||
|
||||
run_gc([device])
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from transformers import CLIPImageProcessor, CLIPTokenizer
|
|||
|
||||
from onnx_web.chain.tile import make_tile_mask
|
||||
|
||||
from ..utils import parse_regions
|
||||
from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -512,7 +512,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
LATENT_CHANNELS,
|
||||
height // LATENT_FACTOR,
|
||||
width // LATENT_FACTOR,
|
||||
)
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
elif latents.shape != latents_shape:
|
||||
|
@ -612,10 +617,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // 8
|
||||
h_end = bottom // 8
|
||||
w_start = left // 8
|
||||
w_end = right // 8
|
||||
h_start = top // LATENT_FACTOR
|
||||
h_end = bottom // LATENT_FACTOR
|
||||
w_start = left // LATENT_FACTOR
|
||||
w_end = right // LATENT_FACTOR
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
@ -1170,8 +1175,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height // 8,
|
||||
width // 8,
|
||||
height // LATENT_FACTOR,
|
||||
width // LATENT_FACTOR,
|
||||
)
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
if latents is None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise
|
|||
|
||||
from onnx_web.chain.tile import make_tile_mask
|
||||
|
||||
from ..utils import parse_regions
|
||||
from ..utils import LATENT_FACTOR, parse_regions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -457,10 +457,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // 8
|
||||
h_end = bottom // 8
|
||||
w_start = left // 8
|
||||
w_end = right // 8
|
||||
h_start = top // LATENT_FACTOR
|
||||
h_end = bottom // LATENT_FACTOR
|
||||
w_start = left // LATENT_FACTOR
|
||||
w_end = right // LATENT_FACTOR
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
|
|
@ -24,7 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
|||
REGION_TOKEN = compile(
|
||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
|
||||
)
|
||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>")
|
||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||
|
|
Loading…
Reference in New Issue