1
0
Fork 0

fix(api): generate new latents for partial tiles

This commit is contained in:
Sean Sube 2023-07-12 21:28:07 -05:00
parent 4b8358b0c9
commit 3d4c77d5d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 7 additions and 9 deletions

View File

@ -60,7 +60,7 @@ class SourceTxt2ImgStage(BaseStage):
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
latents = get_tile_latents(latents, dims, latent_size)
latents = get_tile_latents(latents, params.seed, latent_size, dims)
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(

View File

@ -79,7 +79,7 @@ class UpscaleOutpaintStage(BaseStage):
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
latents = get_tile_latents(latents, dims, latent_size)
latents = get_tile_latents(latents, params.seed, latent_size, dims)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")

View File

@ -268,8 +268,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
def get_tile_latents(
full_latents: np.ndarray,
dims: Tuple[int, int, int],
seed: int,
size: Size,
dims: Tuple[int, int, int],
) -> np.ndarray:
x, y, tile = dims
t = tile // LATENT_FACTOR
@ -284,12 +285,9 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
px = mx - tile_latents.shape[3]
py = my - tile_latents.shape[2]
tile_latents = np.pad(
tile_latents, ((0, 0), (0, 0), (0, py), (0, px)), mode="reflect"
)
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[:, :, 0:t, 0:t] = tile_latents
tile_latents = extra_latents
return tile_latents