From 285c672718edfff2a5b3863390637d35b67ebbaa Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 21 Jan 2024 21:49:24 -0600 Subject: [PATCH] fix output type --- api/onnx_web/diffusers/patches/scheduler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index a7c10957..bc6e747a 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -1,7 +1,7 @@ from typing import Any, Literal import numpy as np -from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from diffusers.schedulers.scheduling_utils import SchedulerOutput from torch import FloatTensor, Tensor @@ -16,7 +16,7 @@ class SchedulerPatch: def step( self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor - ) -> DDIMSchedulerOutput: + ) -> SchedulerOutput: result = self.wrapped.step(model_output, timestep, sample) white_point = 0 @@ -28,9 +28,8 @@ class SchedulerPatch: result.prev_sample, white_point, black_point, center_line, direction ) - return DDIMSchedulerOutput( + return SchedulerOutput( prev_sample=mirrored_latents, - pred_original_sample=result.pred_original_sample, )