1
0
Fork 0

apply lint to tests, test highres

This commit is contained in:
Sean Sube 2023-11-19 23:18:57 -06:00
parent 4691e80744
commit 65912c5a4a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
26 changed files with 1506 additions and 1127 deletions

View File

@ -33,13 +33,17 @@ package-upload:
lint-check:
black --check onnx_web/
black --check tests/
flake8 onnx_web
flake8 tests
isort --check-only --skip __init__.py --filter-files onnx_web
isort --check-only --skip __init__.py --filter-files tests
lint-fix:
black onnx_web/
black tests/
flake8 onnx_web
flake8 tests
isort --skip __init__.py --filter-files onnx_web
isort --skip __init__.py --filter-files tests

View File

@ -19,6 +19,14 @@ class StageResult:
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")

View File

@ -25,9 +25,7 @@ class TileCallback(Protocol):
Definition for a tile job function.
"""
def __call__(
self, image: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
"""
Run this stage against a single tile.
"""
@ -319,6 +317,9 @@ def process_tile_stack(
if mask:
tile_mask = mask.crop((left, top, right, bottom))
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
for image_filter in filters:
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))

View File

@ -48,6 +48,7 @@ def main():
# debug options
if server.debug:
import debugpy
debugpy.listen(5678)
logger.warning("waiting for debugger")
debugpy.wait_for_client()

View File

@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = StageResult(images=[
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
])
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
]
)
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
class BlendImg2ImgStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_stage(self):
"""
stage = BlendImg2ImgStage()
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
server = ServerContext()
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
params = ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"euler-a",
"an astronaut eating a hamburger",
3.0,
1,
1,
)
server = ServerContext(model_path="../models", output_path="../outputs")
worker = WorkerContext(
"test",
DeviceParams("cpu", "CPUProvider"),
None,
None,
None,
None,
None,
None,
0,
)
sources = [
Image.new("RGB", (64, 64), "black"),
]
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
"""
pass
self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = StageResult(images=[
Image.new("RGB", (64, 64), "black"),
])
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
)
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -30,4 +30,4 @@ class CorrectCodeformerStageTests(unittest.TestCase):
self.assertEqual(len(result), 0)
"""
pass
pass

View File

@ -2,6 +2,7 @@ import unittest
from PIL import Image
from onnx_web.chain.result import StageResult
from onnx_web.chain.tile import (
complete_tile,
generate_tile_grid,
@ -14,122 +15,126 @@ from onnx_web.params import Size
class TestCompleteTile(unittest.TestCase):
def test_with_complete_tile(self):
partial = Image.new("RGB", (64, 64))
output = complete_tile(partial, 64)
def test_with_complete_tile(self):
partial = Image.new("RGB", (64, 64))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
self.assertEqual(output.size, (64, 64))
def test_with_partial_tile(self):
partial = Image.new("RGB", (64, 32))
output = complete_tile(partial, 64)
def test_with_partial_tile(self):
partial = Image.new("RGB", (64, 32))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
self.assertEqual(output.size, (64, 64))
def test_with_nothing(self):
output = complete_tile(None, 64)
def test_with_nothing(self):
output = complete_tile(None, 64)
self.assertIsNone(output)
self.assertIsNone(output)
class TestNeedsTile(unittest.TestCase):
def test_with_undersized_source(self):
small = Image.new("RGB", (32, 32))
def test_with_undersized_source(self):
small = Image.new("RGB", (32, 32))
self.assertFalse(needs_tile(64, 64, source=small))
self.assertFalse(needs_tile(64, 64, source=small))
def test_with_oversized_source(self):
large = Image.new("RGB", (64, 64))
def test_with_oversized_source(self):
large = Image.new("RGB", (64, 64))
self.assertTrue(needs_tile(32, 32, source=large))
self.assertTrue(needs_tile(32, 32, source=large))
def test_with_undersized_size(self):
small = Size(32, 32)
def test_with_undersized_size(self):
small = Size(32, 32)
self.assertFalse(needs_tile(64, 64, size=small))
self.assertFalse(needs_tile(64, 64, size=small))
def test_with_oversized_source(self):
large = Size(64, 64)
def test_with_oversized_size(self):
large = Size(64, 64)
self.assertTrue(needs_tile(32, 32, size=large))
self.assertTrue(needs_tile(32, 32, size=large))
def test_with_nothing(self):
self.assertFalse(needs_tile(32, 32))
def test_with_nothing(self):
self.assertFalse(needs_tile(32, 32))
class TestTileGrads(unittest.TestCase):
def test_center_tile(self):
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
def test_center_tile(self):
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [0, 1, 1, 0])
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [0, 1, 1, 0])
def test_vertical_edge_tile(self):
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
def test_vertical_edge_tile(self):
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [1, 1, 1, 1])
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [1, 1, 1, 1])
def test_horizontal_edge_tile(self):
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
def test_horizontal_edge_tile(self):
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
self.assertEqual(grad_x, [1, 1, 1, 1])
self.assertEqual(grad_y, [0, 1, 1, 0])
self.assertEqual(grad_x, [1, 1, 1, 1])
self.assertEqual(grad_y, [0, 1, 1, 0])
class TestGenerateTileGrid(unittest.TestCase):
def test_grid_complete(self):
tiles = generate_tile_grid(16, 16, 8, 0.0)
def test_grid_complete(self):
tiles = generate_tile_grid(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)])
def test_grid_no_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.0)
def test_grid_no_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)])
def test_grid_50_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.5)
def test_grid_50_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
self.assertEqual(len(tiles), 256)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)])
class TestGenerateTileSpiral(unittest.TestCase):
def test_spiral_complete(self):
tiles = generate_tile_spiral(16, 16, 8, 0.0)
def test_spiral_complete(self):
tiles = generate_tile_spiral(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
def test_spiral_no_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.0)
def test_spiral_no_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
def test_spiral_50_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.5)
def test_spiral_50_overlap(self):
tiles = generate_tile_spiral(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_stack(source, 32, 1, [])
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (64, 64))
self.assertEqual(blend[0].size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_stack(source, 32, 1, [])
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (72, 72))
self.assertEqual(blend[0].size, (72, 72))

View File

@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self):
stage = UpscaleHighresStage()
sources = StageResult.empty()
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
self.assertEqual(len(result), 0)

View File

@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase):
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
self.assertEqual(
len(slim_model.graph.initializer), len(model.graph.initializer)
)
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
inputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/down_blocks_proj/MatMul"),
])
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/down_blocks_proj/MatMul"),
],
)
self.assertEqual(fixed, {
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
})
self.assertEqual(
fixed,
{
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
},
)
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/mid_blocks_proj/MatMul"),
])
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/mid_blocks_proj/MatMul"),
],
)
self.assertEqual(fixed, {
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
})
self.assertEqual(
fixed,
{
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
},
)
def test_output_block(self):
pass
@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase):
nodes = {
"output_block_proj_out.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/up_blocks_proj_out/MatMul"),
])
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/up_blocks_proj_out/MatMul"),
],
)
self.assertEqual(fixed, {
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
})
self.assertEqual(
fixed,
{
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
},
)
class KernelSliceTests(unittest.TestCase):
@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
self.assertEqual(result.shape, (4, 4))
"""
class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self):
model = {
@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase):
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),

View File

@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import (
blend_embedding_embeddings,
blend_embedding_node,
blend_embedding_parameters,
blend_textual_inversions,
detect_embedding_format,
)
@ -18,210 +17,267 @@ TEST_DIMS = (8, 8)
TEST_DIMS_EMBEDS = (1, *TEST_DIMS)
TEST_MODEL_EMBEDS = {
"string_to_token": {
"string_to_token": {
"test": 1,
},
"string_to_param": {
},
"string_to_param": {
"test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
},
}
class DetectEmbeddingFormatTests(unittest.TestCase):
def test_concept(self):
embedding = {
"<test>": "test",
}
self.assertEqual(detect_embedding_format(embedding), "concept")
def test_concept(self):
embedding = {
"<test>": "test",
}
self.assertEqual(detect_embedding_format(embedding), "concept")
def test_parameters(self):
embedding = {
"emb_params": "test",
}
self.assertEqual(detect_embedding_format(embedding), "parameters")
def test_parameters(self):
embedding = {
"emb_params": "test",
}
self.assertEqual(detect_embedding_format(embedding), "parameters")
def test_embeddings(self):
embedding = {
"string_to_token": "test",
"string_to_param": "test",
}
self.assertEqual(detect_embedding_format(embedding), "embeddings")
def test_embeddings(self):
embedding = {
"string_to_token": "test",
"string_to_param": "test",
}
self.assertEqual(detect_embedding_format(embedding), "embeddings")
def test_unknown(self):
embedding = {
"what_is_this": "test",
}
self.assertEqual(detect_embedding_format(embedding), None)
def test_unknown(self):
embedding = {
"what_is_this": "test",
}
self.assertEqual(detect_embedding_format(embedding), None)
class BlendEmbeddingConceptTests(unittest.TestCase):
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_concept(embeds, {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self):
embeds = {}
blend_embedding_concept(embeds, {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
def test_missing_base_token(self):
embeds = {}
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self):
embeds = {
"<test>": np.ones(TEST_DIMS),
}
blend_embedding_concept(embeds, {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
def test_existing_token(self):
embeds = {
"<test>": np.ones(TEST_DIMS),
}
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"])
def test_missing_token(self):
embeds = {}
blend_embedding_concept(embeds, {
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
def test_missing_token(self):
embeds = {}
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["<test>", "test"])
class BlendEmbeddingParametersTests(unittest.TestCase):
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_parameters(embeds, {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self):
embeds = {}
blend_embedding_parameters(embeds, {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
def test_missing_base_token(self):
embeds = {}
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self):
embeds = {
"test": np.ones(TEST_DIMS_EMBEDS),
}
blend_embedding_parameters(embeds, {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
def test_existing_token(self):
embeds = {
"test": np.ones(TEST_DIMS_EMBEDS),
}
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
def test_missing_token(self):
embeds = {}
blend_embedding_parameters(embeds, {
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
def test_missing_token(self):
embeds = {}
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
class BlendEmbeddingEmbeddingsTests(unittest.TestCase):
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
def test_existing_base_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertEqual(embeds["test"].mean(), 2)
def test_missing_base_token(self):
embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
def test_missing_base_token(self):
embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
def test_existing_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
def test_existing_token(self):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
def test_missing_token(self):
embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
def test_missing_token(self):
embeds = {}
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
keys = list(embeds.keys())
keys.sort()
keys = list(embeds.keys())
keys.sort()
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
self.assertIn("test", embeds)
self.assertEqual(keys, ["test", "test-0", "test-all"])
class BlendEmbeddingNodeTests(unittest.TestCase):
def test_expand_weights(self):
weights = from_array(np.ones(TEST_DIMS))
weights.name = "text_model.embeddings.token_embedding.weight"
def test_expand_weights(self):
weights = from_array(np.ones(TEST_DIMS))
weights.name = "text_model.embeddings.token_embedding.weight"
model = ModelProto(graph=GraphProto(initializer=[
weights,
]))
model = ModelProto(
graph=GraphProto(
initializer=[
weights,
]
)
)
embeds = {}
blend_embedding_node(model, {
'convert_tokens_to_ids': lambda t: t,
}, embeds, 2)
embeds = {}
blend_embedding_node(
model,
{
"convert_tokens_to_ids": lambda t: t,
},
embeds,
2,
)
result = to_array(model.graph.initializer[0])
result = to_array(model.graph.initializer[0])
self.assertEqual(len(model.graph.initializer), 1)
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
self.assertEqual(len(model.graph.initializer), 1)
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
class BlendTextualInversionsTests(unittest.TestCase):
def test_blend_multi_concept(self):
pass
def test_blend_multi_concept(self):
pass
def test_blend_multi_parameters(self):
pass
def test_blend_multi_parameters(self):
pass
def test_blend_multi_embeddings(self):
pass
def test_blend_multi_embeddings(self):
pass
def test_blend_multi_mixed(self):
pass
def test_blend_multi_mixed(self):
pass

View File

@ -13,7 +13,6 @@ from onnx_web.convert.utils import (
tuple_to_upscaling,
)
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_UPSCALING_SWINIR,
test_needs_models,
)
@ -21,220 +20,225 @@ from tests.helpers import (
class ConversionContextTests(unittest.TestCase):
def test_from_environ(self):
context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET)
context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET)
def test_map_location(self):
context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu")
context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu")
class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")
class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin"
def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned as-is with extra fields
second = tuple_to_source(source)
# make sure this is returned as-is with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
self.assertEqual(source, second)
self.assertIn("bin", second)
class TupleToCorrectionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin"
def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToDiffusionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin"
def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_diffusion(source)
# make sure this is returned with extra fields
second = tuple_to_diffusion(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToUpscalingTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin"
def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class SourceFormatTests(unittest.TestCase):
def test_with_format(self):
result = source_format({
"format": "foo",
})
self.assertEqual(result, "foo")
def test_with_format(self):
result = source_format(
{
"format": "foo",
}
)
self.assertEqual(result, "foo")
def test_source_known_extension(self):
result = source_format({
"source": "foo.safetensors",
})
self.assertEqual(result, "safetensors")
def test_source_known_extension(self):
result = source_format(
{
"source": "foo.safetensors",
}
)
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self):
result = source_format({
"source": "foo.none"
})
self.assertEqual(result, None)
def test_source_unknown_extension(self):
result = source_format({"source": "foo.none"})
self.assertEqual(result, None)
def test_incomplete_model(self):
self.assertIsNone(source_format({}))
def test_incomplete_model(self):
self.assertIsNone(source_format({}))
class RemovePrefixTests(unittest.TestCase):
def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
pass
pass
class LoadTensorTests(unittest.TestCase):
pass
pass
class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(
resolve_tensor("../models/.cache/upscaling-swinir"),
TEST_MODEL_UPSCALING_SWINIR,
)
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))

View File

@ -6,11 +6,13 @@ from onnx_web.params import DeviceParams
def test_needs_models(models: List[str]):
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
return skipUnless(
all([path.exists(model) for model in models]), "model does not exist"
)
def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")
return DeviceParams("cpu", "CPUExecutionProvider")
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"

View File

@ -10,24 +10,24 @@ from onnx_web.image.mask_filter import (
class MaskFilterNoneTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_none(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_none(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
class MaskFilterGaussianMultiplyTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
class MaskFilterGaussianScreenTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
self.assertEqual(result.size, dims)

View File

@ -11,27 +11,27 @@ from onnx_web.server.context import ServerContext
class SourceFilterNoneTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_none(server, source)
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_none(server, source)
self.assertEqual(result.size, dims)
class SourceFilterGaussianTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_gaussian(server, source)
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_gaussian(server, source)
self.assertEqual(result.size, dims)
class SourceFilterNoiseTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_noise(server, source)
self.assertEqual(result.size, dims)
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_noise(server, source)
self.assertEqual(result.size, dims)

View File

@ -7,18 +7,18 @@ from onnx_web.params import Border
class ExpandImageTests(unittest.TestCase):
def test_expand(self):
result = expand_image(
Image.new("RGB", (8, 8)),
Image.new("RGB", (8, 8), "white"),
Border.even(4),
)
self.assertEqual(result[0].size, (16, 16))
def test_expand(self):
result = expand_image(
Image.new("RGB", (8, 8)),
Image.new("RGB", (8, 8), "white"),
Border.even(4),
)
self.assertEqual(result[0].size, (16, 16))
def test_masked(self):
result = expand_image(
Image.new("RGB", (8, 8), "red"),
Image.new("RGB", (8, 8), "white"),
Border.even(4),
)
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))
def test_masked(self):
result = expand_image(
Image.new("RGB", (8, 8), "red"),
Image.new("RGB", (8, 8), "white"),
Border.even(4),
)
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))

View File

@ -1,43 +1,43 @@
from typing import Any, Optional
class MockPipeline():
# flags
slice_size: Optional[str]
vae_slicing: Optional[bool]
sequential_offload: Optional[bool]
model_offload: Optional[bool]
xformers: Optional[bool]
class MockPipeline:
# flags
slice_size: Optional[str]
vae_slicing: Optional[bool]
sequential_offload: Optional[bool]
model_offload: Optional[bool]
xformers: Optional[bool]
# stubs
_encode_prompt: Optional[Any]
unet: Optional[Any]
vae_decoder: Optional[Any]
vae_encoder: Optional[Any]
# stubs
_encode_prompt: Optional[Any]
unet: Optional[Any]
vae_decoder: Optional[Any]
vae_encoder: Optional[Any]
def __init__(self) -> None:
self.slice_size = None
self.vae_slicing = None
self.sequential_offload = None
self.model_offload = None
self.xformers = None
def __init__(self) -> None:
self.slice_size = None
self.vae_slicing = None
self.sequential_offload = None
self.model_offload = None
self.xformers = None
self._encode_prompt = None
self.unet = None
self.vae_decoder = None
self.vae_encoder = None
self._encode_prompt = None
self.unet = None
self.vae_decoder = None
self.vae_encoder = None
def enable_attention_slicing(self, slice_size: str = None):
self.slice_size = slice_size
def enable_attention_slicing(self, slice_size: str = None):
self.slice_size = slice_size
def enable_vae_slicing(self):
self.vae_slicing = True
def enable_vae_slicing(self):
self.vae_slicing = True
def enable_sequential_cpu_offload(self):
self.sequential_offload = True
def enable_sequential_cpu_offload(self):
self.sequential_offload = True
def enable_model_cpu_offload(self):
self.model_offload = True
def enable_model_cpu_offload(self):
self.model_offload = True
def enable_xformers_memory_efficient_attention(self):
self.xformers = True
def enable_xformers_memory_efficient_attention(self):
self.xformers = True

View File

@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase):
str(["foo"]),
str(PromptPhrase(["bar"], weight=1.5)),
str(["bin"]),
]
],
)
def test_multi_word_phrase(self):
@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase):
str(["foo", "bar"]),
str(PromptPhrase(["middle", "words"], weight=1.5)),
str(["bin", "bun"]),
]
],
)
def test_nested_phrase(self):
@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase):
[str(i) for i in res],
[
str(["foo"]),
str(PromptPhrase(["bar"], weight=(1.5 ** 3))),
str(PromptPhrase(["bar"], weight=(1.5**3))),
str(["bin"]),
]
],
)

View File

@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase):
params = get_config_params()
self.assertIsNotNone(params)
class AvailablePlatformTests(unittest.TestCase):
def test_before_setup(self):
platforms = get_available_platforms()
self.assertIsNotNone(platforms)
class CorrectModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_correction_models()
self.assertIsNotNone(models)
class DiffusionModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_diffusion_models()
self.assertIsNotNone(models)
class NetworkModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_network_models()
self.assertIsNotNone(models)
class UpscalingModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_upscaling_models()
self.assertIsNotNone(models)
class WildcardDataTests(unittest.TestCase):
def test_before_setup(self):
wildcards = get_wildcard_data()
self.assertIsNotNone(wildcards)
class ExtraStringsTests(unittest.TestCase):
def test_before_setup(self):
strings = get_extra_strings()
self.assertIsNotNone(strings)
class ExtraHashesTests(unittest.TestCase):
def test_before_setup(self):
hashes = get_extra_hashes()
self.assertIsNotNone(hashes)
class HighresMethodTests(unittest.TestCase):
def test_before_setup(self):
methods = get_highres_methods()
self.assertIsNotNone(methods)
class MaskFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_mask_filters()
self.assertIsNotNone(filters)
class NoiseSourceTests(unittest.TestCase):
def test_before_setup(self):
sources = get_noise_sources()
self.assertIsNotNone(sources)
class SourceFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_source_filters()
self.assertIsNotNone(filters)
class LoadExtrasTests(unittest.TestCase):
def test_default_extras(self):
server = ServerContext(extra_models=["../models/extras.json"])
load_extras(server)
class LoadModelsTests(unittest.TestCase):
def test_default_models(self):
server = ServerContext(model_path="../models")

View File

@ -4,37 +4,37 @@ from onnx_web.server.model_cache import ModelCache
class TestModelCache(unittest.TestCase):
def test_drop_existing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bar",)), 1)
def test_drop_existing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bar",)), 1)
def test_drop_missing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bin",)), 0)
def test_drop_missing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {})
self.assertGreater(cache.size, 0)
self.assertEqual(cache.drop("foo", ("bin",)), 0)
def test_get_existing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_get_existing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_get_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bin",)), None)
def test_get_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bin",)), None)
"""
"""
def test_set_existing(self):
cache = ModelCache(10)
cache.clear()
@ -48,16 +48,16 @@ class TestModelCache(unittest.TestCase):
self.assertIs(cache.get("foo", ("bar",)), value)
"""
def test_set_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_set_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_set_zero(self):
cache = ModelCache(0)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertEqual(cache.size, 0)
def test_set_zero(self):
cache = ModelCache(0)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertEqual(cache.size, 0)

View File

@ -24,253 +24,307 @@ from tests.mocks import MockPipeline
class TestAvailablePipelines(unittest.TestCase):
def test_available_pipelines(self):
pipelines = get_available_pipelines()
def test_available_pipelines(self):
pipelines = get_available_pipelines()
self.assertIn("txt2img", pipelines)
self.assertIn("txt2img", pipelines)
class TestPipelineSchedulers(unittest.TestCase):
def test_pipeline_schedulers(self):
schedulers = get_pipeline_schedulers()
def test_pipeline_schedulers(self):
schedulers = get_pipeline_schedulers()
self.assertIn("euler-a", schedulers)
self.assertIn("euler-a", schedulers)
class TestSchedulerNames(unittest.TestCase):
def test_valid_name(self):
scheduler = get_scheduler_name(DDIMScheduler)
def test_valid_name(self):
scheduler = get_scheduler_name(DDIMScheduler)
self.assertEqual("ddim", scheduler)
self.assertEqual("ddim", scheduler)
def test_missing_names(self):
self.assertIsNone(get_scheduler_name("test"))
def test_missing_names(self):
self.assertIsNone(get_scheduler_name("test"))
class TestOptimizePipeline(unittest.TestCase):
def test_auto_attention_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-attention-slicing-auto",
],
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "auto")
def test_auto_attention_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-attention-slicing-auto",
],
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "auto")
def test_max_attention_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-attention-slicing-max",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "max")
def test_max_attention_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-attention-slicing-max",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.slice_size, "max")
def test_vae_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-vae-slicing",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.vae_slicing, True)
def test_vae_slicing(self):
server = ServerContext(
optimizations=[
"diffusers-vae-slicing",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.vae_slicing, True)
def test_cpu_offload_sequential(self):
server = ServerContext(
optimizations=[
"diffusers-cpu-offload-sequential",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.sequential_offload, True)
def test_cpu_offload_sequential(self):
server = ServerContext(
optimizations=[
"diffusers-cpu-offload-sequential",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.sequential_offload, True)
def test_cpu_offload_model(self):
server = ServerContext(
optimizations=[
"diffusers-cpu-offload-model",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.model_offload, True)
def test_cpu_offload_model(self):
server = ServerContext(
optimizations=[
"diffusers-cpu-offload-model",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.model_offload, True)
def test_memory_efficient_attention(self):
server = ServerContext(
optimizations=[
"diffusers-memory-efficient-attention",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.xformers, True)
def test_memory_efficient_attention(self):
server = ServerContext(
optimizations=[
"diffusers-memory-efficient-attention",
]
)
pipeline = MockPipeline()
optimize_pipeline(server, pipeline)
self.assertEqual(pipeline.xformers, True)
class TestPatchPipeline(unittest.TestCase):
def test_expand_not_lpw(self):
"""
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertEqual(pipeline._encode_prompt, expand_prompt)
"""
pass
def test_expand_not_lpw(self):
"""
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertEqual(pipeline._encode_prompt, expand_prompt)
"""
pass
def test_unet_wrapper_not_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_not_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1))
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_vae_wrapper(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
def test_vae_wrapper(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
class TestLoadControlNet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
def test_load_existing(self):
"""
Should load a model
"""
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
@unittest.skipUnless(
path.exists("../models/control/canny.onnx"), "model does not exist"
)
self.assertIn("controlnet", components)
def test_load_existing(self):
"""
Should load a model
"""
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams(
"test",
"txt2img",
"ddim",
"test",
1.0,
10,
1,
control=NetworkModel("canny", "control"),
),
)
self.assertIn("controlnet", components)
def test_load_missing(self):
"""
Should throw
"""
components = {}
try:
components = load_controlnet(
ServerContext(),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
)
except:
self.assertNotIn("controlnet", components)
return
def test_load_missing(self):
"""
Should throw
"""
components = {}
try:
components = load_controlnet(
ServerContext(),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams(
"test",
"txt2img",
"ddim",
"test",
1.0,
10,
1,
control=NetworkModel("missing", "control"),
),
)
except Exception:
self.assertNotIn("controlnet", components)
return
self.fail()
self.fail()
class TestLoadTextEncoders(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
def test_load_embeddings(self):
"""
Should add the token to tokenizer
Should increase the encoder dims
"""
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some embeddings
],
[],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
"model does not exist",
)
self.assertIn("text_encoder", components)
def test_load_embeddings(self):
"""
Should add the token to tokenizer
Should increase the encoder dims
"""
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some embeddings
],
[],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_embeddings_xl(self):
pass
def test_load_embeddings_xl(self):
pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
def test_load_loras(self):
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[],
[
# TODO: add some loras
],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
"model does not exist",
)
self.assertIn("text_encoder", components)
def test_load_loras(self):
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[],
[
# TODO: add some loras
],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_loras_xl(self):
pass
def test_load_loras_xl(self):
pass
class TestLoadUnet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
def test_load_unet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"unet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
"model does not exist",
)
self.assertIn("unet", components)
def test_load_unet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"unet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
def test_load_unet_loras_xl(self):
pass
def test_load_unet_loras_xl(self):
pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
def test_load_cnet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"cnet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
"model does not exist",
)
self.assertIn("unet", components)
def test_load_cnet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"cnet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
class TestLoadVae(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
def test_load_single(self):
"""
Should return single component
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/upscaling-stable-diffusion-x4",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
"model does not exist",
)
self.assertIn("vae", components)
self.assertNotIn("vae_decoder", components)
self.assertNotIn("vae_encoder", components)
def test_load_single(self):
"""
Should return single component
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/upscaling-stable-diffusion-x4",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("vae", components)
self.assertNotIn("vae_decoder", components)
self.assertNotIn("vae_encoder", components)
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
def test_load_split(self):
"""
Should return split encoder/decoder
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
"model does not exist",
)
self.assertNotIn("vae", components)
self.assertIn("vae_decoder", components)
self.assertIn("vae_encoder", components)
def test_load_split(self):
"""
Should return split encoder/decoder
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertNotIn("vae", components)
self.assertIn("vae_decoder", components)
self.assertIn("vae_encoder", components)

View File

@ -17,155 +17,234 @@ from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_mod
class TestTxt2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(256, 256),
["test-txt2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-txt2img-basic.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
output = Image.open("../outputs/test-txt2img-basic.png")
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test"),
HighresParams(True, 2, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
output = Image.open("../outputs/test-txt2img-highres.png")
self.assertEqual(output.size, (512, 512))
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
1.0,
)
source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
1.0,
)
self.assertTrue(path.exists("../outputs/test-img2img.png"))
self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
)
source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
"../models/upscaling-stable-diffusion-x4",
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
)
self.assertTrue(path.exists("../outputs/test-upscale.png"))
self.assertTrue(path.exists("../outputs/test-upscale.png"))
class TestBlendPipeline(unittest.TestCase):
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white")
run_blend_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),
[source, source],
mask,
)
source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white")
run_blend_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),
[source, source],
mask,
)
self.assertTrue(path.exists("../outputs/test-blend.png"))
self.assertTrue(path.exists("../outputs/test-blend.png"))

View File

@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import (
get_loras_from_prompt,
get_scaled_latents,
get_tile_latents,
get_tokens_from_prompt,
pop_random,
slice_prompt,
)
@ -18,110 +17,128 @@ from onnx_web.params import Size
class TestExpandIntervalRanges(unittest.TestCase):
def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(prompt, result)
def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(prompt, result)
def test_prompt_with_range(self):
prompt = "an astronaut-{1,4} eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
)
def test_prompt_with_range(self):
prompt = "an astronaut-{1,4} eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
class TestExpandAlternativeRanges(unittest.TestCase):
def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger"
result = expand_alternative_ranges(prompt)
self.assertEqual([prompt], result)
def test_prompt_with_no_ranges(self):
prompt = "an astronaut eating a hamburger"
result = expand_alternative_ranges(prompt)
self.assertEqual([prompt], result)
def test_ranges_match(self):
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
result = expand_alternative_ranges(prompt)
self.assertEqual(
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
)
def test_ranges_match(self):
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
result = expand_alternative_ranges(prompt)
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
class TestInversionsFromPrompt(unittest.TestCase):
def test_get_inversions(self):
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
result, tokens = get_inversions_from_prompt(prompt)
def test_get_inversions(self):
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
result, tokens = get_inversions_from_prompt(prompt)
self.assertEqual(result, " an astronaut eating an embedding")
self.assertEqual(tokens, [("test", 1.0)])
self.assertEqual(result, " an astronaut eating an embedding")
self.assertEqual(tokens, [("test", 1.0)])
class TestLoRAsFromPrompt(unittest.TestCase):
def test_get_loras(self):
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
result, tokens = get_loras_from_prompt(prompt)
def test_get_loras(self):
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
result, tokens = get_loras_from_prompt(prompt)
self.assertEqual(result, " an astronaut eating a LoRA")
self.assertEqual(tokens, [("test", 1.0)])
self.assertEqual(result, " an astronaut eating a LoRA")
self.assertEqual(tokens, [("test", 1.0)])
class TestLatentsFromSeed(unittest.TestCase):
def test_batch_size(self):
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
self.assertEqual(latents.shape, (4, 4, 8, 8))
def test_batch_size(self):
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
self.assertEqual(latents.shape, (4, 4, 8, 8))
def test_consistency(self):
latents1 = get_latents_from_seed(1, Size(64, 64))
latents2 = get_latents_from_seed(1, Size(64, 64))
self.assertTrue(np.array_equal(latents1, latents2))
def test_consistency(self):
latents1 = get_latents_from_seed(1, Size(64, 64))
latents2 = get_latents_from_seed(1, Size(64, 64))
self.assertTrue(np.array_equal(latents1, latents2))
class TestTileLatents(unittest.TestCase):
def test_full_tile(self):
partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
def test_full_tile(self):
partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
def test_contract_tile(self):
partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
self.assertEqual(full.shape, (1, 1, 4, 4))
def test_contract_tile(self):
partial = np.zeros((1, 1, 64, 64))
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
self.assertEqual(full.shape, (1, 1, 4, 4))
def test_expand_tile(self):
partial = np.zeros((1, 1, 32, 32))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
def test_expand_tile(self):
partial = np.zeros((1, 1, 32, 32))
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
class TestScaledLatents(unittest.TestCase):
def test_scale_up(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
def test_scale_up(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
def test_scale_down(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
self.assertEqual(
(
latents[0, 0, 0, 0]
+ latents[0, 0, 0, 1]
+ latents[0, 0, 1, 0]
+ latents[0, 0, 1, 1]
)
/ 4,
scaled[0, 0, 0, 0],
)
def test_scale_down(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
self.assertEqual((
latents[0, 0, 0, 0] +
latents[0, 0, 0, 1] +
latents[0, 0, 1, 0] +
latents[0, 0, 1, 1]
) / 4, scaled[0, 0, 0, 0])
class TestReplaceWildcards(unittest.TestCase):
pass
pass
class TestPopRandom(unittest.TestCase):
def test_pop(self):
items = ["1", "2", "3"]
pop_random(items)
self.assertEqual(len(items), 2)
def test_pop(self):
items = ["1", "2", "3"]
pop_random(items)
self.assertEqual(len(items), 2)
class TestRepairNaN(unittest.TestCase):
def test_unchanged(self):
pass
def test_unchanged(self):
pass
def test_missing(self):
pass
def test_missing(self):
pass
class TestSlicePrompt(unittest.TestCase):
def test_slice_no_delimiter(self):
slice = slice_prompt("foo", 1)
self.assertEqual(slice, "foo")
def test_slice_no_delimiter(self):
slice = slice_prompt("foo", 1)
self.assertEqual(slice, "foo")
def test_slice_within_range(self):
slice = slice_prompt("foo || bar", 1)
self.assertEqual(slice, " bar")
def test_slice_within_range(self):
slice = slice_prompt("foo || bar", 1)
self.assertEqual(slice, " bar")
def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar")
def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar")

View File

@ -13,122 +13,128 @@ lock = Event()
def test_job(*args, **kwargs):
lock.wait()
lock.wait()
def wait_job(*args, **kwargs):
sleep(0.5)
sleep(0.5)
class TestWorkerPool(unittest.TestCase):
# lock: Optional[Event]
pool: Optional[DevicePoolExecutor]
# lock: Optional[Event]
pool: Optional[DevicePoolExecutor]
def setUp(self) -> None:
self.pool = None
def setUp(self) -> None:
self.pool = None
def tearDown(self) -> None:
if self.pool is not None:
self.pool.join()
def tearDown(self) -> None:
if self.pool is not None:
self.pool.join()
def test_no_devices(self):
server = ServerContext()
self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
def test_no_devices(self):
server = ServerContext()
self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
def test_fake_worker(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.assertEqual(len(self.pool.workers), 1)
def test_fake_worker(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.assertEqual(len(self.pool.workers), 1)
def test_cancel_pending(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
def test_cancel_pending(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.pool.submit("test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None))
self.pool.submit("test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None))
self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None))
self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None))
def test_cancel_running(self):
pass
def test_cancel_running(self):
pass
def test_next_device(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
def test_next_device(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.assertEqual(self.pool.get_next_device(), 0)
self.assertEqual(self.pool.get_next_device(), 0)
def test_needs_device(self):
device1 = DeviceParams("cpu1", "CPUProvider")
device2 = DeviceParams("cpu2", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
def test_needs_device(self):
device1 = DeviceParams("cpu1", "CPUProvider")
device2 = DeviceParams("cpu2", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT
)
self.pool.start()
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
def test_done_running(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
def test_done_running(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool.start(lock)
sleep(2.0)
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start(lock)
sleep(2.0)
self.pool.submit("test", test_job)
sleep(2.0)
self.pool.submit("test", test_job)
sleep(2.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock)
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock)
self.pool.submit("test1", test_job)
self.pool.submit("test2", test_job)
self.assertTrue(self.pool.done("test2"), (True, None))
self.pool.submit("test1", test_job)
self.pool.submit("test2", test_job)
self.assertTrue(self.pool.done("test2"), (True, None))
lock.set()
lock.set()
def test_done_finished(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
def test_done_finished(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool.start()
sleep(2.0)
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
sleep(2.0)
self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None))
self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None))
sleep(2.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
sleep(2.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
def test_recycle_live(self):
pass
def test_recycle_live(self):
pass
def test_recycle_dead(self):
pass
def test_recycle_dead(self):
pass
def test_running_status(self):
pass
def test_running_status(self):
pass

View File

@ -18,119 +18,194 @@ from tests.helpers import test_device
def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0])
raise Exception(MEMORY_ERRORS[0])
def main_retry(_worker):
raise RetryException()
raise RetryException()
def main_interrupt(_worker):
raise KeyboardInterrupt()
raise KeyboardInterrupt()
class WorkerMainTests(unittest.TestCase):
def test_pending_exception_empty(self):
pass
def test_pending_exception_empty(self):
pass
def test_pending_exception_interrupt(self):
status = None
def test_pending_exception_interrupt(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_interrupt, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
job = JobCommand("test", "test", main_interrupt, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
pending.put(job)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_INTERRUPT)
pass
self.assertEqual(status, EXIT_INTERRUPT)
pass
def test_pending_exception_retry(self):
status = None
def test_pending_exception_retry(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_retry, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
job = JobCommand("test", "test", main_retry, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
pending.put(job)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR)
pass
self.assertEqual(status, EXIT_ERROR)
pass
def test_pending_exception_value(self):
status = None
def test_pending_exception_value(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
def exit(exit_status):
nonlocal status
status = exit_status
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.close()
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
pending.close()
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR)
self.assertEqual(status, EXIT_ERROR)
def test_pending_exception_other_memory(self):
status = None
def test_pending_exception_other_memory(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_memory, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
job = JobCommand("test", "test", main_memory, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
pending.put(job)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_MEMORY)
self.assertEqual(status, EXIT_MEMORY)
def test_pending_exception_other_unknown(self):
pass
def test_pending_exception_other_unknown(self):
pass
def test_pending_replaced(self):
status = None
def test_pending_replaced(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
def exit(exit_status):
nonlocal status
status = exit_status
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", 0)
idle = Value("L", False)
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", 0)
idle = Value("L", False)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_REPLACED)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_REPLACED)