lint(api): use constant for latents scale factor
This commit is contained in:
parent
6697c2eb6a
commit
0557ab9a2e
|
@ -19,14 +19,16 @@ last_pipeline_instance = None
|
||||||
last_pipeline_options = (None, None, None)
|
last_pipeline_options = (None, None, None)
|
||||||
last_pipeline_scheduler = None
|
last_pipeline_scheduler = None
|
||||||
|
|
||||||
|
latent_channels = 4
|
||||||
|
latent_factor = 8
|
||||||
|
|
||||||
def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
|
||||||
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
'''
|
'''
|
||||||
From https://www.travelneil.com/stable-diffusion-updates.html
|
From https://www.travelneil.com/stable-diffusion-updates.html
|
||||||
'''
|
'''
|
||||||
# 1 is batch size
|
latents_shape = (batch, latent_channels, size.height // latent_factor,
|
||||||
latents_shape = (1, 4, size.height // 8, size.width // 8)
|
size.width // latent_factor)
|
||||||
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||||
return image_latents
|
return image_latents
|
||||||
|
@ -34,9 +36,9 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
||||||
|
|
||||||
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
||||||
x, y, tile = dims
|
x, y, tile = dims
|
||||||
t = tile // 8
|
t = tile // latent_factor
|
||||||
x = x // 8
|
x = x // latent_factor
|
||||||
y = y // 8
|
y = y // latent_factor
|
||||||
xt = x + t
|
xt = x + t
|
||||||
yt = y + t
|
yt = y + t
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue