1
0
Fork 0

feat(api): add tokens to reseed region

This commit is contained in:
Sean Sube 2023-11-10 18:37:42 -06:00
parent 01d8aabc42
commit 8a94cdb385
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 35 additions and 0 deletions

View File

@ -11,6 +11,7 @@ from ..diffusers.utils import (
get_latents_from_seed, get_latents_from_seed,
get_tile_latents, get_tile_latents,
parse_prompt, parse_prompt,
parse_reseed,
slice_prompt, slice_prompt,
) )
from ..params import ImageParams, Size, SizeChart, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
@ -76,6 +77,21 @@ class SourceTxt2ImgStage(BaseStage):
else: else:
latents = get_tile_latents(latents, int(params.seed), latent_size, dims) latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
# reseed latents as needed
prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed:
logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
top,
left,
bottom,
right,
region_seed,
)
latents[
:, :, top // 8 : bottom // 8, left // 8 : right // 8
] = get_latents_from_seed(region_seed, latent_size, params.batch)
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(
server, server,

View File

@ -24,6 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile( REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>" r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
) )
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -475,3 +476,21 @@ def parse_region_group(group) -> Region:
def parse_regions(prompt: str) -> Tuple[str, List[Region]]: def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group) return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
Reseed = Tuple[int, int, int, int, int]
def parse_reseed_group(group) -> Region:
top, left, bottom, right, seed = group
return (
int(top),
int(left),
int(bottom),
int(right),
int(seed),
)
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)