From 8a94cdb385a1808ba172bdb95cffb7c913524340 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 10 Nov 2023 18:37:42 -0600 Subject: [PATCH] feat(api): add tokens to reseed region --- api/onnx_web/chain/source_txt2img.py | 16 ++++++++++++++++ api/onnx_web/diffusers/utils.py | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 4ad3885c..8a4dd61b 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -11,6 +11,7 @@ from ..diffusers.utils import ( get_latents_from_seed, get_tile_latents, parse_prompt, + parse_reseed, slice_prompt, ) from ..params import ImageParams, Size, SizeChart, StageParams @@ -76,6 +77,21 @@ class SourceTxt2ImgStage(BaseStage): else: latents = get_tile_latents(latents, int(params.seed), latent_size, dims) + # reseed latents as needed + prompt, reseed = parse_reseed(prompt) + for top, left, bottom, right, region_seed in reseed: + logger.debug( + "reseed latent region: [:, :, %s:%s, %s:%s] with %s", + top, + left, + bottom, + right, + region_seed, + ) + latents[ + :, :, top // 8 : bottom // 8, left // 8 : right // 8 + ] = get_latents_from_seed(region_seed, latent_size, params.batch) + pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( server, diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 2e644a77..b9a43900 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -24,6 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") REGION_TOKEN = compile( r"\]+)\>" ) +RESEED_TOKEN = compile(r"\") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") @@ -475,3 +476,21 @@ def parse_region_group(group) -> Region: def parse_regions(prompt: str) -> Tuple[str, List[Region]]: return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group) + + +Reseed = Tuple[int, int, int, int, int] + + +def parse_reseed_group(group) -> Region: + top, left, bottom, right, seed = group + return ( + int(top), + int(left), + int(bottom), + int(right), + int(seed), + ) + + +def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]: + return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)