1
0
Fork 0

lint(api): use constant for latents scale factor

This commit is contained in:
Sean Sube 2023-02-01 22:21:22 -06:00
parent 6697c2eb6a
commit 0557ab9a2e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 7 deletions

View File

@ -19,14 +19,16 @@ last_pipeline_instance = None
last_pipeline_options = (None, None, None) last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None last_pipeline_scheduler = None
latent_channels = 4
latent_factor = 8
def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
''' '''
From https://www.travelneil.com/stable-diffusion-updates.html From https://www.travelneil.com/stable-diffusion-updates.html
''' '''
# 1 is batch size latents_shape = (batch, latent_channels, size.height // latent_factor,
latents_shape = (1, 4, size.height // 8, size.width // 8) size.width // latent_factor)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32) image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents return image_latents
@ -34,9 +36,9 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray: def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
x, y, tile = dims x, y, tile = dims
t = tile // 8 t = tile // latent_factor
x = x // 8 x = x // latent_factor
y = y // 8 y = y // latent_factor
xt = x + t xt = x + t
yt = y + t yt = y + t