1
0
Fork 0

convert mirrored latents back to torch

This commit is contained in:
Sean Sube 2024-01-21 21:53:45 -06:00
parent 285c672718
commit e41fa04fe9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 5 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from typing import Any, Literal
import numpy as np
import torch
from diffusers.schedulers.scheduling_utils import SchedulerOutput
from torch import FloatTensor, Tensor
@ -19,17 +20,17 @@ class SchedulerPatch:
) -> SchedulerOutput:
result = self.wrapped.step(model_output, timestep, sample)
white_point = 0
black_point = 8
white_point = result.prev_sample.shape[2] // 8
black_point = result.prev_sample.shape[2] // 4
center_line = result.prev_sample.shape[2] // 2
direction = "horizontal"
mirrored_latents = mirror_latents(
result.prev_sample, white_point, black_point, center_line, direction
result.prev_sample.numpy(), white_point, black_point, center_line, direction
)
return SchedulerOutput(
prev_sample=mirrored_latents,
prev_sample=torch.from_numpy(mirrored_latents),
)