feat(api): add tokens to reseed region
This commit is contained in:
parent
01d8aabc42
commit
8a94cdb385
|
@ -11,6 +11,7 @@ from ..diffusers.utils import (
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
get_tile_latents,
|
get_tile_latents,
|
||||||
parse_prompt,
|
parse_prompt,
|
||||||
|
parse_reseed,
|
||||||
slice_prompt,
|
slice_prompt,
|
||||||
)
|
)
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
|
@ -76,6 +77,21 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
else:
|
else:
|
||||||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||||
|
|
||||||
|
# reseed latents as needed
|
||||||
|
prompt, reseed = parse_reseed(prompt)
|
||||||
|
for top, left, bottom, right, region_seed in reseed:
|
||||||
|
logger.debug(
|
||||||
|
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||||
|
top,
|
||||||
|
left,
|
||||||
|
bottom,
|
||||||
|
right,
|
||||||
|
region_seed,
|
||||||
|
)
|
||||||
|
latents[
|
||||||
|
:, :, top // 8 : bottom // 8, left // 8 : right // 8
|
||||||
|
] = get_latents_from_seed(region_seed, latent_size, params.batch)
|
||||||
|
|
||||||
pipe_type = params.get_valid_pipeline("txt2img")
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
|
|
|
@ -24,6 +24,7 @@ WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
REGION_TOKEN = compile(
|
REGION_TOKEN = compile(
|
||||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
|
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
|
||||||
)
|
)
|
||||||
|
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
@ -475,3 +476,21 @@ def parse_region_group(group) -> Region:
|
||||||
|
|
||||||
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||||
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
|
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
|
||||||
|
|
||||||
|
|
||||||
|
Reseed = Tuple[int, int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_reseed_group(group) -> Region:
|
||||||
|
top, left, bottom, right, seed = group
|
||||||
|
return (
|
||||||
|
int(top),
|
||||||
|
int(left),
|
||||||
|
int(bottom),
|
||||||
|
int(right),
|
||||||
|
int(seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
||||||
|
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
||||||
|
|
Loading…
Reference in New Issue