1
0
Fork 0

tests for prompt slicing, latent tiling

This commit is contained in:
Sean Sube 2023-09-16 12:44:54 -05:00
parent 466a28c13b
commit 8a97fca6d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 44 additions and 3 deletions

View File

@ -9,7 +9,10 @@ from onnx_web.diffusers.utils import (
get_latents_from_seed, get_latents_from_seed,
get_loras_from_prompt, get_loras_from_prompt,
get_scaled_latents, get_scaled_latents,
get_tile_latents,
get_tokens_from_prompt, get_tokens_from_prompt,
pop_random,
slice_prompt,
) )
from onnx_web.params import Size from onnx_web.params import Size
@ -64,10 +67,19 @@ class TestLatentsFromSeed(unittest.TestCase):
class TestTileLatents(unittest.TestCase): class TestTileLatents(unittest.TestCase):
def test_full_tile(self): def test_full_tile(self):
pass partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
def test_partial_tile(self): def test_contract_tile(self):
pass partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
self.assertEqual(full.shape, (1, 1, 4, 4))
def test_expand_tile(self):
partial = np.zeros((1, 1, 32, 32))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
class TestScaledLatents(unittest.TestCase): class TestScaledLatents(unittest.TestCase):
def test_scale_up(self): def test_scale_up(self):
@ -84,3 +96,32 @@ class TestScaledLatents(unittest.TestCase):
latents[0, 0, 1, 0] + latents[0, 0, 1, 0] +
latents[0, 0, 1, 1] latents[0, 0, 1, 1]
) / 4, scaled[0, 0, 0, 0]) ) / 4, scaled[0, 0, 0, 0])
class TestReplaceWildcards(unittest.TestCase):
pass
class TestPopRandom(unittest.TestCase):
def test_pop(self):
items = ["1", "2", "3"]
pop_random(items)
self.assertEqual(len(items), 2)
class TestRepairNaN(unittest.TestCase):
def test_unchanged(self):
pass
def test_missing(self):
pass
class TestSlicePrompt(unittest.TestCase):
def test_slice_no_delimiter(self):
slice = slice_prompt("foo", 1)
self.assertEqual(slice, "foo")
def test_slice_within_range(self):
slice = slice_prompt("foo || bar", 1)
self.assertEqual(slice, " bar")
def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar")