From e41fa04fe9b8a39680c574b3d4b2159600f89432 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 21 Jan 2024 21:53:45 -0600 Subject: [PATCH] convert mirrored latents back to torch --- api/onnx_web/diffusers/patches/scheduler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index bc6e747a..2f9a358f 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -1,6 +1,7 @@ from typing import Any, Literal import numpy as np +import torch from diffusers.schedulers.scheduling_utils import SchedulerOutput from torch import FloatTensor, Tensor @@ -19,17 +20,17 @@ class SchedulerPatch: ) -> SchedulerOutput: result = self.wrapped.step(model_output, timestep, sample) - white_point = 0 - black_point = 8 + white_point = result.prev_sample.shape[2] // 8 + black_point = result.prev_sample.shape[2] // 4 center_line = result.prev_sample.shape[2] // 2 direction = "horizontal" mirrored_latents = mirror_latents( - result.prev_sample, white_point, black_point, center_line, direction + result.prev_sample.numpy(), white_point, black_point, center_line, direction ) return SchedulerOutput( - prev_sample=mirrored_latents, + prev_sample=torch.from_numpy(mirrored_latents), )