fix output type
This commit is contained in:
parent
40dac93e18
commit
285c672718
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||||
from torch import FloatTensor, Tensor
|
from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ class SchedulerPatch:
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
||||||
) -> DDIMSchedulerOutput:
|
) -> SchedulerOutput:
|
||||||
result = self.wrapped.step(model_output, timestep, sample)
|
result = self.wrapped.step(model_output, timestep, sample)
|
||||||
|
|
||||||
white_point = 0
|
white_point = 0
|
||||||
|
@ -28,9 +28,8 @@ class SchedulerPatch:
|
||||||
result.prev_sample, white_point, black_point, center_line, direction
|
result.prev_sample, white_point, black_point, center_line, direction
|
||||||
)
|
)
|
||||||
|
|
||||||
return DDIMSchedulerOutput(
|
return SchedulerOutput(
|
||||||
prev_sample=mirrored_latents,
|
prev_sample=mirrored_latents,
|
||||||
pred_original_sample=result.pred_original_sample,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue