1
0
Fork 0

fix output type

This commit is contained in:
Sean Sube 2024-01-21 21:49:24 -06:00
parent 40dac93e18
commit 285c672718
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 4 deletions

View File

@ -1,7 +1,7 @@
from typing import Any, Literal from typing import Any, Literal
import numpy as np import numpy as np
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerOutput
from torch import FloatTensor, Tensor from torch import FloatTensor, Tensor
@ -16,7 +16,7 @@ class SchedulerPatch:
def step( def step(
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
) -> DDIMSchedulerOutput: ) -> SchedulerOutput:
result = self.wrapped.step(model_output, timestep, sample) result = self.wrapped.step(model_output, timestep, sample)
white_point = 0 white_point = 0
@ -28,9 +28,8 @@ class SchedulerPatch:
result.prev_sample, white_point, black_point, center_line, direction result.prev_sample, white_point, black_point, center_line, direction
) )
return DDIMSchedulerOutput( return SchedulerOutput(
prev_sample=mirrored_latents, prev_sample=mirrored_latents,
pred_original_sample=result.pred_original_sample,
) )