1
0
Fork 0

lint(api): use constant for latents scale factor

This commit is contained in:
Sean Sube 2023-02-01 22:21:22 -06:00
parent 6697c2eb6a
commit 0557ab9a2e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 7 deletions

View File

@ -19,14 +19,16 @@ last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None
latent_channels = 4
latent_factor = 8
def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
'''
From https://www.travelneil.com/stable-diffusion-updates.html
'''
# 1 is batch size
latents_shape = (1, 4, size.height // 8, size.width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
latents_shape = (batch, latent_channels, size.height // latent_factor,
size.width // latent_factor)
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
@ -34,9 +36,9 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
x, y, tile = dims
t = tile // 8
x = x // 8
y = y // 8
t = tile // latent_factor
x = x // latent_factor
y = y // latent_factor
xt = x + t
yt = y + t