convert mirrored latents back to torch
This commit is contained in:
parent
285c672718
commit
e41fa04fe9
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
|
@ -19,17 +20,17 @@ class SchedulerPatch:
|
|||
) -> SchedulerOutput:
|
||||
result = self.wrapped.step(model_output, timestep, sample)
|
||||
|
||||
white_point = 0
|
||||
black_point = 8
|
||||
white_point = result.prev_sample.shape[2] // 8
|
||||
black_point = result.prev_sample.shape[2] // 4
|
||||
center_line = result.prev_sample.shape[2] // 2
|
||||
direction = "horizontal"
|
||||
|
||||
mirrored_latents = mirror_latents(
|
||||
result.prev_sample, white_point, black_point, center_line, direction
|
||||
result.prev_sample.numpy(), white_point, black_point, center_line, direction
|
||||
)
|
||||
|
||||
return SchedulerOutput(
|
||||
prev_sample=mirrored_latents,
|
||||
prev_sample=torch.from_numpy(mirrored_latents),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue