1
0
Fork 0
onnx-web/api/tests/test_diffusers/test_scheduler.py

65 lines
2.0 KiB
Python
Raw Permalink Normal View History

2024-01-22 03:36:39 +00:00
import unittest
from unittest.mock import MagicMock
2024-01-22 03:36:39 +00:00
import numpy as np
import torch
from diffusers.schedulers.scheduling_utils import SchedulerOutput
from numpy.testing import assert_array_equal
2024-01-22 03:36:39 +00:00
from onnx_web.diffusers.patches.scheduler import (
SchedulerPatch,
linear_gradient,
mirror_latents,
)
2024-01-22 03:36:39 +00:00
class SchedulerPatchTests(unittest.TestCase):
def test_scheduler_step(self):
wrapped_scheduler = MagicMock()
wrapped_scheduler.step.return_value = SchedulerOutput(None)
scheduler = SchedulerPatch(None, wrapped_scheduler, None)
2024-01-22 03:36:39 +00:00
model_output = torch.FloatTensor([1.0, 2.0, 3.0])
timestep = torch.Tensor([0.1])
sample = torch.FloatTensor([0.5, 0.6, 0.7])
output = scheduler.step(model_output, timestep, sample)
self.assertIsInstance(output, SchedulerOutput)
2024-01-22 03:36:39 +00:00
def test_mirror_latents_horizontal(self):
latents = np.array(
[ # batch
[ # channels
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
],
]
)
2024-01-22 03:45:51 +00:00
white_point = 0
black_point = 1
2024-01-22 03:36:39 +00:00
center_line = 2
direction = "horizontal"
gradient = linear_gradient(white_point, black_point, center_line)
mirrored_latents = mirror_latents(latents, gradient, center_line, direction)
assert_array_equal(mirrored_latents, latents)
2024-01-22 03:36:39 +00:00
def test_mirror_latents_vertical(self):
latents = np.array(
[ # batch
[ # channels
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
],
]
)
2024-01-22 03:45:51 +00:00
white_point = 0
black_point = 1
2024-01-22 03:36:39 +00:00
center_line = 3
direction = "vertical"
gradient = linear_gradient(white_point, black_point, center_line)
mirrored_latents = mirror_latents(latents, gradient, center_line, direction)
assert_array_equal(
2024-01-22 03:36:39 +00:00
mirrored_latents,
[
[
[[0, 0, 0], [0, 0, 0], [10, 11, 12], [7, 8, 9]],
]
],
)