convert mirrored latents back to torch
This commit is contained in:
parent
285c672718
commit
e41fa04fe9
|
@ -1,6 +1,7 @@
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||||
from torch import FloatTensor, Tensor
|
from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
@ -19,17 +20,17 @@ class SchedulerPatch:
|
||||||
) -> SchedulerOutput:
|
) -> SchedulerOutput:
|
||||||
result = self.wrapped.step(model_output, timestep, sample)
|
result = self.wrapped.step(model_output, timestep, sample)
|
||||||
|
|
||||||
white_point = 0
|
white_point = result.prev_sample.shape[2] // 8
|
||||||
black_point = 8
|
black_point = result.prev_sample.shape[2] // 4
|
||||||
center_line = result.prev_sample.shape[2] // 2
|
center_line = result.prev_sample.shape[2] // 2
|
||||||
direction = "horizontal"
|
direction = "horizontal"
|
||||||
|
|
||||||
mirrored_latents = mirror_latents(
|
mirrored_latents = mirror_latents(
|
||||||
result.prev_sample, white_point, black_point, center_line, direction
|
result.prev_sample.numpy(), white_point, black_point, center_line, direction
|
||||||
)
|
)
|
||||||
|
|
||||||
return SchedulerOutput(
|
return SchedulerOutput(
|
||||||
prev_sample=mirrored_latents,
|
prev_sample=torch.from_numpy(mirrored_latents),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue