1
0
Fork 0

apply lint to tests, test highres

This commit is contained in:
Sean Sube 2023-11-19 23:18:57 -06:00
parent 4691e80744
commit 65912c5a4a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
26 changed files with 1506 additions and 1127 deletions

View File

@ -33,13 +33,17 @@ package-upload:
lint-check: lint-check:
black --check onnx_web/ black --check onnx_web/
black --check tests/
flake8 onnx_web flake8 onnx_web
flake8 tests
isort --check-only --skip __init__.py --filter-files onnx_web isort --check-only --skip __init__.py --filter-files onnx_web
isort --check-only --skip __init__.py --filter-files tests isort --check-only --skip __init__.py --filter-files tests
lint-fix: lint-fix:
black onnx_web/ black onnx_web/
black tests/
flake8 onnx_web flake8 onnx_web
flake8 tests
isort --skip __init__.py --filter-files onnx_web isort --skip __init__.py --filter-files onnx_web
isort --skip __init__.py --filter-files tests isort --skip __init__.py --filter-files tests

View File

@ -19,6 +19,14 @@ class StageResult:
def empty(): def empty():
return StageResult(images=[]) return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None: def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None: if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result") raise ValueError("stages must only return one type of result")

View File

@ -25,9 +25,7 @@ class TileCallback(Protocol):
Definition for a tile job function. Definition for a tile job function.
""" """
def __call__( def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
self, image: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
""" """
Run this stage against a single tile. Run this stage against a single tile.
""" """
@ -319,6 +317,9 @@ def process_tile_stack(
if mask: if mask:
tile_mask = mask.crop((left, top, right, bottom)) tile_mask = mask.crop((left, top, right, bottom))
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
for image_filter in filters: for image_filter in filters:
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))

View File

@ -48,6 +48,7 @@ def main():
# debug options # debug options
if server.debug: if server.debug:
import debugpy import debugpy
debugpy.listen(5678) debugpy.listen(5678)
logger.warning("waiting for debugger") logger.warning("waiting for debugger")
debugpy.wait_for_client() debugpy.wait_for_client()

View File

@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase): class BlendGridStageTests(unittest.TestCase):
def test_stage(self): def test_stage(self):
stage = BlendGridStage() stage = BlendGridStage()
sources = StageResult(images=[ sources = StageResult(
Image.new("RGB", (64, 64), "black"), images=[
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "black"),
]) Image.new("RGB", (64, 64), "white"),
]
)
result = stage.run(None, None, None, None, sources, height=2, width=2) result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5) self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0)) self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.params import DeviceParams, ImageParams from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
class BlendImg2ImgStageTests(unittest.TestCase): class BlendImg2ImgStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_stage(self): def test_stage(self):
"""
stage = BlendImg2ImgStage() stage = BlendImg2ImgStage()
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1) params = ImageParams(
server = ServerContext() TEST_MODEL_DIFFUSION_SD15,
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0) "txt2img",
"euler-a",
"an astronaut eating a hamburger",
3.0,
1,
1,
)
server = ServerContext(model_path="../models", output_path="../outputs")
worker = WorkerContext(
"test",
DeviceParams("cpu", "CPUProvider"),
None,
None,
None,
None,
None,
None,
0,
)
sources = [ sources = [
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
] ]
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))
"""
pass

View File

@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase): class BlendLinearStageTests(unittest.TestCase):
def test_stage(self): def test_stage(self):
stage = BlendLinearStage() stage = BlendLinearStage()
sources = StageResult(images=[ sources = StageResult(
Image.new("RGB", (64, 64), "black"), images=[
]) Image.new("RGB", (64, 64), "black"),
]
)
stage_source = Image.new("RGB", (64, 64), "white") stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source) result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127)) self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -2,6 +2,7 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.result import StageResult
from onnx_web.chain.tile import ( from onnx_web.chain.tile import (
complete_tile, complete_tile,
generate_tile_grid, generate_tile_grid,
@ -14,122 +15,126 @@ from onnx_web.params import Size
class TestCompleteTile(unittest.TestCase): class TestCompleteTile(unittest.TestCase):
def test_with_complete_tile(self): def test_with_complete_tile(self):
partial = Image.new("RGB", (64, 64)) partial = Image.new("RGB", (64, 64))
output = complete_tile(partial, 64) output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64)) self.assertEqual(output.size, (64, 64))
def test_with_partial_tile(self): def test_with_partial_tile(self):
partial = Image.new("RGB", (64, 32)) partial = Image.new("RGB", (64, 32))
output = complete_tile(partial, 64) output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64)) self.assertEqual(output.size, (64, 64))
def test_with_nothing(self): def test_with_nothing(self):
output = complete_tile(None, 64) output = complete_tile(None, 64)
self.assertIsNone(output) self.assertIsNone(output)
class TestNeedsTile(unittest.TestCase): class TestNeedsTile(unittest.TestCase):
def test_with_undersized_source(self): def test_with_undersized_source(self):
small = Image.new("RGB", (32, 32)) small = Image.new("RGB", (32, 32))
self.assertFalse(needs_tile(64, 64, source=small)) self.assertFalse(needs_tile(64, 64, source=small))
def test_with_oversized_source(self): def test_with_oversized_source(self):
large = Image.new("RGB", (64, 64)) large = Image.new("RGB", (64, 64))
self.assertTrue(needs_tile(32, 32, source=large)) self.assertTrue(needs_tile(32, 32, source=large))
def test_with_undersized_size(self): def test_with_undersized_size(self):
small = Size(32, 32) small = Size(32, 32)
self.assertFalse(needs_tile(64, 64, size=small)) self.assertFalse(needs_tile(64, 64, size=small))
def test_with_oversized_source(self): def test_with_oversized_size(self):
large = Size(64, 64) large = Size(64, 64)
self.assertTrue(needs_tile(32, 32, size=large)) self.assertTrue(needs_tile(32, 32, size=large))
def test_with_nothing(self): def test_with_nothing(self):
self.assertFalse(needs_tile(32, 32)) self.assertFalse(needs_tile(32, 32))
class TestTileGrads(unittest.TestCase): class TestTileGrads(unittest.TestCase):
def test_center_tile(self): def test_center_tile(self):
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64) grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
self.assertEqual(grad_x, [0, 1, 1, 0]) self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [0, 1, 1, 0]) self.assertEqual(grad_y, [0, 1, 1, 0])
def test_vertical_edge_tile(self): def test_vertical_edge_tile(self):
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8) grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
self.assertEqual(grad_x, [0, 1, 1, 0]) self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [1, 1, 1, 1]) self.assertEqual(grad_y, [1, 1, 1, 1])
def test_horizontal_edge_tile(self): def test_horizontal_edge_tile(self):
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64) grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
self.assertEqual(grad_x, [1, 1, 1, 1]) self.assertEqual(grad_x, [1, 1, 1, 1])
self.assertEqual(grad_y, [0, 1, 1, 0]) self.assertEqual(grad_y, [0, 1, 1, 0])
class TestGenerateTileGrid(unittest.TestCase): class TestGenerateTileGrid(unittest.TestCase):
def test_grid_complete(self): def test_grid_complete(self):
tiles = generate_tile_grid(16, 16, 8, 0.0) tiles = generate_tile_grid(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4) self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)])
def test_grid_no_overlap(self): def test_grid_no_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.0) tiles = generate_tile_grid(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64) self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)])
def test_grid_50_overlap(self): def test_grid_50_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.5) tiles = generate_tile_grid(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225) self.assertEqual(len(tiles), 256)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)])
class TestGenerateTileSpiral(unittest.TestCase): class TestGenerateTileSpiral(unittest.TestCase):
def test_spiral_complete(self): def test_spiral_complete(self):
tiles = generate_tile_spiral(16, 16, 8, 0.0) tiles = generate_tile_spiral(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4) self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
def test_spiral_no_overlap(self): def test_spiral_no_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.0) tiles = generate_tile_spiral(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64) self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
def test_spiral_50_overlap(self): def test_spiral_50_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.5) tiles = generate_tile_spiral(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225) self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
class TestProcessTileStack(unittest.TestCase): class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self): def test_grid_full(self):
source = Image.new("RGB", (64, 64)) source = Image.new("RGB", (64, 64))
blend = process_tile_stack(source, 32, 1, []) blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (64, 64)) self.assertEqual(blend[0].size, (64, 64))
def test_grid_partial(self): def test_grid_partial(self):
source = Image.new("RGB", (72, 72)) source = Image.new("RGB", (72, 72))
blend = process_tile_stack(source, 32, 1, []) blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (72, 72)) self.assertEqual(blend[0].size, (72, 72))

View File

@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = UpscaleHighresStage() stage = UpscaleHighresStage()
sources = StageResult.empty() sources = StageResult.empty()
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import ( from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm, blend_node_conv_gemm,
blend_node_matmul, blend_node_matmul,
blend_weights_loha, blend_weights_loha,
@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1))) weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1)) self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self): def test_3x3_kernel(self):
""" """
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4))) weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase):
) )
(slim_model, external_weights) = buffer_external_data_tensors(model) (slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer)) self.assertEqual(
len(slim_model.graph.initializer), len(model.graph.initializer)
)
self.assertEqual(len(external_weights), 1) self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase): class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self): def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"] inputs = [
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"] "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
]
for input, output in zip(inputs, outputs): for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output) self.assertEqual(fix_initializer_name(input), output)
@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase):
nodes = { nodes = {
"input_block_proj.lora_down.weight": {}, "input_block_proj.lora_down.weight": {},
} }
fixed = fix_xl_names(nodes, [ fixed = fix_xl_names(
NodeProto(name="/down_blocks_proj/MatMul"), nodes,
]) [
NodeProto(name="/down_blocks_proj/MatMul"),
],
)
self.assertEqual(fixed, { self.assertEqual(
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"], fixed,
}) {
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
},
)
def test_middle_block(self): def test_middle_block(self):
nodes = { nodes = {
"middle_block_proj.lora_down.weight": {}, "middle_block_proj.lora_down.weight": {},
} }
fixed = fix_xl_names(nodes, [ fixed = fix_xl_names(
NodeProto(name="/mid_blocks_proj/MatMul"), nodes,
]) [
NodeProto(name="/mid_blocks_proj/MatMul"),
],
)
self.assertEqual(fixed, { self.assertEqual(
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"], fixed,
}) {
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
},
)
def test_output_block(self): def test_output_block(self):
pass pass
@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase):
nodes = { nodes = {
"output_block_proj_out.lora_down.weight": {}, "output_block_proj_out.lora_down.weight": {},
} }
fixed = fix_xl_names(nodes, [ fixed = fix_xl_names(
NodeProto(name="/up_blocks_proj_out/MatMul"), nodes,
]) [
NodeProto(name="/up_blocks_proj_out/MatMul"),
],
)
self.assertEqual(fixed, { self.assertEqual(
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"], fixed,
}) {
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
},
)
class KernelSliceTests(unittest.TestCase): class KernelSliceTests(unittest.TestCase):
@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
self.assertEqual(result.shape, (4, 4)) self.assertEqual(result.shape, (4, 4))
""" """
class BlendWeightsLoRATests(unittest.TestCase): class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self): def test_blend_kernel_none(self):
model = { model = {
@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase):
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4)) self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self): def test_blend_kernel_1x1(self):
model = { model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))), "foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),

View File

@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import (
blend_embedding_embeddings, blend_embedding_embeddings,
blend_embedding_node, blend_embedding_node,
blend_embedding_parameters, blend_embedding_parameters,
blend_textual_inversions,
detect_embedding_format, detect_embedding_format,
) )
@ -18,210 +17,267 @@ TEST_DIMS = (8, 8)
TEST_DIMS_EMBEDS = (1, *TEST_DIMS) TEST_DIMS_EMBEDS = (1, *TEST_DIMS)
TEST_MODEL_EMBEDS = { TEST_MODEL_EMBEDS = {
"string_to_token": { "string_to_token": {
"test": 1, "test": 1,
}, },
"string_to_param": { "string_to_param": {
"test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), "test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, },
} }
class DetectEmbeddingFormatTests(unittest.TestCase): class DetectEmbeddingFormatTests(unittest.TestCase):
def test_concept(self): def test_concept(self):
embedding = { embedding = {
"<test>": "test", "<test>": "test",
} }
self.assertEqual(detect_embedding_format(embedding), "concept") self.assertEqual(detect_embedding_format(embedding), "concept")
def test_parameters(self): def test_parameters(self):
embedding = { embedding = {
"emb_params": "test", "emb_params": "test",
} }
self.assertEqual(detect_embedding_format(embedding), "parameters") self.assertEqual(detect_embedding_format(embedding), "parameters")
def test_embeddings(self): def test_embeddings(self):
embedding = { embedding = {
"string_to_token": "test", "string_to_token": "test",
"string_to_param": "test", "string_to_param": "test",
} }
self.assertEqual(detect_embedding_format(embedding), "embeddings") self.assertEqual(detect_embedding_format(embedding), "embeddings")
def test_unknown(self): def test_unknown(self):
embedding = { embedding = {
"what_is_this": "test", "what_is_this": "test",
} }
self.assertEqual(detect_embedding_format(embedding), None) self.assertEqual(detect_embedding_format(embedding), None)
class BlendEmbeddingConceptTests(unittest.TestCase): class BlendEmbeddingConceptTests(unittest.TestCase):
def test_existing_base_token(self): def test_existing_base_token(self):
embeds = { embeds = {
"test": np.ones(TEST_DIMS), "test": np.ones(TEST_DIMS),
} }
blend_embedding_concept(embeds, { blend_embedding_concept(
"<test>": torch.from_numpy(np.ones(TEST_DIMS)), embeds,
}, np.float32, "test", 1.0) {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2) self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self): def test_missing_base_token(self):
embeds = {} embeds = {}
blend_embedding_concept(embeds, { blend_embedding_concept(
"<test>": torch.from_numpy(np.ones(TEST_DIMS)), embeds,
}, np.float32, "test", 1.0) {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self): def test_existing_token(self):
embeds = { embeds = {
"<test>": np.ones(TEST_DIMS), "<test>": np.ones(TEST_DIMS),
} }
blend_embedding_concept(embeds, { blend_embedding_concept(
"<test>": torch.from_numpy(np.ones(TEST_DIMS)), embeds,
}, np.float32, "test", 1.0) {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"]) self.assertEqual(keys, ["<test>", "test"])
def test_missing_token(self): def test_missing_token(self):
embeds = {} embeds = {}
blend_embedding_concept(embeds, { blend_embedding_concept(
"<test>": torch.from_numpy(np.ones(TEST_DIMS)), embeds,
}, np.float32, "test", 1.0) {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"]) self.assertEqual(keys, ["<test>", "test"])
class BlendEmbeddingParametersTests(unittest.TestCase): class BlendEmbeddingParametersTests(unittest.TestCase):
def test_existing_base_token(self): def test_existing_base_token(self):
embeds = { embeds = {
"test": np.ones(TEST_DIMS), "test": np.ones(TEST_DIMS),
} }
blend_embedding_parameters(embeds, { blend_embedding_parameters(
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), embeds,
}, np.float32, "test", 1.0) {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2) self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self): def test_missing_base_token(self):
embeds = {} embeds = {}
blend_embedding_parameters(embeds, { blend_embedding_parameters(
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), embeds,
}, np.float32, "test", 1.0) {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self): def test_existing_token(self):
embeds = { embeds = {
"test": np.ones(TEST_DIMS_EMBEDS), "test": np.ones(TEST_DIMS_EMBEDS),
} }
blend_embedding_parameters(embeds, { blend_embedding_parameters(
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), embeds,
}, np.float32, "test", 1.0) {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"]) self.assertEqual(keys, ["test", "test-0", "test-all"])
def test_missing_token(self): def test_missing_token(self):
embeds = {} embeds = {}
blend_embedding_parameters(embeds, { blend_embedding_parameters(
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), embeds,
}, np.float32, "test", 1.0) {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"]) self.assertEqual(keys, ["test", "test-0", "test-all"])
class BlendEmbeddingEmbeddingsTests(unittest.TestCase): class BlendEmbeddingEmbeddingsTests(unittest.TestCase):
def test_existing_base_token(self): def test_existing_base_token(self):
embeds = { embeds = {
"test": np.ones(TEST_DIMS), "test": np.ones(TEST_DIMS),
} }
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2) self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self): def test_missing_base_token(self):
embeds = {} embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS) self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self): def test_existing_token(self):
embeds = { embeds = {
"test": np.ones(TEST_DIMS), "test": np.ones(TEST_DIMS),
} }
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"]) self.assertEqual(keys, ["test", "test-0", "test-all"])
def test_missing_token(self): def test_missing_token(self):
embeds = {} embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
keys = list(embeds.keys()) keys = list(embeds.keys())
keys.sort() keys.sort()
self.assertIn("test", embeds) self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"]) self.assertEqual(keys, ["test", "test-0", "test-all"])
class BlendEmbeddingNodeTests(unittest.TestCase): class BlendEmbeddingNodeTests(unittest.TestCase):
def test_expand_weights(self): def test_expand_weights(self):
weights = from_array(np.ones(TEST_DIMS)) weights = from_array(np.ones(TEST_DIMS))
weights.name = "text_model.embeddings.token_embedding.weight" weights.name = "text_model.embeddings.token_embedding.weight"
model = ModelProto(graph=GraphProto(initializer=[ model = ModelProto(
weights, graph=GraphProto(
])) initializer=[
weights,
]
)
)
embeds = {} embeds = {}
blend_embedding_node(model, { blend_embedding_node(
'convert_tokens_to_ids': lambda t: t, model,
}, embeds, 2) {
"convert_tokens_to_ids": lambda t: t,
},
embeds,
2,
)
result = to_array(model.graph.initializer[0]) result = to_array(model.graph.initializer[0])
self.assertEqual(len(model.graph.initializer), 1) self.assertEqual(len(model.graph.initializer), 1)
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8) self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
class BlendTextualInversionsTests(unittest.TestCase): class BlendTextualInversionsTests(unittest.TestCase):
def test_blend_multi_concept(self): def test_blend_multi_concept(self):
pass pass
def test_blend_multi_parameters(self): def test_blend_multi_parameters(self):
pass pass
def test_blend_multi_embeddings(self): def test_blend_multi_embeddings(self):
pass pass
def test_blend_multi_mixed(self): def test_blend_multi_mixed(self):
pass pass

View File

@ -13,7 +13,6 @@ from onnx_web.convert.utils import (
tuple_to_upscaling, tuple_to_upscaling,
) )
from tests.helpers import ( from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_UPSCALING_SWINIR, TEST_MODEL_UPSCALING_SWINIR,
test_needs_models, test_needs_models,
) )
@ -21,220 +20,225 @@ from tests.helpers import (
class ConversionContextTests(unittest.TestCase): class ConversionContextTests(unittest.TestCase):
def test_from_environ(self): def test_from_environ(self):
context = ConversionContext.from_environ() context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET) self.assertEqual(context.opset, DEFAULT_OPSET)
def test_map_location(self): def test_map_location(self):
context = ConversionContext.from_environ() context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu") self.assertEqual(context.map_location.type, "cpu")
class DownloadProgressTests(unittest.TestCase): class DownloadProgressTests(unittest.TestCase):
def test_download_example(self): def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")]) path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com")
class TupleToSourceTests(unittest.TestCase): class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self): def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar")) source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_list(self): def test_basic_list(self):
source = tuple_to_source(["foo", "bar"]) source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_dict(self): def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"]) source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin" source["bin"] = "bin"
# make sure this is returned as-is with extra fields # make sure this is returned as-is with extra fields
second = tuple_to_source(source) second = tuple_to_source(source)
self.assertEqual(source, second) self.assertEqual(source, second)
self.assertIn("bin", second) self.assertIn("bin", second)
class TupleToCorrectionTests(unittest.TestCase): class TupleToCorrectionTests(unittest.TestCase):
def test_basic_tuple(self): def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar")) source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_list(self): def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"]) source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_dict(self): def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"]) source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin" source["bin"] = "bin"
# make sure this is returned with extra fields # make sure this is returned with extra fields
second = tuple_to_source(source) second = tuple_to_source(source)
self.assertEqual(source, second) self.assertEqual(source, second)
self.assertIn("bin", second) self.assertIn("bin", second)
def test_scale_tuple(self): def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2]) source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_half_tuple(self): def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True]) source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_opset_tuple(self): def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14]) source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_all_tuple(self): def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14]) source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2) self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True) self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14) self.assertEqual(source["opset"], 14)
class TupleToDiffusionTests(unittest.TestCase): class TupleToDiffusionTests(unittest.TestCase):
def test_basic_tuple(self): def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar")) source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_list(self): def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"]) source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_dict(self): def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"]) source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin" source["bin"] = "bin"
# make sure this is returned with extra fields # make sure this is returned with extra fields
second = tuple_to_diffusion(source) second = tuple_to_diffusion(source)
self.assertEqual(source, second) self.assertEqual(source, second)
self.assertIn("bin", second) self.assertIn("bin", second)
def test_single_vae_tuple(self): def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True]) source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_half_tuple(self): def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True]) source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_opset_tuple(self): def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14]) source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_all_tuple(self): def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14]) source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True) self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True) self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14) self.assertEqual(source["opset"], 14)
class TupleToUpscalingTests(unittest.TestCase): class TupleToUpscalingTests(unittest.TestCase):
def test_basic_tuple(self): def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar")) source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_list(self): def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"]) source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_basic_dict(self): def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"]) source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin" source["bin"] = "bin"
# make sure this is returned with extra fields # make sure this is returned with extra fields
second = tuple_to_source(source) second = tuple_to_source(source)
self.assertEqual(source, second) self.assertEqual(source, second)
self.assertIn("bin", second) self.assertIn("bin", second)
def test_scale_tuple(self): def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2]) source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_half_tuple(self): def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True]) source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_opset_tuple(self): def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14]) source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
def test_all_tuple(self): def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14]) source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo") self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar") self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2) self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True) self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14) self.assertEqual(source["opset"], 14)
class SourceFormatTests(unittest.TestCase): class SourceFormatTests(unittest.TestCase):
def test_with_format(self): def test_with_format(self):
result = source_format({ result = source_format(
"format": "foo", {
}) "format": "foo",
self.assertEqual(result, "foo") }
)
self.assertEqual(result, "foo")
def test_source_known_extension(self): def test_source_known_extension(self):
result = source_format({ result = source_format(
"source": "foo.safetensors", {
}) "source": "foo.safetensors",
self.assertEqual(result, "safetensors") }
)
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self): def test_source_unknown_extension(self):
result = source_format({ result = source_format({"source": "foo.none"})
"source": "foo.none" self.assertEqual(result, None)
})
self.assertEqual(result, None)
def test_incomplete_model(self): def test_incomplete_model(self):
self.assertIsNone(source_format({})) self.assertIsNone(source_format({}))
class RemovePrefixTests(unittest.TestCase): class RemovePrefixTests(unittest.TestCase):
def test_with_prefix(self): def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar") self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
def test_without_prefix(self): def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase): class LoadTorchTests(unittest.TestCase):
pass pass
class LoadTensorTests(unittest.TestCase): class LoadTensorTests(unittest.TestCase):
pass pass
class ResolveTensorTests(unittest.TestCase): class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) @test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self): def test_resolve_existing(self):
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR) self.assertEqual(
resolve_tensor("../models/.cache/upscaling-swinir"),
TEST_MODEL_UPSCALING_SWINIR,
)
def test_resolve_missing(self): def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing")) self.assertIsNone(resolve_tensor("missing"))

View File

@ -6,11 +6,13 @@ from onnx_web.params import DeviceParams
def test_needs_models(models: List[str]): def test_needs_models(models: List[str]):
return skipUnless(all([path.exists(model) for model in models]), "model does not exist") return skipUnless(
all([path.exists(model) for model in models]), "model does not exist"
)
def test_device() -> DeviceParams: def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider") return DeviceParams("cpu", "CPUExecutionProvider")
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"

View File

@ -10,24 +10,24 @@ from onnx_web.image.mask_filter import (
class MaskFilterNoneTests(unittest.TestCase): class MaskFilterNoneTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
mask = Image.new("RGB", dims) mask = Image.new("RGB", dims)
result = mask_filter_none(mask, dims, (0, 0)) result = mask_filter_none(mask, dims, (0, 0))
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)
class MaskFilterGaussianMultiplyTests(unittest.TestCase): class MaskFilterGaussianMultiplyTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
mask = Image.new("RGB", dims) mask = Image.new("RGB", dims)
result = mask_filter_gaussian_multiply(mask, dims, (0, 0)) result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)
class MaskFilterGaussianScreenTests(unittest.TestCase): class MaskFilterGaussianScreenTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
mask = Image.new("RGB", dims) mask = Image.new("RGB", dims)
result = mask_filter_gaussian_screen(mask, dims, (0, 0)) result = mask_filter_gaussian_screen(mask, dims, (0, 0))
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)

View File

@ -11,27 +11,27 @@ from onnx_web.server.context import ServerContext
class SourceFilterNoneTests(unittest.TestCase): class SourceFilterNoneTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
server = ServerContext() server = ServerContext()
source = Image.new("RGB", dims) source = Image.new("RGB", dims)
result = source_filter_none(server, source) result = source_filter_none(server, source)
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)
class SourceFilterGaussianTests(unittest.TestCase): class SourceFilterGaussianTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
server = ServerContext() server = ServerContext()
source = Image.new("RGB", dims) source = Image.new("RGB", dims)
result = source_filter_gaussian(server, source) result = source_filter_gaussian(server, source)
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)
class SourceFilterNoiseTests(unittest.TestCase): class SourceFilterNoiseTests(unittest.TestCase):
def test_basic(self): def test_basic(self):
dims = (64, 64) dims = (64, 64)
server = ServerContext() server = ServerContext()
source = Image.new("RGB", dims) source = Image.new("RGB", dims)
result = source_filter_noise(server, source) result = source_filter_noise(server, source)
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)

View File

@ -7,18 +7,18 @@ from onnx_web.params import Border
class ExpandImageTests(unittest.TestCase): class ExpandImageTests(unittest.TestCase):
def test_expand(self): def test_expand(self):
result = expand_image( result = expand_image(
Image.new("RGB", (8, 8)), Image.new("RGB", (8, 8)),
Image.new("RGB", (8, 8), "white"), Image.new("RGB", (8, 8), "white"),
Border.even(4), Border.even(4),
) )
self.assertEqual(result[0].size, (16, 16)) self.assertEqual(result[0].size, (16, 16))
def test_masked(self): def test_masked(self):
result = expand_image( result = expand_image(
Image.new("RGB", (8, 8), "red"), Image.new("RGB", (8, 8), "red"),
Image.new("RGB", (8, 8), "white"), Image.new("RGB", (8, 8), "white"),
Border.even(4), Border.even(4),
) )
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0)) self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))

View File

@ -1,43 +1,43 @@
from typing import Any, Optional from typing import Any, Optional
class MockPipeline(): class MockPipeline:
# flags # flags
slice_size: Optional[str] slice_size: Optional[str]
vae_slicing: Optional[bool] vae_slicing: Optional[bool]
sequential_offload: Optional[bool] sequential_offload: Optional[bool]
model_offload: Optional[bool] model_offload: Optional[bool]
xformers: Optional[bool] xformers: Optional[bool]
# stubs # stubs
_encode_prompt: Optional[Any] _encode_prompt: Optional[Any]
unet: Optional[Any] unet: Optional[Any]
vae_decoder: Optional[Any] vae_decoder: Optional[Any]
vae_encoder: Optional[Any] vae_encoder: Optional[Any]
def __init__(self) -> None: def __init__(self) -> None:
self.slice_size = None self.slice_size = None
self.vae_slicing = None self.vae_slicing = None
self.sequential_offload = None self.sequential_offload = None
self.model_offload = None self.model_offload = None
self.xformers = None self.xformers = None
self._encode_prompt = None self._encode_prompt = None
self.unet = None self.unet = None
self.vae_decoder = None self.vae_decoder = None
self.vae_encoder = None self.vae_encoder = None
def enable_attention_slicing(self, slice_size: str = None): def enable_attention_slicing(self, slice_size: str = None):
self.slice_size = slice_size self.slice_size = slice_size
def enable_vae_slicing(self): def enable_vae_slicing(self):
self.vae_slicing = True self.vae_slicing = True
def enable_sequential_cpu_offload(self): def enable_sequential_cpu_offload(self):
self.sequential_offload = True self.sequential_offload = True
def enable_model_cpu_offload(self): def enable_model_cpu_offload(self):
self.model_offload = True self.model_offload = True
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
self.xformers = True self.xformers = True

View File

@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase):
str(["foo"]), str(["foo"]),
str(PromptPhrase(["bar"], weight=1.5)), str(PromptPhrase(["bar"], weight=1.5)),
str(["bin"]), str(["bin"]),
] ],
) )
def test_multi_word_phrase(self): def test_multi_word_phrase(self):
@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase):
str(["foo", "bar"]), str(["foo", "bar"]),
str(PromptPhrase(["middle", "words"], weight=1.5)), str(PromptPhrase(["middle", "words"], weight=1.5)),
str(["bin", "bun"]), str(["bin", "bun"]),
] ],
) )
def test_nested_phrase(self): def test_nested_phrase(self):
@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase):
[str(i) for i in res], [str(i) for i in res],
[ [
str(["foo"]), str(["foo"]),
str(PromptPhrase(["bar"], weight=(1.5 ** 3))), str(PromptPhrase(["bar"], weight=(1.5**3))),
str(["bin"]), str(["bin"]),
] ],
) )

View File

@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase):
params = get_config_params() params = get_config_params()
self.assertIsNotNone(params) self.assertIsNotNone(params)
class AvailablePlatformTests(unittest.TestCase): class AvailablePlatformTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
platforms = get_available_platforms() platforms = get_available_platforms()
self.assertIsNotNone(platforms) self.assertIsNotNone(platforms)
class CorrectModelTests(unittest.TestCase): class CorrectModelTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
models = get_correction_models() models = get_correction_models()
self.assertIsNotNone(models) self.assertIsNotNone(models)
class DiffusionModelTests(unittest.TestCase): class DiffusionModelTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
models = get_diffusion_models() models = get_diffusion_models()
self.assertIsNotNone(models) self.assertIsNotNone(models)
class NetworkModelTests(unittest.TestCase): class NetworkModelTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
models = get_network_models() models = get_network_models()
self.assertIsNotNone(models) self.assertIsNotNone(models)
class UpscalingModelTests(unittest.TestCase): class UpscalingModelTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
models = get_upscaling_models() models = get_upscaling_models()
self.assertIsNotNone(models) self.assertIsNotNone(models)
class WildcardDataTests(unittest.TestCase): class WildcardDataTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
wildcards = get_wildcard_data() wildcards = get_wildcard_data()
self.assertIsNotNone(wildcards) self.assertIsNotNone(wildcards)
class ExtraStringsTests(unittest.TestCase): class ExtraStringsTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
strings = get_extra_strings() strings = get_extra_strings()
self.assertIsNotNone(strings) self.assertIsNotNone(strings)
class ExtraHashesTests(unittest.TestCase): class ExtraHashesTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
hashes = get_extra_hashes() hashes = get_extra_hashes()
self.assertIsNotNone(hashes) self.assertIsNotNone(hashes)
class HighresMethodTests(unittest.TestCase): class HighresMethodTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
methods = get_highres_methods() methods = get_highres_methods()
self.assertIsNotNone(methods) self.assertIsNotNone(methods)
class MaskFilterTests(unittest.TestCase): class MaskFilterTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
filters = get_mask_filters() filters = get_mask_filters()
self.assertIsNotNone(filters) self.assertIsNotNone(filters)
class NoiseSourceTests(unittest.TestCase): class NoiseSourceTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
sources = get_noise_sources() sources = get_noise_sources()
self.assertIsNotNone(sources) self.assertIsNotNone(sources)
class SourceFilterTests(unittest.TestCase): class SourceFilterTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
filters = get_source_filters() filters = get_source_filters()
self.assertIsNotNone(filters) self.assertIsNotNone(filters)
class LoadExtrasTests(unittest.TestCase): class LoadExtrasTests(unittest.TestCase):
def test_default_extras(self): def test_default_extras(self):
server = ServerContext(extra_models=["../models/extras.json"]) server = ServerContext(extra_models=["../models/extras.json"])
load_extras(server) load_extras(server)
class LoadModelsTests(unittest.TestCase): class LoadModelsTests(unittest.TestCase):
def test_default_models(self): def test_default_models(self):
server = ServerContext(model_path="../models") server = ServerContext(model_path="../models")

View File

@ -4,37 +4,37 @@ from onnx_web.server.model_cache import ModelCache
class TestModelCache(unittest.TestCase): class TestModelCache(unittest.TestCase):
def test_drop_existing(self): def test_drop_existing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
cache.set("foo", ("bar",), {}) cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0) self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bar",)), 1) self.assertEqual(cache.drop("foo", ("bar",)), 1)
def test_drop_missing(self): def test_drop_missing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
cache.set("foo", ("bar",), {}) cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0) self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bin",)), 0) self.assertEqual(cache.drop("foo", ("bin",)), 0)
def test_get_existing(self): def test_get_existing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
value = {} value = {}
cache.set("foo", ("bar",), value) cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0) self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bar",)), value) self.assertIs(cache.get("foo", ("bar",)), value)
def test_get_missing(self): def test_get_missing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
value = {} value = {}
cache.set("foo", ("bar",), value) cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0) self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bin",)), None) self.assertIs(cache.get("foo", ("bin",)), None)
""" """
def test_set_existing(self): def test_set_existing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
@ -48,16 +48,16 @@ class TestModelCache(unittest.TestCase):
self.assertIs(cache.get("foo", ("bar",)), value) self.assertIs(cache.get("foo", ("bar",)), value)
""" """
def test_set_missing(self): def test_set_missing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
value = {} value = {}
cache.set("foo", ("bar",), value) cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value) self.assertIs(cache.get("foo", ("bar",)), value)
def test_set_zero(self): def test_set_zero(self):
cache = ModelCache(0) cache = ModelCache(0)
cache.clear() cache.clear()
value = {} value = {}
cache.set("foo", ("bar",), value) cache.set("foo", ("bar",), value)
self.assertEqual(cache.size, 0) self.assertEqual(cache.size, 0)

View File

@ -24,253 +24,307 @@ from tests.mocks import MockPipeline
class TestAvailablePipelines(unittest.TestCase): class TestAvailablePipelines(unittest.TestCase):
def test_available_pipelines(self): def test_available_pipelines(self):
pipelines = get_available_pipelines() pipelines = get_available_pipelines()
self.assertIn("txt2img", pipelines) self.assertIn("txt2img", pipelines)
class TestPipelineSchedulers(unittest.TestCase): class TestPipelineSchedulers(unittest.TestCase):
def test_pipeline_schedulers(self): def test_pipeline_schedulers(self):
schedulers = get_pipeline_schedulers() schedulers = get_pipeline_schedulers()
self.assertIn("euler-a", schedulers) self.assertIn("euler-a", schedulers)
class TestSchedulerNames(unittest.TestCase): class TestSchedulerNames(unittest.TestCase):
def test_valid_name(self): def test_valid_name(self):
scheduler = get_scheduler_name(DDIMScheduler) scheduler = get_scheduler_name(DDIMScheduler)
self.assertEqual("ddim", scheduler) self.assertEqual("ddim", scheduler)
def test_missing_names(self): def test_missing_names(self):
self.assertIsNone(get_scheduler_name("test")) self.assertIsNone(get_scheduler_name("test"))
class TestOptimizePipeline(unittest.TestCase): class TestOptimizePipeline(unittest.TestCase):
def test_auto_attention_slicing(self): def test_auto_attention_slicing(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-attention-slicing-auto", "diffusers-attention-slicing-auto",
], ],
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "auto") self.assertEqual(pipeline.slice_size, "auto")
def test_max_attention_slicing(self): def test_max_attention_slicing(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-attention-slicing-max", "diffusers-attention-slicing-max",
] ]
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "max") self.assertEqual(pipeline.slice_size, "max")
def test_vae_slicing(self): def test_vae_slicing(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-vae-slicing", "diffusers-vae-slicing",
] ]
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.vae_slicing, True) self.assertEqual(pipeline.vae_slicing, True)
def test_cpu_offload_sequential(self): def test_cpu_offload_sequential(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-cpu-offload-sequential", "diffusers-cpu-offload-sequential",
] ]
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.sequential_offload, True) self.assertEqual(pipeline.sequential_offload, True)
def test_cpu_offload_model(self): def test_cpu_offload_model(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-cpu-offload-model", "diffusers-cpu-offload-model",
] ]
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.model_offload, True) self.assertEqual(pipeline.model_offload, True)
def test_memory_efficient_attention(self): def test_memory_efficient_attention(self):
server = ServerContext( server = ServerContext(
optimizations=[ optimizations=[
"diffusers-memory-efficient-attention", "diffusers-memory-efficient-attention",
] ]
) )
pipeline = MockPipeline() pipeline = MockPipeline()
optimize_pipeline(server, pipeline) optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.xformers, True) self.assertEqual(pipeline.xformers, True)
class TestPatchPipeline(unittest.TestCase): class TestPatchPipeline(unittest.TestCase):
def test_expand_not_lpw(self): def test_expand_not_lpw(self):
""" """
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertEqual(pipeline._encode_prompt, expand_prompt) self.assertEqual(pipeline._encode_prompt, expand_prompt)
""" """
pass pass
def test_unet_wrapper_not_xl(self): def test_unet_wrapper_not_xl(self):
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) patch_pipeline(
self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) server,
pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self): def test_unet_wrapper_xl(self):
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1)) patch_pipeline(
self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) server,
pipeline,
None,
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_vae_wrapper(self): def test_vae_wrapper(self):
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) patch_pipeline(
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) server,
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
class TestLoadControlNet(unittest.TestCase): class TestLoadControlNet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_existing(self): path.exists("../models/control/canny.onnx"), "model does not exist"
"""
Should load a model
"""
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
) )
self.assertIn("controlnet", components) def test_load_existing(self):
"""
Should load a model
"""
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams(
"test",
"txt2img",
"ddim",
"test",
1.0,
10,
1,
control=NetworkModel("canny", "control"),
),
)
self.assertIn("controlnet", components)
def test_load_missing(self): def test_load_missing(self):
""" """
Should throw Should throw
""" """
components = {} components = {}
try: try:
components = load_controlnet( components = load_controlnet(
ServerContext(), ServerContext(),
DeviceParams("cpu", "CPUExecutionProvider"), DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")), ImageParams(
) "test",
except: "txt2img",
self.assertNotIn("controlnet", components) "ddim",
return "test",
1.0,
10,
1,
control=NetworkModel("missing", "control"),
),
)
except Exception:
self.assertNotIn("controlnet", components)
return
self.fail() self.fail()
class TestLoadTextEncoders(unittest.TestCase): class TestLoadTextEncoders(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_embeddings(self): path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
""" "model does not exist",
Should add the token to tokenizer
Should increase the encoder dims
"""
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some embeddings
],
[],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertIn("text_encoder", components) def test_load_embeddings(self):
"""
Should add the token to tokenizer
Should increase the encoder dims
"""
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some embeddings
],
[],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_embeddings_xl(self): def test_load_embeddings_xl(self):
pass pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_loras(self): path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
components = load_text_encoders( "model does not exist",
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[],
[
# TODO: add some loras
],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertIn("text_encoder", components) def test_load_loras(self):
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[],
[
# TODO: add some loras
],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_loras_xl(self):
pass
def test_load_loras_xl(self):
pass
class TestLoadUnet(unittest.TestCase): class TestLoadUnet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_unet_loras(self): path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
components = load_unet( "model does not exist",
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"unet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertIn("unet", components) def test_load_unet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"unet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
def test_load_unet_loras_xl(self): def test_load_unet_loras_xl(self):
pass pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_cnet_loras(self): path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
components = load_unet( "model does not exist",
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"cnet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertIn("unet", components) def test_load_cnet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"cnet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
class TestLoadVae(unittest.TestCase): class TestLoadVae(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_single(self): path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
""" "model does not exist",
Should return single component
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/upscaling-stable-diffusion-x4",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertIn("vae", components) def test_load_single(self):
self.assertNotIn("vae_decoder", components) """
self.assertNotIn("vae_encoder", components) Should return single component
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/upscaling-stable-diffusion-x4",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("vae", components)
self.assertNotIn("vae_decoder", components)
self.assertNotIn("vae_encoder", components)
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist") @unittest.skipUnless(
def test_load_split(self): path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
""" "model does not exist",
Should return split encoder/decoder
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
) )
self.assertNotIn("vae", components) def test_load_split(self):
self.assertIn("vae_decoder", components) """
self.assertIn("vae_encoder", components) Should return split encoder/decoder
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertNotIn("vae", components)
self.assertIn("vae_decoder", components)
self.assertIn("vae_encoder", components)

View File

@ -17,155 +17,234 @@ from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_mod
class TestTxt2ImgPipeline(unittest.TestCase): class TestTxt2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self): def test_basic(self):
cancel = Value("L", 0) cancel = Value("L", 0)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), test_device(),
cancel, cancel,
logs, logs,
pending, pending,
progress, progress,
active, active,
idle, idle,
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start("test")
run_txt2img_pipeline( run_txt2img_pipeline(
worker, worker,
ServerContext(model_path="../models", output_path="../outputs"), ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), TEST_MODEL_DIFFUSION_SD15,
Size(256, 256), "txt2img",
["test-txt2img.png"], "ddim",
UpscaleParams("test"), "an astronaut eating a hamburger",
HighresParams(False, 1, 0, 0), 3.0,
) 1,
1,
),
Size(256, 256),
["test-txt2img-basic.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
output = Image.open("../outputs/test-txt2img-basic.png")
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test"),
HighresParams(True, 2, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
output = Image.open("../outputs/test-txt2img-highres.png")
self.assertEqual(output.size, (512, 512))
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
class TestImg2ImgPipeline(unittest.TestCase): class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self): def test_basic(self):
cancel = Value("L", 0) cancel = Value("L", 0)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), test_device(),
cancel, cancel,
logs, logs,
pending, pending,
progress, progress,
active, active,
idle, idle,
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start("test")
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline( run_img2img_pipeline(
worker, worker,
ServerContext(model_path="../models", output_path="../outputs"), ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), TEST_MODEL_DIFFUSION_SD15,
["test-img2img.png"], "txt2img",
UpscaleParams("test"), "ddim",
HighresParams(False, 1, 0, 0), "an astronaut eating a hamburger",
source, 3.0,
1.0, 1,
) 1,
),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
1.0,
)
self.assertTrue(path.exists("../outputs/test-img2img.png"))
self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestUpscalePipeline(unittest.TestCase): class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"]) @test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self): def test_basic(self):
cancel = Value("L", 0) cancel = Value("L", 0)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), test_device(),
cancel, cancel,
logs, logs,
pending, pending,
progress, progress,
active, active,
idle, idle,
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start("test")
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline( run_upscale_pipeline(
worker, worker,
ServerContext(model_path="../models", output_path="../outputs"), ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), "../models/upscaling-stable-diffusion-x4",
Size(256, 256), "txt2img",
["test-upscale.png"], "ddim",
UpscaleParams("test"), "an astronaut eating a hamburger",
HighresParams(False, 1, 0, 0), 3.0,
source, 1,
) 1,
),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
)
self.assertTrue(path.exists("../outputs/test-upscale.png"))
self.assertTrue(path.exists("../outputs/test-upscale.png"))
class TestBlendPipeline(unittest.TestCase): class TestBlendPipeline(unittest.TestCase):
def test_basic(self): def test_basic(self):
cancel = Value("L", 0) cancel = Value("L", 0)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), test_device(),
cancel, cancel,
logs, logs,
pending, pending,
progress, progress,
active, active,
idle, idle,
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start("test")
source = Image.new("RGBA", (64, 64), "black") source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white") mask = Image.new("RGBA", (64, 64), "white")
run_blend_pipeline( run_blend_pipeline(
worker, worker,
ServerContext(model_path="../models", output_path="../outputs"), ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), TEST_MODEL_DIFFUSION_SD15,
Size(64, 64), "txt2img",
["test-blend.png"], "ddim",
UpscaleParams("test"), "an astronaut eating a hamburger",
[source, source], 3.0,
mask, 1,
) 1,
),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),
[source, source],
mask,
)
self.assertTrue(path.exists("../outputs/test-blend.png")) self.assertTrue(path.exists("../outputs/test-blend.png"))

View File

@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import (
get_loras_from_prompt, get_loras_from_prompt,
get_scaled_latents, get_scaled_latents,
get_tile_latents, get_tile_latents,
get_tokens_from_prompt,
pop_random, pop_random,
slice_prompt, slice_prompt,
) )
@ -18,110 +17,128 @@ from onnx_web.params import Size
class TestExpandIntervalRanges(unittest.TestCase): class TestExpandIntervalRanges(unittest.TestCase):
def test_prompt_with_no_ranges(self): def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger" prompt = "an astronaut eating a hamburger"
result = expand_interval_ranges(prompt) result = expand_interval_ranges(prompt)
self.assertEqual(prompt, result) self.assertEqual(prompt, result)
def test_prompt_with_range(self):
prompt = "an astronaut-{1,4} eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
)
def test_prompt_with_range(self):
prompt = "an astronaut-{1,4} eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
class TestExpandAlternativeRanges(unittest.TestCase): class TestExpandAlternativeRanges(unittest.TestCase):
def test_prompt_with_no_ranges(self): def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger" prompt = "an astronaut eating a hamburger"
result = expand_alternative_ranges(prompt) result = expand_alternative_ranges(prompt)
self.assertEqual([prompt], result) self.assertEqual([prompt], result)
def test_ranges_match(self):
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
result = expand_alternative_ranges(prompt)
self.assertEqual(
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
)
def test_ranges_match(self):
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
result = expand_alternative_ranges(prompt)
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
class TestInversionsFromPrompt(unittest.TestCase): class TestInversionsFromPrompt(unittest.TestCase):
def test_get_inversions(self): def test_get_inversions(self):
prompt = "<inversion:test:1.0> an astronaut eating an embedding" prompt = "<inversion:test:1.0> an astronaut eating an embedding"
result, tokens = get_inversions_from_prompt(prompt) result, tokens = get_inversions_from_prompt(prompt)
self.assertEqual(result, " an astronaut eating an embedding")
self.assertEqual(tokens, [("test", 1.0)])
self.assertEqual(result, " an astronaut eating an embedding")
self.assertEqual(tokens, [("test", 1.0)])
class TestLoRAsFromPrompt(unittest.TestCase): class TestLoRAsFromPrompt(unittest.TestCase):
def test_get_loras(self): def test_get_loras(self):
prompt = "<lora:test:1.0> an astronaut eating a LoRA" prompt = "<lora:test:1.0> an astronaut eating a LoRA"
result, tokens = get_loras_from_prompt(prompt) result, tokens = get_loras_from_prompt(prompt)
self.assertEqual(result, " an astronaut eating a LoRA")
self.assertEqual(tokens, [("test", 1.0)])
self.assertEqual(result, " an astronaut eating a LoRA")
self.assertEqual(tokens, [("test", 1.0)])
class TestLatentsFromSeed(unittest.TestCase): class TestLatentsFromSeed(unittest.TestCase):
def test_batch_size(self): def test_batch_size(self):
latents = get_latents_from_seed(1, Size(64, 64), batch=4) latents = get_latents_from_seed(1, Size(64, 64), batch=4)
self.assertEqual(latents.shape, (4, 4, 8, 8)) self.assertEqual(latents.shape, (4, 4, 8, 8))
def test_consistency(self):
latents1 = get_latents_from_seed(1, Size(64, 64))
latents2 = get_latents_from_seed(1, Size(64, 64))
self.assertTrue(np.array_equal(latents1, latents2))
def test_consistency(self):
latents1 = get_latents_from_seed(1, Size(64, 64))
latents2 = get_latents_from_seed(1, Size(64, 64))
self.assertTrue(np.array_equal(latents1, latents2))
class TestTileLatents(unittest.TestCase): class TestTileLatents(unittest.TestCase):
def test_full_tile(self): def test_full_tile(self):
partial = np.zeros((1, 1, 64, 64)) partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8)) self.assertEqual(full.shape, (1, 1, 8, 8))
def test_contract_tile(self): def test_contract_tile(self):
partial = np.zeros((1, 1, 64, 64)) partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32)) full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
self.assertEqual(full.shape, (1, 1, 4, 4)) self.assertEqual(full.shape, (1, 1, 4, 4))
def test_expand_tile(self):
partial = np.zeros((1, 1, 32, 32))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
def test_expand_tile(self):
partial = np.zeros((1, 1, 32, 32))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
class TestScaledLatents(unittest.TestCase): class TestScaledLatents(unittest.TestCase):
def test_scale_up(self): def test_scale_up(self):
latents = get_latents_from_seed(1, Size(16, 16)) latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=2) scaled = get_scaled_latents(1, Size(16, 16), scale=2)
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0]) self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
def test_scale_down(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
self.assertEqual(
(
latents[0, 0, 0, 0]
+ latents[0, 0, 0, 1]
+ latents[0, 0, 1, 0]
+ latents[0, 0, 1, 1]
)
/ 4,
scaled[0, 0, 0, 0],
)
def test_scale_down(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
self.assertEqual((
latents[0, 0, 0, 0] +
latents[0, 0, 0, 1] +
latents[0, 0, 1, 0] +
latents[0, 0, 1, 1]
) / 4, scaled[0, 0, 0, 0])
class TestReplaceWildcards(unittest.TestCase): class TestReplaceWildcards(unittest.TestCase):
pass pass
class TestPopRandom(unittest.TestCase): class TestPopRandom(unittest.TestCase):
def test_pop(self): def test_pop(self):
items = ["1", "2", "3"] items = ["1", "2", "3"]
pop_random(items) pop_random(items)
self.assertEqual(len(items), 2) self.assertEqual(len(items), 2)
class TestRepairNaN(unittest.TestCase): class TestRepairNaN(unittest.TestCase):
def test_unchanged(self): def test_unchanged(self):
pass pass
def test_missing(self):
pass
def test_missing(self):
pass
class TestSlicePrompt(unittest.TestCase): class TestSlicePrompt(unittest.TestCase):
def test_slice_no_delimiter(self): def test_slice_no_delimiter(self):
slice = slice_prompt("foo", 1) slice = slice_prompt("foo", 1)
self.assertEqual(slice, "foo") self.assertEqual(slice, "foo")
def test_slice_within_range(self): def test_slice_within_range(self):
slice = slice_prompt("foo || bar", 1) slice = slice_prompt("foo || bar", 1)
self.assertEqual(slice, " bar") self.assertEqual(slice, " bar")
def test_slice_outside_range(self): def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9) slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar") self.assertEqual(slice, " bar")

View File

@ -13,122 +13,128 @@ lock = Event()
def test_job(*args, **kwargs): def test_job(*args, **kwargs):
lock.wait() lock.wait()
def wait_job(*args, **kwargs): def wait_job(*args, **kwargs):
sleep(0.5) sleep(0.5)
class TestWorkerPool(unittest.TestCase): class TestWorkerPool(unittest.TestCase):
# lock: Optional[Event] # lock: Optional[Event]
pool: Optional[DevicePoolExecutor] pool: Optional[DevicePoolExecutor]
def setUp(self) -> None: def setUp(self) -> None:
self.pool = None self.pool = None
def tearDown(self) -> None: def tearDown(self) -> None:
if self.pool is not None: if self.pool is not None:
self.pool.join() self.pool.join()
def test_no_devices(self): def test_no_devices(self):
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
def test_fake_worker(self): def test_fake_worker(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.assertEqual(len(self.pool.workers), 1) self.assertEqual(len(self.pool.workers), 1)
def test_cancel_pending(self): def test_cancel_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.pool.submit("test", wait_job, lock=lock) self.pool.submit("test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None)) self.assertEqual(self.pool.done("test"), (True, None))
self.assertTrue(self.pool.cancel("test")) self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None)) self.assertEqual(self.pool.done("test"), (False, None))
def test_cancel_running(self): def test_cancel_running(self):
pass pass
def test_next_device(self): def test_next_device(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.assertEqual(self.pool.get_next_device(), 0) self.assertEqual(self.pool.get_next_device(), 0)
def test_needs_device(self): def test_needs_device(self):
device1 = DeviceParams("cpu1", "CPUProvider") device1 = DeviceParams("cpu1", "CPUProvider")
device2 = DeviceParams("cpu2", "CPUProvider") device2 = DeviceParams("cpu2", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(
self.pool.start() server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT
)
self.pool.start()
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
def test_done_running(self): def test_done_running(self):
""" """
TODO: flaky TODO: flaky
""" """
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) self.pool = DevicePoolExecutor(
self.pool.start(lock) server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
sleep(2.0) )
self.pool.start(lock)
sleep(2.0)
self.pool.submit("test", test_job) self.pool.submit("test", test_job)
sleep(2.0) sleep(2.0)
pending, _progress = self.pool.done("test") pending, _progress = self.pool.done("test")
self.assertFalse(pending) self.assertFalse(pending)
def test_done_pending(self): def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock) self.pool.start(lock)
self.pool.submit("test1", test_job) self.pool.submit("test1", test_job)
self.pool.submit("test2", test_job) self.pool.submit("test2", test_job)
self.assertTrue(self.pool.done("test2"), (True, None)) self.assertTrue(self.pool.done("test2"), (True, None))
lock.set() lock.set()
def test_done_finished(self): def test_done_finished(self):
""" """
TODO: flaky TODO: flaky
""" """
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) self.pool = DevicePoolExecutor(
self.pool.start() server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
sleep(2.0) )
self.pool.start()
sleep(2.0)
self.pool.submit("test", wait_job) self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None)) self.assertEqual(self.pool.done("test"), (True, None))
sleep(2.0) sleep(2.0)
pending, _progress = self.pool.done("test") pending, _progress = self.pool.done("test")
self.assertFalse(pending) self.assertFalse(pending)
def test_recycle_live(self): def test_recycle_live(self):
pass pass
def test_recycle_dead(self): def test_recycle_dead(self):
pass pass
def test_running_status(self): def test_running_status(self):
pass pass

View File

@ -18,119 +18,194 @@ from tests.helpers import test_device
def main_memory(_worker): def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0]) raise Exception(MEMORY_ERRORS[0])
def main_retry(_worker): def main_retry(_worker):
raise RetryException() raise RetryException()
def main_interrupt(_worker): def main_interrupt(_worker):
raise KeyboardInterrupt() raise KeyboardInterrupt()
class WorkerMainTests(unittest.TestCase): class WorkerMainTests(unittest.TestCase):
def test_pending_exception_empty(self): def test_pending_exception_empty(self):
pass pass
def test_pending_exception_interrupt(self): def test_pending_exception_interrupt(self):
status = None status = None
def exit(exit_status): def exit(exit_status):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_interrupt, [], {}) job = JobCommand("test", "test", main_interrupt, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
pid = Value("L", getpid()) pid = Value("L", getpid())
idle = Value("L", False) idle = Value("L", False)
pending.put(job) pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_INTERRUPT) self.assertEqual(status, EXIT_INTERRUPT)
pass pass
def test_pending_exception_retry(self): def test_pending_exception_retry(self):
status = None status = None
def exit(exit_status): def exit(exit_status):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_retry, [], {}) job = JobCommand("test", "test", main_retry, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
pid = Value("L", getpid()) pid = Value("L", getpid())
idle = Value("L", False) idle = Value("L", False)
pending.put(job) pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR) self.assertEqual(status, EXIT_ERROR)
pass pass
def test_pending_exception_value(self): def test_pending_exception_value(self):
status = None status = None
def exit(exit_status): def exit(exit_status):
nonlocal status nonlocal status
status = exit_status status = exit_status
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
pid = Value("L", getpid()) pid = Value("L", getpid())
idle = Value("L", False) idle = Value("L", False)
pending.close() pending.close()
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR) self.assertEqual(status, EXIT_ERROR)
def test_pending_exception_other_memory(self): def test_pending_exception_other_memory(self):
status = None status = None
def exit(exit_status): def exit(exit_status):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_memory, [], {}) job = JobCommand("test", "test", main_memory, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
progress = Queue() progress = Queue()
pid = Value("L", getpid()) pid = Value("L", getpid())
idle = Value("L", False) idle = Value("L", False)
pending.put(job) pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_MEMORY) self.assertEqual(status, EXIT_MEMORY)
def test_pending_exception_other_unknown(self):
pass
def test_pending_exception_other_unknown(self): def test_pending_replaced(self):
pass status = None
def test_pending_replaced(self): def exit(exit_status):
status = None nonlocal status
status = exit_status
def exit(exit_status): cancel = Value("L", False)
nonlocal status logs = Queue()
status = exit_status pending = Queue()
progress = Queue()
pid = Value("L", 0)
idle = Value("L", False)
cancel = Value("L", False) worker_main(
logs = Queue() WorkerContext(
pending = Queue() "test",
progress = Queue() test_device(),
pid = Value("L", 0) cancel,
idle = Value("L", False) logs,
pending,
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) progress,
pid,
self.assertEqual(status, EXIT_REPLACED) idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_REPLACED)