diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index 7bf2c384..0d7701e2 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -37,7 +37,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray: From https://www.travelneil.com/stable-diffusion-updates.html ''' # 1 is batch size - latents_shape = (1, 4, size.height // 8, size.width // 8) + latents_shape = (1, 4, size.width // 8, size.height // 8) # Gotta use numpy instead of torch, because torch's randn() doesn't support DML rng = np.random.default_rng(seed) image_latents = rng.standard_normal(latents_shape).astype(np.float32)