2024-01-22 03:36:39 +00:00
|
|
|
import unittest
|
2024-01-28 22:12:24 +00:00
|
|
|
from unittest.mock import MagicMock
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2024-01-28 22:12:24 +00:00
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
|
|
|
from numpy.testing import assert_array_equal
|
2024-01-22 03:36:39 +00:00
|
|
|
|
2024-01-28 22:12:24 +00:00
|
|
|
from onnx_web.diffusers.patches.scheduler import (
|
|
|
|
SchedulerPatch,
|
|
|
|
linear_gradient,
|
|
|
|
mirror_latents,
|
|
|
|
)
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SchedulerPatchTests(unittest.TestCase):
|
|
|
|
def test_scheduler_step(self):
|
2024-01-28 22:12:24 +00:00
|
|
|
wrapped_scheduler = MagicMock()
|
|
|
|
wrapped_scheduler.step.return_value = SchedulerOutput(None)
|
|
|
|
scheduler = SchedulerPatch(None, wrapped_scheduler, None)
|
2024-01-22 03:36:39 +00:00
|
|
|
model_output = torch.FloatTensor([1.0, 2.0, 3.0])
|
|
|
|
timestep = torch.Tensor([0.1])
|
|
|
|
sample = torch.FloatTensor([0.5, 0.6, 0.7])
|
|
|
|
output = scheduler.step(model_output, timestep, sample)
|
2024-01-28 22:12:24 +00:00
|
|
|
self.assertIsInstance(output, SchedulerOutput)
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
def test_mirror_latents_horizontal(self):
|
|
|
|
latents = np.array(
|
|
|
|
[ # batch
|
|
|
|
[ # channels
|
|
|
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
|
|
|
|
],
|
|
|
|
]
|
|
|
|
)
|
2024-01-22 03:45:51 +00:00
|
|
|
white_point = 0
|
|
|
|
black_point = 1
|
2024-01-22 03:36:39 +00:00
|
|
|
center_line = 2
|
|
|
|
direction = "horizontal"
|
2024-01-28 22:12:24 +00:00
|
|
|
gradient = linear_gradient(white_point, black_point, center_line)
|
|
|
|
mirrored_latents = mirror_latents(latents, gradient, center_line, direction)
|
|
|
|
assert_array_equal(mirrored_latents, latents)
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
def test_mirror_latents_vertical(self):
|
|
|
|
latents = np.array(
|
|
|
|
[ # batch
|
|
|
|
[ # channels
|
|
|
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
|
|
|
|
],
|
|
|
|
]
|
|
|
|
)
|
2024-01-22 03:45:51 +00:00
|
|
|
white_point = 0
|
|
|
|
black_point = 1
|
2024-01-22 03:36:39 +00:00
|
|
|
center_line = 3
|
|
|
|
direction = "vertical"
|
2024-01-28 22:12:24 +00:00
|
|
|
gradient = linear_gradient(white_point, black_point, center_line)
|
|
|
|
mirrored_latents = mirror_latents(latents, gradient, center_line, direction)
|
|
|
|
assert_array_equal(
|
2024-01-22 03:36:39 +00:00
|
|
|
mirrored_latents,
|
|
|
|
[
|
|
|
|
[
|
|
|
|
[[0, 0, 0], [0, 0, 0], [10, 11, 12], [7, 8, 9]],
|
|
|
|
]
|
|
|
|
],
|
|
|
|
)
|