fix output type
This commit is contained in:
parent
40dac93e18
commit
285c672718
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ class SchedulerPatch:
|
|||
|
||||
def step(
|
||||
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
||||
) -> DDIMSchedulerOutput:
|
||||
) -> SchedulerOutput:
|
||||
result = self.wrapped.step(model_output, timestep, sample)
|
||||
|
||||
white_point = 0
|
||||
|
@ -28,9 +28,8 @@ class SchedulerPatch:
|
|||
result.prev_sample, white_point, black_point, center_line, direction
|
||||
)
|
||||
|
||||
return DDIMSchedulerOutput(
|
||||
return SchedulerOutput(
|
||||
prev_sample=mirrored_latents,
|
||||
pred_original_sample=result.pred_original_sample,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue