tests for prompt slicing, latent tiling
This commit is contained in:
parent
466a28c13b
commit
8a97fca6d0
|
@ -9,7 +9,10 @@ from onnx_web.diffusers.utils import (
|
|||
get_latents_from_seed,
|
||||
get_loras_from_prompt,
|
||||
get_scaled_latents,
|
||||
get_tile_latents,
|
||||
get_tokens_from_prompt,
|
||||
pop_random,
|
||||
slice_prompt,
|
||||
)
|
||||
from onnx_web.params import Size
|
||||
|
||||
|
@ -64,10 +67,19 @@ class TestLatentsFromSeed(unittest.TestCase):
|
|||
|
||||
class TestTileLatents(unittest.TestCase):
|
||||
def test_full_tile(self):
|
||||
pass
|
||||
partial = np.zeros((1, 1, 64, 64))
|
||||
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
||||
self.assertEqual(full.shape, (1, 1, 8, 8))
|
||||
|
||||
def test_partial_tile(self):
|
||||
pass
|
||||
def test_contract_tile(self):
|
||||
partial = np.zeros((1, 1, 64, 64))
|
||||
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
|
||||
self.assertEqual(full.shape, (1, 1, 4, 4))
|
||||
|
||||
def test_expand_tile(self):
|
||||
partial = np.zeros((1, 1, 32, 32))
|
||||
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
||||
self.assertEqual(full.shape, (1, 1, 8, 8))
|
||||
|
||||
class TestScaledLatents(unittest.TestCase):
|
||||
def test_scale_up(self):
|
||||
|
@ -84,3 +96,32 @@ class TestScaledLatents(unittest.TestCase):
|
|||
latents[0, 0, 1, 0] +
|
||||
latents[0, 0, 1, 1]
|
||||
) / 4, scaled[0, 0, 0, 0])
|
||||
|
||||
class TestReplaceWildcards(unittest.TestCase):
|
||||
pass
|
||||
|
||||
class TestPopRandom(unittest.TestCase):
|
||||
def test_pop(self):
|
||||
items = ["1", "2", "3"]
|
||||
pop_random(items)
|
||||
self.assertEqual(len(items), 2)
|
||||
|
||||
class TestRepairNaN(unittest.TestCase):
|
||||
def test_unchanged(self):
|
||||
pass
|
||||
|
||||
def test_missing(self):
|
||||
pass
|
||||
|
||||
class TestSlicePrompt(unittest.TestCase):
|
||||
def test_slice_no_delimiter(self):
|
||||
slice = slice_prompt("foo", 1)
|
||||
self.assertEqual(slice, "foo")
|
||||
|
||||
def test_slice_within_range(self):
|
||||
slice = slice_prompt("foo || bar", 1)
|
||||
self.assertEqual(slice, " bar")
|
||||
|
||||
def test_slice_outside_range(self):
|
||||
slice = slice_prompt("foo || bar", 9)
|
||||
self.assertEqual(slice, " bar")
|
Loading…
Reference in New Issue