1
0
Fork 0
onnx-web/api/tests/test_diffusers/test_scheduler.py

59 lines
1.7 KiB
Python
Raw Permalink Normal View History

2024-01-22 03:36:39 +00:00
import unittest
import numpy as np
import torch
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from onnx_web.diffusers.patches.scheduler import SchedulerPatch, mirror_latents
class SchedulerPatchTests(unittest.TestCase):
def test_scheduler_step(self):
scheduler = SchedulerPatch(None)
model_output = torch.FloatTensor([1.0, 2.0, 3.0])
timestep = torch.Tensor([0.1])
sample = torch.FloatTensor([0.5, 0.6, 0.7])
output = scheduler.step(model_output, timestep, sample)
assert isinstance(output, DDIMSchedulerOutput)
def test_mirror_latents_horizontal(self):
latents = np.array(
[ # batch
[ # channels
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
],
]
)
2024-01-22 03:45:51 +00:00
white_point = 0
black_point = 1
2024-01-22 03:36:39 +00:00
center_line = 2
direction = "horizontal"
mirrored_latents = mirror_latents(
2024-01-22 03:45:51 +00:00
latents, white_point, black_point, center_line, direction
2024-01-22 03:36:39 +00:00
)
assert np.array_equal(mirrored_latents, latents)
def test_mirror_latents_vertical(self):
latents = np.array(
[ # batch
[ # channels
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
],
]
)
2024-01-22 03:45:51 +00:00
white_point = 0
black_point = 1
2024-01-22 03:36:39 +00:00
center_line = 3
direction = "vertical"
mirrored_latents = mirror_latents(
2024-01-22 03:45:51 +00:00
latents, white_point, black_point, center_line, direction
2024-01-22 03:36:39 +00:00
)
assert np.array_equal(
mirrored_latents,
[
[
[[0, 0, 0], [0, 0, 0], [10, 11, 12], [7, 8, 9]],
]
],
)