2023-09-16 00:16:47 +00:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from onnx_web.diffusers.utils import (
|
2023-09-16 03:06:53 +00:00
|
|
|
expand_alternative_ranges,
|
|
|
|
expand_interval_ranges,
|
|
|
|
get_inversions_from_prompt,
|
|
|
|
get_latents_from_seed,
|
|
|
|
get_loras_from_prompt,
|
|
|
|
get_scaled_latents,
|
2023-09-16 17:44:54 +00:00
|
|
|
get_tile_latents,
|
|
|
|
pop_random,
|
|
|
|
slice_prompt,
|
2023-09-16 00:16:47 +00:00
|
|
|
)
|
|
|
|
from onnx_web.params import Size
|
|
|
|
|
2023-09-16 03:06:53 +00:00
|
|
|
|
2023-09-16 00:16:47 +00:00
|
|
|
class TestExpandIntervalRanges(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_prompt_with_no_ranges(self):
|
|
|
|
prompt = "an astronaut eating a hamburger"
|
|
|
|
result = expand_interval_ranges(prompt)
|
|
|
|
self.assertEqual(prompt, result)
|
|
|
|
|
|
|
|
def test_prompt_with_range(self):
|
|
|
|
prompt = "an astronaut-{1,4} eating a hamburger"
|
|
|
|
result = expand_interval_ranges(prompt)
|
|
|
|
self.assertEqual(
|
|
|
|
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
|
|
|
|
)
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestExpandAlternativeRanges(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_prompt_with_no_ranges(self):
|
|
|
|
prompt = "an astronaut eating a hamburger"
|
|
|
|
result = expand_alternative_ranges(prompt)
|
|
|
|
self.assertEqual([prompt], result)
|
|
|
|
|
|
|
|
def test_ranges_match(self):
|
|
|
|
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
|
|
|
|
result = expand_alternative_ranges(prompt)
|
|
|
|
self.assertEqual(
|
|
|
|
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
|
|
|
|
)
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestInversionsFromPrompt(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_get_inversions(self):
|
|
|
|
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
|
|
|
|
result, tokens = get_inversions_from_prompt(prompt)
|
|
|
|
|
|
|
|
self.assertEqual(result, " an astronaut eating an embedding")
|
|
|
|
self.assertEqual(tokens, [("test", 1.0)])
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoRAsFromPrompt(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_get_loras(self):
|
|
|
|
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
|
|
|
|
result, tokens = get_loras_from_prompt(prompt)
|
|
|
|
|
|
|
|
self.assertEqual(result, " an astronaut eating a LoRA")
|
|
|
|
self.assertEqual(tokens, [("test", 1.0)])
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLatentsFromSeed(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_batch_size(self):
|
|
|
|
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
|
|
|
|
self.assertEqual(latents.shape, (4, 4, 8, 8))
|
|
|
|
|
|
|
|
def test_consistency(self):
|
|
|
|
latents1 = get_latents_from_seed(1, Size(64, 64))
|
|
|
|
latents2 = get_latents_from_seed(1, Size(64, 64))
|
|
|
|
self.assertTrue(np.array_equal(latents1, latents2))
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestTileLatents(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_full_tile(self):
|
|
|
|
partial = np.zeros((1, 1, 64, 64))
|
|
|
|
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
|
|
|
self.assertEqual(full.shape, (1, 1, 8, 8))
|
|
|
|
|
|
|
|
def test_contract_tile(self):
|
|
|
|
partial = np.zeros((1, 1, 64, 64))
|
|
|
|
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
|
|
|
|
self.assertEqual(full.shape, (1, 1, 4, 4))
|
2023-09-16 00:16:47 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_expand_tile(self):
|
|
|
|
partial = np.zeros((1, 1, 32, 32))
|
|
|
|
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
|
|
|
self.assertEqual(full.shape, (1, 1, 8, 8))
|
2023-09-16 17:44:54 +00:00
|
|
|
|
2023-09-16 00:16:47 +00:00
|
|
|
|
|
|
|
class TestScaledLatents(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_scale_up(self):
|
|
|
|
latents = get_latents_from_seed(1, Size(16, 16))
|
|
|
|
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
|
|
|
|
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
|
|
|
|
|
|
|
|
def test_scale_down(self):
|
|
|
|
latents = get_latents_from_seed(1, Size(16, 16))
|
|
|
|
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
|
|
|
|
self.assertEqual(
|
|
|
|
(
|
|
|
|
latents[0, 0, 0, 0]
|
|
|
|
+ latents[0, 0, 0, 1]
|
|
|
|
+ latents[0, 0, 1, 0]
|
|
|
|
+ latents[0, 0, 1, 1]
|
|
|
|
)
|
|
|
|
/ 4,
|
|
|
|
scaled[0, 0, 0, 0],
|
|
|
|
)
|
|
|
|
|
2023-09-16 17:44:54 +00:00
|
|
|
|
|
|
|
class TestReplaceWildcards(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
pass
|
|
|
|
|
2023-09-16 17:44:54 +00:00
|
|
|
|
|
|
|
class TestPopRandom(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_pop(self):
|
|
|
|
items = ["1", "2", "3"]
|
|
|
|
pop_random(items)
|
|
|
|
self.assertEqual(len(items), 2)
|
|
|
|
|
2023-09-16 17:44:54 +00:00
|
|
|
|
|
|
|
class TestRepairNaN(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_unchanged(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_missing(self):
|
|
|
|
pass
|
2023-09-16 17:44:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestSlicePrompt(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_slice_no_delimiter(self):
|
|
|
|
slice = slice_prompt("foo", 1)
|
|
|
|
self.assertEqual(slice, "foo")
|
2023-09-16 17:44:54 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_slice_within_range(self):
|
|
|
|
slice = slice_prompt("foo || bar", 1)
|
|
|
|
self.assertEqual(slice, " bar")
|
2023-09-16 17:44:54 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_slice_outside_range(self):
|
|
|
|
slice = slice_prompt("foo || bar", 9)
|
|
|
|
self.assertEqual(slice, " bar")
|