fix(api): generate new latents for partial tiles
This commit is contained in:
parent
4b8358b0c9
commit
3d4c77d5d0
|
@ -60,7 +60,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
if latents is None:
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
else:
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
latents = get_tile_latents(latents, params.seed, latent_size, dims)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
pipe = load_pipeline(
|
||||
|
|
|
@ -79,7 +79,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
if latents is None:
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
else:
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
latents = get_tile_latents(latents, params.seed, latent_size, dims)
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
|
|
|
@ -268,8 +268,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
|||
|
||||
def get_tile_latents(
|
||||
full_latents: np.ndarray,
|
||||
dims: Tuple[int, int, int],
|
||||
seed: int,
|
||||
size: Size,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> np.ndarray:
|
||||
x, y, tile = dims
|
||||
t = tile // LATENT_FACTOR
|
||||
|
@ -284,12 +285,9 @@ def get_tile_latents(
|
|||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||
px = mx - tile_latents.shape[3]
|
||||
py = my - tile_latents.shape[2]
|
||||
|
||||
tile_latents = np.pad(
|
||||
tile_latents, ((0, 0), (0, 0), (0, py), (0, px)), mode="reflect"
|
||||
)
|
||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
||||
extra_latents[:, :, 0:t, 0:t] = tile_latents
|
||||
tile_latents = extra_latents
|
||||
|
||||
return tile_latents
|
||||
|
||||
|
|
Loading…
Reference in New Issue