apply lint to tests, test highres
This commit is contained in:
parent
4691e80744
commit
65912c5a4a
|
@ -33,13 +33,17 @@ package-upload:
|
|||
|
||||
lint-check:
|
||||
black --check onnx_web/
|
||||
black --check tests/
|
||||
flake8 onnx_web
|
||||
flake8 tests
|
||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||
isort --check-only --skip __init__.py --filter-files tests
|
||||
|
||||
lint-fix:
|
||||
black onnx_web/
|
||||
black tests/
|
||||
flake8 onnx_web
|
||||
flake8 tests
|
||||
isort --skip __init__.py --filter-files onnx_web
|
||||
isort --skip __init__.py --filter-files tests
|
||||
|
||||
|
|
|
@ -19,6 +19,14 @@ class StageResult:
|
|||
def empty():
|
||||
return StageResult(images=[])
|
||||
|
||||
@staticmethod
|
||||
def from_arrays(arrays: List[np.ndarray]):
|
||||
return StageResult(arrays=arrays)
|
||||
|
||||
@staticmethod
|
||||
def from_images(images: List[Image.Image]):
|
||||
return StageResult(images=images)
|
||||
|
||||
def __init__(self, arrays=None, images=None) -> None:
|
||||
if arrays is not None and images is not None:
|
||||
raise ValueError("stages must only return one type of result")
|
||||
|
|
|
@ -25,9 +25,7 @@ class TileCallback(Protocol):
|
|||
Definition for a tile job function.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, dims: Tuple[int, int, int]
|
||||
) -> StageResult:
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||
"""
|
||||
Run this stage against a single tile.
|
||||
"""
|
||||
|
@ -319,6 +317,9 @@ def process_tile_stack(
|
|||
if mask:
|
||||
tile_mask = mask.crop((left, top, right, bottom))
|
||||
|
||||
if isinstance(tile_stack, list):
|
||||
tile_stack = StageResult.from_images(tile_stack)
|
||||
|
||||
for image_filter in filters:
|
||||
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ def main():
|
|||
# debug options
|
||||
if server.debug:
|
||||
import debugpy
|
||||
|
||||
debugpy.listen(5678)
|
||||
logger.warning("waiting for debugger")
|
||||
debugpy.wait_for_client()
|
||||
|
|
|
@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult
|
|||
class BlendGridStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendGridStage()
|
||||
sources = StageResult(images=[
|
||||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
])
|
||||
]
|
||||
)
|
||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||
|
||||
self.assertEqual(len(result), 5)
|
||||
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))
|
||||
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
||||
|
|
|
@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
|||
from onnx_web.params import DeviceParams, ImageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
|
||||
|
||||
|
||||
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_stage(self):
|
||||
"""
|
||||
stage = BlendImg2ImgStage()
|
||||
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
|
||||
server = ServerContext()
|
||||
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
|
||||
params = ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"euler-a",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
DeviceParams("cpu", "CPUProvider"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
)
|
||||
sources = [
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
||||
"""
|
||||
pass
|
||||
self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))
|
||||
|
|
|
@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult
|
|||
class BlendLinearStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendLinearStage()
|
||||
sources = StageResult(images=[
|
||||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
])
|
||||
]
|
||||
)
|
||||
stage_source = Image.new("RGB", (64, 64), "white")
|
||||
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
||||
result = stage.run(
|
||||
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))
|
||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.tile import (
|
||||
complete_tile,
|
||||
generate_tile_grid,
|
||||
|
@ -48,7 +49,7 @@ class TestNeedsTile(unittest.TestCase):
|
|||
|
||||
self.assertFalse(needs_tile(64, 64, size=small))
|
||||
|
||||
def test_with_oversized_source(self):
|
||||
def test_with_oversized_size(self):
|
||||
large = Size(64, 64)
|
||||
|
||||
self.assertTrue(needs_tile(32, 32, size=large))
|
||||
|
@ -82,21 +83,21 @@ class TestGenerateTileGrid(unittest.TestCase):
|
|||
tiles = generate_tile_grid(16, 16, 8, 0.0)
|
||||
|
||||
self.assertEqual(len(tiles), 4)
|
||||
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
||||
self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)])
|
||||
|
||||
def test_grid_no_overlap(self):
|
||||
tiles = generate_tile_grid(64, 64, 8, 0.0)
|
||||
|
||||
self.assertEqual(len(tiles), 64)
|
||||
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
||||
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
||||
self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)])
|
||||
|
||||
def test_grid_50_overlap(self):
|
||||
tiles = generate_tile_grid(64, 64, 8, 0.5)
|
||||
|
||||
self.assertEqual(len(tiles), 225)
|
||||
self.assertEqual(len(tiles), 256)
|
||||
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||
self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)])
|
||||
|
||||
|
||||
class TestGenerateTileSpiral(unittest.TestCase):
|
||||
|
@ -124,12 +125,16 @@ class TestGenerateTileSpiral(unittest.TestCase):
|
|||
class TestProcessTileStack(unittest.TestCase):
|
||||
def test_grid_full(self):
|
||||
source = Image.new("RGB", (64, 64))
|
||||
blend = process_tile_stack(source, 32, 1, [])
|
||||
blend = process_tile_stack(
|
||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||
)
|
||||
|
||||
self.assertEqual(blend.size, (64, 64))
|
||||
self.assertEqual(blend[0].size, (64, 64))
|
||||
|
||||
def test_grid_partial(self):
|
||||
source = Image.new("RGB", (72, 72))
|
||||
blend = process_tile_stack(source, 32, 1, [])
|
||||
blend = process_tile_stack(
|
||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||
)
|
||||
|
||||
self.assertEqual(blend.size, (72, 72))
|
||||
self.assertEqual(blend[0].size, (72, 72))
|
||||
|
|
|
@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase):
|
|||
def test_empty(self):
|
||||
stage = UpscaleHighresStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto
|
|||
from onnx.numpy_helper import from_array
|
||||
|
||||
from onnx_web.convert.diffusion.lora import (
|
||||
blend_loras,
|
||||
blend_node_conv_gemm,
|
||||
blend_node_matmul,
|
||||
blend_weights_loha,
|
||||
|
@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase):
|
|||
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
|
||||
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
||||
|
||||
|
||||
def test_3x3_kernel(self):
|
||||
"""
|
||||
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
|
||||
|
@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase):
|
|||
)
|
||||
(slim_model, external_weights) = buffer_external_data_tensors(model)
|
||||
|
||||
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
|
||||
self.assertEqual(
|
||||
len(slim_model.graph.initializer), len(model.graph.initializer)
|
||||
)
|
||||
self.assertEqual(len(external_weights), 1)
|
||||
|
||||
|
||||
class FixInitializerKeyTests(unittest.TestCase):
|
||||
def test_fix_name(self):
|
||||
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
|
||||
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
|
||||
inputs = [
|
||||
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
|
||||
]
|
||||
outputs = [
|
||||
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
|
||||
]
|
||||
|
||||
for input, output in zip(inputs, outputs):
|
||||
self.assertEqual(fix_initializer_name(input), output)
|
||||
|
@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase):
|
|||
nodes = {
|
||||
"input_block_proj.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
fixed = fix_xl_names(
|
||||
nodes,
|
||||
[
|
||||
NodeProto(name="/down_blocks_proj/MatMul"),
|
||||
])
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
self.assertEqual(
|
||||
fixed,
|
||||
{
|
||||
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
def test_middle_block(self):
|
||||
nodes = {
|
||||
"middle_block_proj.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
fixed = fix_xl_names(
|
||||
nodes,
|
||||
[
|
||||
NodeProto(name="/mid_blocks_proj/MatMul"),
|
||||
])
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
self.assertEqual(
|
||||
fixed,
|
||||
{
|
||||
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
def test_output_block(self):
|
||||
pass
|
||||
|
@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase):
|
|||
nodes = {
|
||||
"output_block_proj_out.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
fixed = fix_xl_names(
|
||||
nodes,
|
||||
[
|
||||
NodeProto(name="/up_blocks_proj_out/MatMul"),
|
||||
])
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
self.assertEqual(
|
||||
fixed,
|
||||
{
|
||||
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class KernelSliceTests(unittest.TestCase):
|
||||
|
@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
|
|||
self.assertEqual(result.shape, (4, 4))
|
||||
"""
|
||||
|
||||
|
||||
class BlendWeightsLoRATests(unittest.TestCase):
|
||||
def test_blend_kernel_none(self):
|
||||
model = {
|
||||
|
@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase):
|
|||
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||
self.assertEqual(result.shape, (4, 4))
|
||||
|
||||
|
||||
def test_blend_kernel_1x1(self):
|
||||
model = {
|
||||
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
|
||||
|
|
|
@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import (
|
|||
blend_embedding_embeddings,
|
||||
blend_embedding_node,
|
||||
blend_embedding_parameters,
|
||||
blend_textual_inversions,
|
||||
detect_embedding_format,
|
||||
)
|
||||
|
||||
|
@ -59,9 +58,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
|
|||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_concept(embeds, {
|
||||
blend_embedding_concept(
|
||||
embeds,
|
||||
{
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
@ -69,9 +74,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
|
|||
|
||||
def test_missing_base_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_concept(embeds, {
|
||||
blend_embedding_concept(
|
||||
embeds,
|
||||
{
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
@ -80,9 +91,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
|
|||
embeds = {
|
||||
"<test>": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_concept(embeds, {
|
||||
blend_embedding_concept(
|
||||
embeds,
|
||||
{
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
@ -92,9 +109,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
|
|||
|
||||
def test_missing_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_concept(embeds, {
|
||||
blend_embedding_concept(
|
||||
embeds,
|
||||
{
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
@ -108,9 +131,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
|
|||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_parameters(embeds, {
|
||||
blend_embedding_parameters(
|
||||
embeds,
|
||||
{
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
@ -118,9 +147,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
|
|||
|
||||
def test_missing_base_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_parameters(embeds, {
|
||||
blend_embedding_parameters(
|
||||
embeds,
|
||||
{
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
@ -129,9 +164,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
|
|||
embeds = {
|
||||
"test": np.ones(TEST_DIMS_EMBEDS),
|
||||
}
|
||||
blend_embedding_parameters(embeds, {
|
||||
blend_embedding_parameters(
|
||||
embeds,
|
||||
{
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
@ -141,9 +182,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
|
|||
|
||||
def test_missing_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_parameters(embeds, {
|
||||
blend_embedding_parameters(
|
||||
embeds,
|
||||
{
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
},
|
||||
np.float32,
|
||||
"test",
|
||||
1.0,
|
||||
)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
@ -198,14 +245,23 @@ class BlendEmbeddingNodeTests(unittest.TestCase):
|
|||
weights = from_array(np.ones(TEST_DIMS))
|
||||
weights.name = "text_model.embeddings.token_embedding.weight"
|
||||
|
||||
model = ModelProto(graph=GraphProto(initializer=[
|
||||
model = ModelProto(
|
||||
graph=GraphProto(
|
||||
initializer=[
|
||||
weights,
|
||||
]))
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
embeds = {}
|
||||
blend_embedding_node(model, {
|
||||
'convert_tokens_to_ids': lambda t: t,
|
||||
}, embeds, 2)
|
||||
blend_embedding_node(
|
||||
model,
|
||||
{
|
||||
"convert_tokens_to_ids": lambda t: t,
|
||||
},
|
||||
embeds,
|
||||
2,
|
||||
)
|
||||
|
||||
result = to_array(model.graph.initializer[0])
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ from onnx_web.convert.utils import (
|
|||
tuple_to_upscaling,
|
||||
)
|
||||
from tests.helpers import (
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
TEST_MODEL_UPSCALING_SWINIR,
|
||||
test_needs_models,
|
||||
)
|
||||
|
@ -194,21 +193,23 @@ class TupleToUpscalingTests(unittest.TestCase):
|
|||
|
||||
class SourceFormatTests(unittest.TestCase):
|
||||
def test_with_format(self):
|
||||
result = source_format({
|
||||
result = source_format(
|
||||
{
|
||||
"format": "foo",
|
||||
})
|
||||
}
|
||||
)
|
||||
self.assertEqual(result, "foo")
|
||||
|
||||
def test_source_known_extension(self):
|
||||
result = source_format({
|
||||
result = source_format(
|
||||
{
|
||||
"source": "foo.safetensors",
|
||||
})
|
||||
}
|
||||
)
|
||||
self.assertEqual(result, "safetensors")
|
||||
|
||||
def test_source_unknown_extension(self):
|
||||
result = source_format({
|
||||
"source": "foo.none"
|
||||
})
|
||||
result = source_format({"source": "foo.none"})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_incomplete_model(self):
|
||||
|
@ -234,7 +235,10 @@ class LoadTensorTests(unittest.TestCase):
|
|||
class ResolveTensorTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||
def test_resolve_existing(self):
|
||||
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
|
||||
self.assertEqual(
|
||||
resolve_tensor("../models/.cache/upscaling-swinir"),
|
||||
TEST_MODEL_UPSCALING_SWINIR,
|
||||
)
|
||||
|
||||
def test_resolve_missing(self):
|
||||
self.assertIsNone(resolve_tensor("missing"))
|
||||
|
|
|
@ -6,7 +6,9 @@ from onnx_web.params import DeviceParams
|
|||
|
||||
|
||||
def test_needs_models(models: List[str]):
|
||||
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
|
||||
return skipUnless(
|
||||
all([path.exists(model) for model in models]), "model does not exist"
|
||||
)
|
||||
|
||||
|
||||
def test_device() -> DeviceParams:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
|
||||
class MockPipeline():
|
||||
class MockPipeline:
|
||||
# flags
|
||||
slice_size: Optional[str]
|
||||
vae_slicing: Optional[bool]
|
||||
|
|
|
@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase):
|
|||
str(["foo"]),
|
||||
str(PromptPhrase(["bar"], weight=1.5)),
|
||||
str(["bin"]),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def test_multi_word_phrase(self):
|
||||
|
@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase):
|
|||
str(["foo", "bar"]),
|
||||
str(PromptPhrase(["middle", "words"], weight=1.5)),
|
||||
str(["bin", "bun"]),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def test_nested_phrase(self):
|
||||
|
@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase):
|
|||
[str(i) for i in res],
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptPhrase(["bar"], weight=(1.5 ** 3))),
|
||||
str(PromptPhrase(["bar"], weight=(1.5**3))),
|
||||
str(["bin"]),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
|
|
@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase):
|
|||
params = get_config_params()
|
||||
self.assertIsNotNone(params)
|
||||
|
||||
|
||||
class AvailablePlatformTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
platforms = get_available_platforms()
|
||||
self.assertIsNotNone(platforms)
|
||||
|
||||
|
||||
class CorrectModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_correction_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
|
||||
class DiffusionModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_diffusion_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
|
||||
class NetworkModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_network_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
|
||||
class UpscalingModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_upscaling_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
|
||||
class WildcardDataTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
wildcards = get_wildcard_data()
|
||||
self.assertIsNotNone(wildcards)
|
||||
|
||||
|
||||
class ExtraStringsTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
strings = get_extra_strings()
|
||||
self.assertIsNotNone(strings)
|
||||
|
||||
|
||||
class ExtraHashesTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
hashes = get_extra_hashes()
|
||||
self.assertIsNotNone(hashes)
|
||||
|
||||
|
||||
class HighresMethodTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
methods = get_highres_methods()
|
||||
self.assertIsNotNone(methods)
|
||||
|
||||
|
||||
class MaskFilterTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
filters = get_mask_filters()
|
||||
self.assertIsNotNone(filters)
|
||||
|
||||
|
||||
class NoiseSourceTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
sources = get_noise_sources()
|
||||
self.assertIsNotNone(sources)
|
||||
|
||||
|
||||
class SourceFilterTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
filters = get_source_filters()
|
||||
self.assertIsNotNone(filters)
|
||||
|
||||
|
||||
class LoadExtrasTests(unittest.TestCase):
|
||||
def test_default_extras(self):
|
||||
server = ServerContext(extra_models=["../models/extras.json"])
|
||||
load_extras(server)
|
||||
|
||||
|
||||
class LoadModelsTests(unittest.TestCase):
|
||||
def test_default_models(self):
|
||||
server = ServerContext(model_path="../models")
|
||||
|
|
|
@ -122,25 +122,42 @@ class TestPatchPipeline(unittest.TestCase):
|
|||
def test_unet_wrapper_not_xl(self):
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
None,
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_unet_wrapper_xl(self):
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1))
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
None,
|
||||
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_vae_wrapper(self):
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
None,
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
||||
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
||||
|
||||
|
||||
class TestLoadControlNet(unittest.TestCase):
|
||||
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/control/canny.onnx"), "model does not exist"
|
||||
)
|
||||
def test_load_existing(self):
|
||||
"""
|
||||
Should load a model
|
||||
|
@ -148,7 +165,16 @@ class TestLoadControlNet(unittest.TestCase):
|
|||
components = load_controlnet(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
|
||||
ImageParams(
|
||||
"test",
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"test",
|
||||
1.0,
|
||||
10,
|
||||
1,
|
||||
control=NetworkModel("canny", "control"),
|
||||
),
|
||||
)
|
||||
self.assertIn("controlnet", components)
|
||||
|
||||
|
@ -161,9 +187,18 @@ class TestLoadControlNet(unittest.TestCase):
|
|||
components = load_controlnet(
|
||||
ServerContext(),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
|
||||
ImageParams(
|
||||
"test",
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"test",
|
||||
1.0,
|
||||
10,
|
||||
1,
|
||||
control=NetworkModel("missing", "control"),
|
||||
),
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
self.assertNotIn("controlnet", components)
|
||||
return
|
||||
|
||||
|
@ -171,7 +206,10 @@ class TestLoadControlNet(unittest.TestCase):
|
|||
|
||||
|
||||
class TestLoadTextEncoders(unittest.TestCase):
|
||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_embeddings(self):
|
||||
"""
|
||||
Should add the token to tokenizer
|
||||
|
@ -193,7 +231,10 @@ class TestLoadTextEncoders(unittest.TestCase):
|
|||
def test_load_embeddings_xl(self):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_loras(self):
|
||||
components = load_text_encoders(
|
||||
ServerContext(model_path="../models"),
|
||||
|
@ -211,8 +252,12 @@ class TestLoadTextEncoders(unittest.TestCase):
|
|||
def test_load_loras_xl(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestLoadUnet(unittest.TestCase):
|
||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_unet_loras(self):
|
||||
components = load_unet(
|
||||
ServerContext(model_path="../models"),
|
||||
|
@ -229,7 +274,10 @@ class TestLoadUnet(unittest.TestCase):
|
|||
def test_load_unet_loras_xl(self):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_cnet_loras(self):
|
||||
components = load_unet(
|
||||
ServerContext(model_path="../models"),
|
||||
|
@ -245,7 +293,10 @@ class TestLoadUnet(unittest.TestCase):
|
|||
|
||||
|
||||
class TestLoadVae(unittest.TestCase):
|
||||
@unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_single(self):
|
||||
"""
|
||||
Should return single component
|
||||
|
@ -260,7 +311,10 @@ class TestLoadVae(unittest.TestCase):
|
|||
self.assertNotIn("vae_decoder", components)
|
||||
self.assertNotIn("vae_encoder", components)
|
||||
|
||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
|
||||
@unittest.skipUnless(
|
||||
path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
|
||||
"model does not exist",
|
||||
)
|
||||
def test_load_split(self):
|
||||
"""
|
||||
Should return split encoder/decoder
|
||||
|
|
|
@ -44,14 +44,70 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img.png"],
|
||||
["test-txt2img-basic.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
|
||||
output = Image.open("../outputs/test-txt2img-basic.png")
|
||||
self.assertEqual(output.size, (256, 256))
|
||||
# TODO: test contents of image
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_highres(self):
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
active = Value("L", 0)
|
||||
idle = Value("L", 0)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
active,
|
||||
idle,
|
||||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-highres.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(True, 2, 0, 0),
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
|
||||
output = Image.open("../outputs/test-txt2img-highres.png")
|
||||
self.assertEqual(output.size, (512, 512))
|
||||
|
||||
|
||||
class TestImg2ImgPipeline(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
|
@ -82,7 +138,14 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
|||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
["test-img2img.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
|
@ -92,6 +155,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
|||
|
||||
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
||||
|
||||
|
||||
class TestUpscalePipeline(unittest.TestCase):
|
||||
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
||||
def test_basic(self):
|
||||
|
@ -121,7 +185,14 @@ class TestUpscalePipeline(unittest.TestCase):
|
|||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||
"../models/upscaling-stable-diffusion-x4",
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-upscale.png"],
|
||||
UpscaleParams("test"),
|
||||
|
@ -131,6 +202,7 @@ class TestUpscalePipeline(unittest.TestCase):
|
|||
|
||||
self.assertTrue(path.exists("../outputs/test-upscale.png"))
|
||||
|
||||
|
||||
class TestBlendPipeline(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
cancel = Value("L", 0)
|
||||
|
@ -160,7 +232,14 @@ class TestBlendPipeline(unittest.TestCase):
|
|||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(64, 64),
|
||||
["test-blend.png"],
|
||||
UpscaleParams("test"),
|
||||
|
|
|
@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import (
|
|||
get_loras_from_prompt,
|
||||
get_scaled_latents,
|
||||
get_tile_latents,
|
||||
get_tokens_from_prompt,
|
||||
pop_random,
|
||||
slice_prompt,
|
||||
)
|
||||
|
@ -26,7 +25,10 @@ class TestExpandIntervalRanges(unittest.TestCase):
|
|||
def test_prompt_with_range(self):
|
||||
prompt = "an astronaut-{1,4} eating a hamburger"
|
||||
result = expand_interval_ranges(prompt)
|
||||
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
|
||||
self.assertEqual(
|
||||
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
|
||||
)
|
||||
|
||||
|
||||
class TestExpandAlternativeRanges(unittest.TestCase):
|
||||
def test_prompt_with_no_ranges(self):
|
||||
|
@ -37,7 +39,10 @@ class TestExpandAlternativeRanges(unittest.TestCase):
|
|||
def test_ranges_match(self):
|
||||
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
|
||||
result = expand_alternative_ranges(prompt)
|
||||
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
|
||||
self.assertEqual(
|
||||
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
|
||||
)
|
||||
|
||||
|
||||
class TestInversionsFromPrompt(unittest.TestCase):
|
||||
def test_get_inversions(self):
|
||||
|
@ -47,6 +52,7 @@ class TestInversionsFromPrompt(unittest.TestCase):
|
|||
self.assertEqual(result, " an astronaut eating an embedding")
|
||||
self.assertEqual(tokens, [("test", 1.0)])
|
||||
|
||||
|
||||
class TestLoRAsFromPrompt(unittest.TestCase):
|
||||
def test_get_loras(self):
|
||||
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
|
||||
|
@ -55,6 +61,7 @@ class TestLoRAsFromPrompt(unittest.TestCase):
|
|||
self.assertEqual(result, " an astronaut eating a LoRA")
|
||||
self.assertEqual(tokens, [("test", 1.0)])
|
||||
|
||||
|
||||
class TestLatentsFromSeed(unittest.TestCase):
|
||||
def test_batch_size(self):
|
||||
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
|
||||
|
@ -65,6 +72,7 @@ class TestLatentsFromSeed(unittest.TestCase):
|
|||
latents2 = get_latents_from_seed(1, Size(64, 64))
|
||||
self.assertTrue(np.array_equal(latents1, latents2))
|
||||
|
||||
|
||||
class TestTileLatents(unittest.TestCase):
|
||||
def test_full_tile(self):
|
||||
partial = np.zeros((1, 1, 64, 64))
|
||||
|
@ -81,6 +89,7 @@ class TestTileLatents(unittest.TestCase):
|
|||
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
||||
self.assertEqual(full.shape, (1, 1, 8, 8))
|
||||
|
||||
|
||||
class TestScaledLatents(unittest.TestCase):
|
||||
def test_scale_up(self):
|
||||
latents = get_latents_from_seed(1, Size(16, 16))
|
||||
|
@ -90,22 +99,29 @@ class TestScaledLatents(unittest.TestCase):
|
|||
def test_scale_down(self):
|
||||
latents = get_latents_from_seed(1, Size(16, 16))
|
||||
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
|
||||
self.assertEqual((
|
||||
latents[0, 0, 0, 0] +
|
||||
latents[0, 0, 0, 1] +
|
||||
latents[0, 0, 1, 0] +
|
||||
latents[0, 0, 1, 1]
|
||||
) / 4, scaled[0, 0, 0, 0])
|
||||
self.assertEqual(
|
||||
(
|
||||
latents[0, 0, 0, 0]
|
||||
+ latents[0, 0, 0, 1]
|
||||
+ latents[0, 0, 1, 0]
|
||||
+ latents[0, 0, 1, 1]
|
||||
)
|
||||
/ 4,
|
||||
scaled[0, 0, 0, 0],
|
||||
)
|
||||
|
||||
|
||||
class TestReplaceWildcards(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestPopRandom(unittest.TestCase):
|
||||
def test_pop(self):
|
||||
items = ["1", "2", "3"]
|
||||
pop_random(items)
|
||||
self.assertEqual(len(items), 2)
|
||||
|
||||
|
||||
class TestRepairNaN(unittest.TestCase):
|
||||
def test_unchanged(self):
|
||||
pass
|
||||
|
@ -113,6 +129,7 @@ class TestRepairNaN(unittest.TestCase):
|
|||
def test_missing(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestSlicePrompt(unittest.TestCase):
|
||||
def test_slice_no_delimiter(self):
|
||||
slice = slice_prompt("foo", 1)
|
||||
|
|
|
@ -71,7 +71,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
device1 = DeviceParams("cpu1", "CPUProvider")
|
||||
device2 = DeviceParams("cpu2", "CPUProvider")
|
||||
server = ServerContext()
|
||||
self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT
|
||||
)
|
||||
self.pool.start()
|
||||
|
||||
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
||||
|
@ -83,7 +85,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
device = DeviceParams("cpu", "CPUProvider")
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start(lock)
|
||||
sleep(2.0)
|
||||
|
||||
|
@ -113,7 +117,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
device = DeviceParams("cpu", "CPUProvider")
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
sleep(2.0)
|
||||
|
||||
|
|
|
@ -20,9 +20,11 @@ from tests.helpers import test_device
|
|||
def main_memory(_worker):
|
||||
raise Exception(MEMORY_ERRORS[0])
|
||||
|
||||
|
||||
def main_retry(_worker):
|
||||
raise RetryException()
|
||||
|
||||
|
||||
def main_interrupt(_worker):
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
|
@ -47,7 +49,22 @@ class WorkerMainTests(unittest.TestCase):
|
|||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
pid,
|
||||
idle,
|
||||
0,
|
||||
0.0,
|
||||
),
|
||||
ServerContext(),
|
||||
exit=exit,
|
||||
)
|
||||
|
||||
self.assertEqual(status, EXIT_INTERRUPT)
|
||||
pass
|
||||
|
@ -68,7 +85,22 @@ class WorkerMainTests(unittest.TestCase):
|
|||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
pid,
|
||||
idle,
|
||||
0,
|
||||
0.0,
|
||||
),
|
||||
ServerContext(),
|
||||
exit=exit,
|
||||
)
|
||||
|
||||
self.assertEqual(status, EXIT_ERROR)
|
||||
pass
|
||||
|
@ -88,7 +120,22 @@ class WorkerMainTests(unittest.TestCase):
|
|||
idle = Value("L", False)
|
||||
|
||||
pending.close()
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
pid,
|
||||
idle,
|
||||
0,
|
||||
0.0,
|
||||
),
|
||||
ServerContext(),
|
||||
exit=exit,
|
||||
)
|
||||
|
||||
self.assertEqual(status, EXIT_ERROR)
|
||||
|
||||
|
@ -108,11 +155,25 @@ class WorkerMainTests(unittest.TestCase):
|
|||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
pid,
|
||||
idle,
|
||||
0,
|
||||
0.0,
|
||||
),
|
||||
ServerContext(),
|
||||
exit=exit,
|
||||
)
|
||||
|
||||
self.assertEqual(status, EXIT_MEMORY)
|
||||
|
||||
|
||||
def test_pending_exception_other_unknown(self):
|
||||
pass
|
||||
|
||||
|
@ -130,7 +191,21 @@ class WorkerMainTests(unittest.TestCase):
|
|||
pid = Value("L", 0)
|
||||
idle = Value("L", False)
|
||||
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
pid,
|
||||
idle,
|
||||
0,
|
||||
0.0,
|
||||
),
|
||||
ServerContext(),
|
||||
exit=exit,
|
||||
)
|
||||
|
||||
self.assertEqual(status, EXIT_REPLACED)
|
||||
|
||||
|
|
Loading…
Reference in New Issue