apply lint to tests, test highres
This commit is contained in:
parent
4691e80744
commit
65912c5a4a
|
@ -33,13 +33,17 @@ package-upload:
|
||||||
|
|
||||||
lint-check:
|
lint-check:
|
||||||
black --check onnx_web/
|
black --check onnx_web/
|
||||||
|
black --check tests/
|
||||||
flake8 onnx_web
|
flake8 onnx_web
|
||||||
|
flake8 tests
|
||||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||||
isort --check-only --skip __init__.py --filter-files tests
|
isort --check-only --skip __init__.py --filter-files tests
|
||||||
|
|
||||||
lint-fix:
|
lint-fix:
|
||||||
black onnx_web/
|
black onnx_web/
|
||||||
|
black tests/
|
||||||
flake8 onnx_web
|
flake8 onnx_web
|
||||||
|
flake8 tests
|
||||||
isort --skip __init__.py --filter-files onnx_web
|
isort --skip __init__.py --filter-files onnx_web
|
||||||
isort --skip __init__.py --filter-files tests
|
isort --skip __init__.py --filter-files tests
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,14 @@ class StageResult:
|
||||||
def empty():
|
def empty():
|
||||||
return StageResult(images=[])
|
return StageResult(images=[])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_arrays(arrays: List[np.ndarray]):
|
||||||
|
return StageResult(arrays=arrays)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_images(images: List[Image.Image]):
|
||||||
|
return StageResult(images=images)
|
||||||
|
|
||||||
def __init__(self, arrays=None, images=None) -> None:
|
def __init__(self, arrays=None, images=None) -> None:
|
||||||
if arrays is not None and images is not None:
|
if arrays is not None and images is not None:
|
||||||
raise ValueError("stages must only return one type of result")
|
raise ValueError("stages must only return one type of result")
|
||||||
|
|
|
@ -25,9 +25,7 @@ class TileCallback(Protocol):
|
||||||
Definition for a tile job function.
|
Definition for a tile job function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||||
self, image: Image.Image, dims: Tuple[int, int, int]
|
|
||||||
) -> StageResult:
|
|
||||||
"""
|
"""
|
||||||
Run this stage against a single tile.
|
Run this stage against a single tile.
|
||||||
"""
|
"""
|
||||||
|
@ -319,6 +317,9 @@ def process_tile_stack(
|
||||||
if mask:
|
if mask:
|
||||||
tile_mask = mask.crop((left, top, right, bottom))
|
tile_mask = mask.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
if isinstance(tile_stack, list):
|
||||||
|
tile_stack = StageResult.from_images(tile_stack)
|
||||||
|
|
||||||
for image_filter in filters:
|
for image_filter in filters:
|
||||||
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ def main():
|
||||||
# debug options
|
# debug options
|
||||||
if server.debug:
|
if server.debug:
|
||||||
import debugpy
|
import debugpy
|
||||||
|
|
||||||
debugpy.listen(5678)
|
debugpy.listen(5678)
|
||||||
logger.warning("waiting for debugger")
|
logger.warning("waiting for debugger")
|
||||||
debugpy.wait_for_client()
|
debugpy.wait_for_client()
|
||||||
|
|
|
@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult
|
||||||
class BlendGridStageTests(unittest.TestCase):
|
class BlendGridStageTests(unittest.TestCase):
|
||||||
def test_stage(self):
|
def test_stage(self):
|
||||||
stage = BlendGridStage()
|
stage = BlendGridStage()
|
||||||
sources = StageResult(images=[
|
sources = StageResult(
|
||||||
Image.new("RGB", (64, 64), "black"),
|
images=[
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "white"),
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
])
|
Image.new("RGB", (64, 64), "white"),
|
||||||
|
]
|
||||||
|
)
|
||||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||||
|
|
||||||
self.assertEqual(len(result), 5)
|
self.assertEqual(len(result), 5)
|
||||||
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))
|
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
||||||
|
|
|
@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||||
from onnx_web.params import DeviceParams, ImageParams
|
from onnx_web.params import DeviceParams, ImageParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
|
||||||
|
|
||||||
|
|
||||||
class BlendImg2ImgStageTests(unittest.TestCase):
|
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_stage(self):
|
def test_stage(self):
|
||||||
"""
|
|
||||||
stage = BlendImg2ImgStage()
|
stage = BlendImg2ImgStage()
|
||||||
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
|
params = ImageParams(
|
||||||
server = ServerContext()
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
|
"txt2img",
|
||||||
|
"euler-a",
|
||||||
|
"an astronaut eating a hamburger",
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
DeviceParams("cpu", "CPUProvider"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
)
|
||||||
sources = [
|
sources = [
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
]
|
]
|
||||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
|
@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult
|
||||||
class BlendLinearStageTests(unittest.TestCase):
|
class BlendLinearStageTests(unittest.TestCase):
|
||||||
def test_stage(self):
|
def test_stage(self):
|
||||||
stage = BlendLinearStage()
|
stage = BlendLinearStage()
|
||||||
sources = StageResult(images=[
|
sources = StageResult(
|
||||||
Image.new("RGB", (64, 64), "black"),
|
images=[
|
||||||
])
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
]
|
||||||
|
)
|
||||||
stage_source = Image.new("RGB", (64, 64), "white")
|
stage_source = Image.new("RGB", (64, 64), "white")
|
||||||
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
result = stage.run(
|
||||||
|
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))
|
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
||||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.chain.tile import (
|
from onnx_web.chain.tile import (
|
||||||
complete_tile,
|
complete_tile,
|
||||||
generate_tile_grid,
|
generate_tile_grid,
|
||||||
|
@ -14,122 +15,126 @@ from onnx_web.params import Size
|
||||||
|
|
||||||
|
|
||||||
class TestCompleteTile(unittest.TestCase):
|
class TestCompleteTile(unittest.TestCase):
|
||||||
def test_with_complete_tile(self):
|
def test_with_complete_tile(self):
|
||||||
partial = Image.new("RGB", (64, 64))
|
partial = Image.new("RGB", (64, 64))
|
||||||
output = complete_tile(partial, 64)
|
output = complete_tile(partial, 64)
|
||||||
|
|
||||||
self.assertEqual(output.size, (64, 64))
|
self.assertEqual(output.size, (64, 64))
|
||||||
|
|
||||||
def test_with_partial_tile(self):
|
def test_with_partial_tile(self):
|
||||||
partial = Image.new("RGB", (64, 32))
|
partial = Image.new("RGB", (64, 32))
|
||||||
output = complete_tile(partial, 64)
|
output = complete_tile(partial, 64)
|
||||||
|
|
||||||
self.assertEqual(output.size, (64, 64))
|
self.assertEqual(output.size, (64, 64))
|
||||||
|
|
||||||
def test_with_nothing(self):
|
def test_with_nothing(self):
|
||||||
output = complete_tile(None, 64)
|
output = complete_tile(None, 64)
|
||||||
|
|
||||||
self.assertIsNone(output)
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
|
||||||
class TestNeedsTile(unittest.TestCase):
|
class TestNeedsTile(unittest.TestCase):
|
||||||
def test_with_undersized_source(self):
|
def test_with_undersized_source(self):
|
||||||
small = Image.new("RGB", (32, 32))
|
small = Image.new("RGB", (32, 32))
|
||||||
|
|
||||||
self.assertFalse(needs_tile(64, 64, source=small))
|
self.assertFalse(needs_tile(64, 64, source=small))
|
||||||
|
|
||||||
def test_with_oversized_source(self):
|
def test_with_oversized_source(self):
|
||||||
large = Image.new("RGB", (64, 64))
|
large = Image.new("RGB", (64, 64))
|
||||||
|
|
||||||
self.assertTrue(needs_tile(32, 32, source=large))
|
self.assertTrue(needs_tile(32, 32, source=large))
|
||||||
|
|
||||||
def test_with_undersized_size(self):
|
def test_with_undersized_size(self):
|
||||||
small = Size(32, 32)
|
small = Size(32, 32)
|
||||||
|
|
||||||
self.assertFalse(needs_tile(64, 64, size=small))
|
self.assertFalse(needs_tile(64, 64, size=small))
|
||||||
|
|
||||||
def test_with_oversized_source(self):
|
def test_with_oversized_size(self):
|
||||||
large = Size(64, 64)
|
large = Size(64, 64)
|
||||||
|
|
||||||
self.assertTrue(needs_tile(32, 32, size=large))
|
self.assertTrue(needs_tile(32, 32, size=large))
|
||||||
|
|
||||||
def test_with_nothing(self):
|
def test_with_nothing(self):
|
||||||
self.assertFalse(needs_tile(32, 32))
|
self.assertFalse(needs_tile(32, 32))
|
||||||
|
|
||||||
|
|
||||||
class TestTileGrads(unittest.TestCase):
|
class TestTileGrads(unittest.TestCase):
|
||||||
def test_center_tile(self):
|
def test_center_tile(self):
|
||||||
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
|
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||||
|
|
||||||
def test_vertical_edge_tile(self):
|
def test_vertical_edge_tile(self):
|
||||||
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
|
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||||
self.assertEqual(grad_y, [1, 1, 1, 1])
|
self.assertEqual(grad_y, [1, 1, 1, 1])
|
||||||
|
|
||||||
def test_horizontal_edge_tile(self):
|
def test_horizontal_edge_tile(self):
|
||||||
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
|
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [1, 1, 1, 1])
|
self.assertEqual(grad_x, [1, 1, 1, 1])
|
||||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateTileGrid(unittest.TestCase):
|
class TestGenerateTileGrid(unittest.TestCase):
|
||||||
def test_grid_complete(self):
|
def test_grid_complete(self):
|
||||||
tiles = generate_tile_grid(16, 16, 8, 0.0)
|
tiles = generate_tile_grid(16, 16, 8, 0.0)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 4)
|
self.assertEqual(len(tiles), 4)
|
||||||
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)])
|
||||||
|
|
||||||
def test_grid_no_overlap(self):
|
def test_grid_no_overlap(self):
|
||||||
tiles = generate_tile_grid(64, 64, 8, 0.0)
|
tiles = generate_tile_grid(64, 64, 8, 0.0)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 64)
|
self.assertEqual(len(tiles), 64)
|
||||||
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
||||||
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)])
|
||||||
|
|
||||||
def test_grid_50_overlap(self):
|
def test_grid_50_overlap(self):
|
||||||
tiles = generate_tile_grid(64, 64, 8, 0.5)
|
tiles = generate_tile_grid(64, 64, 8, 0.5)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 225)
|
self.assertEqual(len(tiles), 256)
|
||||||
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)])
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateTileSpiral(unittest.TestCase):
|
class TestGenerateTileSpiral(unittest.TestCase):
|
||||||
def test_spiral_complete(self):
|
def test_spiral_complete(self):
|
||||||
tiles = generate_tile_spiral(16, 16, 8, 0.0)
|
tiles = generate_tile_spiral(16, 16, 8, 0.0)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 4)
|
self.assertEqual(len(tiles), 4)
|
||||||
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
||||||
|
|
||||||
def test_spiral_no_overlap(self):
|
def test_spiral_no_overlap(self):
|
||||||
tiles = generate_tile_spiral(64, 64, 8, 0.0)
|
tiles = generate_tile_spiral(64, 64, 8, 0.0)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 64)
|
self.assertEqual(len(tiles), 64)
|
||||||
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
||||||
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
||||||
|
|
||||||
def test_spiral_50_overlap(self):
|
def test_spiral_50_overlap(self):
|
||||||
tiles = generate_tile_spiral(64, 64, 8, 0.5)
|
tiles = generate_tile_spiral(64, 64, 8, 0.5)
|
||||||
|
|
||||||
self.assertEqual(len(tiles), 225)
|
self.assertEqual(len(tiles), 225)
|
||||||
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||||
|
|
||||||
|
|
||||||
class TestProcessTileStack(unittest.TestCase):
|
class TestProcessTileStack(unittest.TestCase):
|
||||||
def test_grid_full(self):
|
def test_grid_full(self):
|
||||||
source = Image.new("RGB", (64, 64))
|
source = Image.new("RGB", (64, 64))
|
||||||
blend = process_tile_stack(source, 32, 1, [])
|
blend = process_tile_stack(
|
||||||
|
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(blend.size, (64, 64))
|
self.assertEqual(blend[0].size, (64, 64))
|
||||||
|
|
||||||
def test_grid_partial(self):
|
def test_grid_partial(self):
|
||||||
source = Image.new("RGB", (72, 72))
|
source = Image.new("RGB", (72, 72))
|
||||||
blend = process_tile_stack(source, 32, 1, [])
|
blend = process_tile_stack(
|
||||||
|
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(blend.size, (72, 72))
|
self.assertEqual(blend[0].size, (72, 72))
|
||||||
|
|
|
@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = UpscaleHighresStage()
|
stage = UpscaleHighresStage()
|
||||||
sources = StageResult.empty()
|
sources = StageResult.empty()
|
||||||
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
result = stage.run(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto
|
||||||
from onnx.numpy_helper import from_array
|
from onnx.numpy_helper import from_array
|
||||||
|
|
||||||
from onnx_web.convert.diffusion.lora import (
|
from onnx_web.convert.diffusion.lora import (
|
||||||
blend_loras,
|
|
||||||
blend_node_conv_gemm,
|
blend_node_conv_gemm,
|
||||||
blend_node_matmul,
|
blend_node_matmul,
|
||||||
blend_weights_loha,
|
blend_weights_loha,
|
||||||
|
@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase):
|
||||||
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
|
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
|
||||||
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
def test_3x3_kernel(self):
|
def test_3x3_kernel(self):
|
||||||
"""
|
"""
|
||||||
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
|
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
|
||||||
|
@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
(slim_model, external_weights) = buffer_external_data_tensors(model)
|
(slim_model, external_weights) = buffer_external_data_tensors(model)
|
||||||
|
|
||||||
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
|
self.assertEqual(
|
||||||
|
len(slim_model.graph.initializer), len(model.graph.initializer)
|
||||||
|
)
|
||||||
self.assertEqual(len(external_weights), 1)
|
self.assertEqual(len(external_weights), 1)
|
||||||
|
|
||||||
|
|
||||||
class FixInitializerKeyTests(unittest.TestCase):
|
class FixInitializerKeyTests(unittest.TestCase):
|
||||||
def test_fix_name(self):
|
def test_fix_name(self):
|
||||||
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
|
inputs = [
|
||||||
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
|
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
|
||||||
|
]
|
||||||
|
outputs = [
|
||||||
|
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
|
||||||
|
]
|
||||||
|
|
||||||
for input, output in zip(inputs, outputs):
|
for input, output in zip(inputs, outputs):
|
||||||
self.assertEqual(fix_initializer_name(input), output)
|
self.assertEqual(fix_initializer_name(input), output)
|
||||||
|
@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase):
|
||||||
nodes = {
|
nodes = {
|
||||||
"input_block_proj.lora_down.weight": {},
|
"input_block_proj.lora_down.weight": {},
|
||||||
}
|
}
|
||||||
fixed = fix_xl_names(nodes, [
|
fixed = fix_xl_names(
|
||||||
NodeProto(name="/down_blocks_proj/MatMul"),
|
nodes,
|
||||||
])
|
[
|
||||||
|
NodeProto(name="/down_blocks_proj/MatMul"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(fixed, {
|
self.assertEqual(
|
||||||
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
|
fixed,
|
||||||
})
|
{
|
||||||
|
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_middle_block(self):
|
def test_middle_block(self):
|
||||||
nodes = {
|
nodes = {
|
||||||
"middle_block_proj.lora_down.weight": {},
|
"middle_block_proj.lora_down.weight": {},
|
||||||
}
|
}
|
||||||
fixed = fix_xl_names(nodes, [
|
fixed = fix_xl_names(
|
||||||
NodeProto(name="/mid_blocks_proj/MatMul"),
|
nodes,
|
||||||
])
|
[
|
||||||
|
NodeProto(name="/mid_blocks_proj/MatMul"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(fixed, {
|
self.assertEqual(
|
||||||
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
|
fixed,
|
||||||
})
|
{
|
||||||
|
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_output_block(self):
|
def test_output_block(self):
|
||||||
pass
|
pass
|
||||||
|
@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase):
|
||||||
nodes = {
|
nodes = {
|
||||||
"output_block_proj_out.lora_down.weight": {},
|
"output_block_proj_out.lora_down.weight": {},
|
||||||
}
|
}
|
||||||
fixed = fix_xl_names(nodes, [
|
fixed = fix_xl_names(
|
||||||
NodeProto(name="/up_blocks_proj_out/MatMul"),
|
nodes,
|
||||||
])
|
[
|
||||||
|
NodeProto(name="/up_blocks_proj_out/MatMul"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(fixed, {
|
self.assertEqual(
|
||||||
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
|
fixed,
|
||||||
})
|
{
|
||||||
|
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KernelSliceTests(unittest.TestCase):
|
class KernelSliceTests(unittest.TestCase):
|
||||||
|
@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
|
||||||
self.assertEqual(result.shape, (4, 4))
|
self.assertEqual(result.shape, (4, 4))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BlendWeightsLoRATests(unittest.TestCase):
|
class BlendWeightsLoRATests(unittest.TestCase):
|
||||||
def test_blend_kernel_none(self):
|
def test_blend_kernel_none(self):
|
||||||
model = {
|
model = {
|
||||||
|
@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase):
|
||||||
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||||
self.assertEqual(result.shape, (4, 4))
|
self.assertEqual(result.shape, (4, 4))
|
||||||
|
|
||||||
|
|
||||||
def test_blend_kernel_1x1(self):
|
def test_blend_kernel_1x1(self):
|
||||||
model = {
|
model = {
|
||||||
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
|
||||||
|
|
|
@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import (
|
||||||
blend_embedding_embeddings,
|
blend_embedding_embeddings,
|
||||||
blend_embedding_node,
|
blend_embedding_node,
|
||||||
blend_embedding_parameters,
|
blend_embedding_parameters,
|
||||||
blend_textual_inversions,
|
|
||||||
detect_embedding_format,
|
detect_embedding_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,210 +17,267 @@ TEST_DIMS = (8, 8)
|
||||||
TEST_DIMS_EMBEDS = (1, *TEST_DIMS)
|
TEST_DIMS_EMBEDS = (1, *TEST_DIMS)
|
||||||
|
|
||||||
TEST_MODEL_EMBEDS = {
|
TEST_MODEL_EMBEDS = {
|
||||||
"string_to_token": {
|
"string_to_token": {
|
||||||
"test": 1,
|
"test": 1,
|
||||||
},
|
},
|
||||||
"string_to_param": {
|
"string_to_param": {
|
||||||
"test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
"test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DetectEmbeddingFormatTests(unittest.TestCase):
|
class DetectEmbeddingFormatTests(unittest.TestCase):
|
||||||
def test_concept(self):
|
def test_concept(self):
|
||||||
embedding = {
|
embedding = {
|
||||||
"<test>": "test",
|
"<test>": "test",
|
||||||
}
|
}
|
||||||
self.assertEqual(detect_embedding_format(embedding), "concept")
|
self.assertEqual(detect_embedding_format(embedding), "concept")
|
||||||
|
|
||||||
def test_parameters(self):
|
def test_parameters(self):
|
||||||
embedding = {
|
embedding = {
|
||||||
"emb_params": "test",
|
"emb_params": "test",
|
||||||
}
|
}
|
||||||
self.assertEqual(detect_embedding_format(embedding), "parameters")
|
self.assertEqual(detect_embedding_format(embedding), "parameters")
|
||||||
|
|
||||||
def test_embeddings(self):
|
def test_embeddings(self):
|
||||||
embedding = {
|
embedding = {
|
||||||
"string_to_token": "test",
|
"string_to_token": "test",
|
||||||
"string_to_param": "test",
|
"string_to_param": "test",
|
||||||
}
|
}
|
||||||
self.assertEqual(detect_embedding_format(embedding), "embeddings")
|
self.assertEqual(detect_embedding_format(embedding), "embeddings")
|
||||||
|
|
||||||
def test_unknown(self):
|
def test_unknown(self):
|
||||||
embedding = {
|
embedding = {
|
||||||
"what_is_this": "test",
|
"what_is_this": "test",
|
||||||
}
|
}
|
||||||
self.assertEqual(detect_embedding_format(embedding), None)
|
self.assertEqual(detect_embedding_format(embedding), None)
|
||||||
|
|
||||||
|
|
||||||
class BlendEmbeddingConceptTests(unittest.TestCase):
|
class BlendEmbeddingConceptTests(unittest.TestCase):
|
||||||
def test_existing_base_token(self):
|
def test_existing_base_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"test": np.ones(TEST_DIMS),
|
"test": np.ones(TEST_DIMS),
|
||||||
}
|
}
|
||||||
blend_embedding_concept(embeds, {
|
blend_embedding_concept(
|
||||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
self.assertEqual(embeds["test"].mean(), 2)
|
self.assertEqual(embeds["test"].mean(), 2)
|
||||||
|
|
||||||
def test_missing_base_token(self):
|
def test_missing_base_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_concept(embeds, {
|
blend_embedding_concept(
|
||||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
|
|
||||||
def test_existing_token(self):
|
def test_existing_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"<test>": np.ones(TEST_DIMS),
|
"<test>": np.ones(TEST_DIMS),
|
||||||
}
|
}
|
||||||
blend_embedding_concept(embeds, {
|
blend_embedding_concept(
|
||||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["<test>", "test"])
|
self.assertEqual(keys, ["<test>", "test"])
|
||||||
|
|
||||||
def test_missing_token(self):
|
def test_missing_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_concept(embeds, {
|
blend_embedding_concept(
|
||||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["<test>", "test"])
|
self.assertEqual(keys, ["<test>", "test"])
|
||||||
|
|
||||||
|
|
||||||
class BlendEmbeddingParametersTests(unittest.TestCase):
|
class BlendEmbeddingParametersTests(unittest.TestCase):
|
||||||
def test_existing_base_token(self):
|
def test_existing_base_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"test": np.ones(TEST_DIMS),
|
"test": np.ones(TEST_DIMS),
|
||||||
}
|
}
|
||||||
blend_embedding_parameters(embeds, {
|
blend_embedding_parameters(
|
||||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
self.assertEqual(embeds["test"].mean(), 2)
|
self.assertEqual(embeds["test"].mean(), 2)
|
||||||
|
|
||||||
def test_missing_base_token(self):
|
def test_missing_base_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_parameters(embeds, {
|
blend_embedding_parameters(
|
||||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
|
|
||||||
def test_existing_token(self):
|
def test_existing_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"test": np.ones(TEST_DIMS_EMBEDS),
|
"test": np.ones(TEST_DIMS_EMBEDS),
|
||||||
}
|
}
|
||||||
blend_embedding_parameters(embeds, {
|
blend_embedding_parameters(
|
||||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||||
|
|
||||||
def test_missing_token(self):
|
def test_missing_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_parameters(embeds, {
|
blend_embedding_parameters(
|
||||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
embeds,
|
||||||
}, np.float32, "test", 1.0)
|
{
|
||||||
|
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||||
|
},
|
||||||
|
np.float32,
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||||
|
|
||||||
|
|
||||||
class BlendEmbeddingEmbeddingsTests(unittest.TestCase):
|
class BlendEmbeddingEmbeddingsTests(unittest.TestCase):
|
||||||
def test_existing_base_token(self):
|
def test_existing_base_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"test": np.ones(TEST_DIMS),
|
"test": np.ones(TEST_DIMS),
|
||||||
}
|
}
|
||||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
self.assertEqual(embeds["test"].mean(), 2)
|
self.assertEqual(embeds["test"].mean(), 2)
|
||||||
|
|
||||||
def test_missing_base_token(self):
|
def test_missing_base_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||||
|
|
||||||
def test_existing_token(self):
|
def test_existing_token(self):
|
||||||
embeds = {
|
embeds = {
|
||||||
"test": np.ones(TEST_DIMS),
|
"test": np.ones(TEST_DIMS),
|
||||||
}
|
}
|
||||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||||
|
|
||||||
def test_missing_token(self):
|
def test_missing_token(self):
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||||
|
|
||||||
keys = list(embeds.keys())
|
keys = list(embeds.keys())
|
||||||
keys.sort()
|
keys.sort()
|
||||||
|
|
||||||
self.assertIn("test", embeds)
|
self.assertIn("test", embeds)
|
||||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||||
|
|
||||||
|
|
||||||
class BlendEmbeddingNodeTests(unittest.TestCase):
|
class BlendEmbeddingNodeTests(unittest.TestCase):
|
||||||
def test_expand_weights(self):
|
def test_expand_weights(self):
|
||||||
weights = from_array(np.ones(TEST_DIMS))
|
weights = from_array(np.ones(TEST_DIMS))
|
||||||
weights.name = "text_model.embeddings.token_embedding.weight"
|
weights.name = "text_model.embeddings.token_embedding.weight"
|
||||||
|
|
||||||
model = ModelProto(graph=GraphProto(initializer=[
|
model = ModelProto(
|
||||||
weights,
|
graph=GraphProto(
|
||||||
]))
|
initializer=[
|
||||||
|
weights,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
embeds = {}
|
embeds = {}
|
||||||
blend_embedding_node(model, {
|
blend_embedding_node(
|
||||||
'convert_tokens_to_ids': lambda t: t,
|
model,
|
||||||
}, embeds, 2)
|
{
|
||||||
|
"convert_tokens_to_ids": lambda t: t,
|
||||||
|
},
|
||||||
|
embeds,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
result = to_array(model.graph.initializer[0])
|
result = to_array(model.graph.initializer[0])
|
||||||
|
|
||||||
self.assertEqual(len(model.graph.initializer), 1)
|
self.assertEqual(len(model.graph.initializer), 1)
|
||||||
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
|
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
|
||||||
|
|
||||||
|
|
||||||
class BlendTextualInversionsTests(unittest.TestCase):
|
class BlendTextualInversionsTests(unittest.TestCase):
|
||||||
def test_blend_multi_concept(self):
|
def test_blend_multi_concept(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_blend_multi_parameters(self):
|
def test_blend_multi_parameters(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_blend_multi_embeddings(self):
|
def test_blend_multi_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_blend_multi_mixed(self):
|
def test_blend_multi_mixed(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -13,7 +13,6 @@ from onnx_web.convert.utils import (
|
||||||
tuple_to_upscaling,
|
tuple_to_upscaling,
|
||||||
)
|
)
|
||||||
from tests.helpers import (
|
from tests.helpers import (
|
||||||
TEST_MODEL_DIFFUSION_SD15,
|
|
||||||
TEST_MODEL_UPSCALING_SWINIR,
|
TEST_MODEL_UPSCALING_SWINIR,
|
||||||
test_needs_models,
|
test_needs_models,
|
||||||
)
|
)
|
||||||
|
@ -21,220 +20,225 @@ from tests.helpers import (
|
||||||
|
|
||||||
class ConversionContextTests(unittest.TestCase):
|
class ConversionContextTests(unittest.TestCase):
|
||||||
def test_from_environ(self):
|
def test_from_environ(self):
|
||||||
context = ConversionContext.from_environ()
|
context = ConversionContext.from_environ()
|
||||||
self.assertEqual(context.opset, DEFAULT_OPSET)
|
self.assertEqual(context.opset, DEFAULT_OPSET)
|
||||||
|
|
||||||
def test_map_location(self):
|
def test_map_location(self):
|
||||||
context = ConversionContext.from_environ()
|
context = ConversionContext.from_environ()
|
||||||
self.assertEqual(context.map_location.type, "cpu")
|
self.assertEqual(context.map_location.type, "cpu")
|
||||||
|
|
||||||
|
|
||||||
class DownloadProgressTests(unittest.TestCase):
|
class DownloadProgressTests(unittest.TestCase):
|
||||||
def test_download_example(self):
|
def test_download_example(self):
|
||||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
||||||
self.assertEqual(path, "/tmp/example-dot-com")
|
self.assertEqual(path, "/tmp/example-dot-com")
|
||||||
|
|
||||||
|
|
||||||
class TupleToSourceTests(unittest.TestCase):
|
class TupleToSourceTests(unittest.TestCase):
|
||||||
def test_basic_tuple(self):
|
def test_basic_tuple(self):
|
||||||
source = tuple_to_source(("foo", "bar"))
|
source = tuple_to_source(("foo", "bar"))
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_list(self):
|
def test_basic_list(self):
|
||||||
source = tuple_to_source(["foo", "bar"])
|
source = tuple_to_source(["foo", "bar"])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_dict(self):
|
def test_basic_dict(self):
|
||||||
source = tuple_to_source(["foo", "bar"])
|
source = tuple_to_source(["foo", "bar"])
|
||||||
source["bin"] = "bin"
|
source["bin"] = "bin"
|
||||||
|
|
||||||
# make sure this is returned as-is with extra fields
|
# make sure this is returned as-is with extra fields
|
||||||
second = tuple_to_source(source)
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
self.assertEqual(source, second)
|
self.assertEqual(source, second)
|
||||||
self.assertIn("bin", second)
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
|
|
||||||
class TupleToCorrectionTests(unittest.TestCase):
|
class TupleToCorrectionTests(unittest.TestCase):
|
||||||
def test_basic_tuple(self):
|
def test_basic_tuple(self):
|
||||||
source = tuple_to_correction(("foo", "bar"))
|
source = tuple_to_correction(("foo", "bar"))
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_list(self):
|
def test_basic_list(self):
|
||||||
source = tuple_to_correction(["foo", "bar"])
|
source = tuple_to_correction(["foo", "bar"])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_dict(self):
|
def test_basic_dict(self):
|
||||||
source = tuple_to_correction(["foo", "bar"])
|
source = tuple_to_correction(["foo", "bar"])
|
||||||
source["bin"] = "bin"
|
source["bin"] = "bin"
|
||||||
|
|
||||||
# make sure this is returned with extra fields
|
# make sure this is returned with extra fields
|
||||||
second = tuple_to_source(source)
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
self.assertEqual(source, second)
|
self.assertEqual(source, second)
|
||||||
self.assertIn("bin", second)
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
def test_scale_tuple(self):
|
def test_scale_tuple(self):
|
||||||
source = tuple_to_correction(["foo", "bar", 2])
|
source = tuple_to_correction(["foo", "bar", 2])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_half_tuple(self):
|
def test_half_tuple(self):
|
||||||
source = tuple_to_correction(["foo", "bar", True])
|
source = tuple_to_correction(["foo", "bar", True])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_opset_tuple(self):
|
def test_opset_tuple(self):
|
||||||
source = tuple_to_correction(["foo", "bar", 14])
|
source = tuple_to_correction(["foo", "bar", 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_all_tuple(self):
|
def test_all_tuple(self):
|
||||||
source = tuple_to_correction(["foo", "bar", 2, True, 14])
|
source = tuple_to_correction(["foo", "bar", 2, True, 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
self.assertEqual(source["scale"], 2)
|
self.assertEqual(source["scale"], 2)
|
||||||
self.assertEqual(source["half"], True)
|
self.assertEqual(source["half"], True)
|
||||||
self.assertEqual(source["opset"], 14)
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
class TupleToDiffusionTests(unittest.TestCase):
|
class TupleToDiffusionTests(unittest.TestCase):
|
||||||
def test_basic_tuple(self):
|
def test_basic_tuple(self):
|
||||||
source = tuple_to_diffusion(("foo", "bar"))
|
source = tuple_to_diffusion(("foo", "bar"))
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_list(self):
|
def test_basic_list(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar"])
|
source = tuple_to_diffusion(["foo", "bar"])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_dict(self):
|
def test_basic_dict(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar"])
|
source = tuple_to_diffusion(["foo", "bar"])
|
||||||
source["bin"] = "bin"
|
source["bin"] = "bin"
|
||||||
|
|
||||||
# make sure this is returned with extra fields
|
# make sure this is returned with extra fields
|
||||||
second = tuple_to_diffusion(source)
|
second = tuple_to_diffusion(source)
|
||||||
|
|
||||||
self.assertEqual(source, second)
|
self.assertEqual(source, second)
|
||||||
self.assertIn("bin", second)
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
def test_single_vae_tuple(self):
|
def test_single_vae_tuple(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar", True])
|
source = tuple_to_diffusion(["foo", "bar", True])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_half_tuple(self):
|
def test_half_tuple(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar", True])
|
source = tuple_to_diffusion(["foo", "bar", True])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_opset_tuple(self):
|
def test_opset_tuple(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar", 14])
|
source = tuple_to_diffusion(["foo", "bar", 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_all_tuple(self):
|
def test_all_tuple(self):
|
||||||
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
|
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
self.assertEqual(source["single_vae"], True)
|
self.assertEqual(source["single_vae"], True)
|
||||||
self.assertEqual(source["half"], True)
|
self.assertEqual(source["half"], True)
|
||||||
self.assertEqual(source["opset"], 14)
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
class TupleToUpscalingTests(unittest.TestCase):
|
class TupleToUpscalingTests(unittest.TestCase):
|
||||||
def test_basic_tuple(self):
|
def test_basic_tuple(self):
|
||||||
source = tuple_to_upscaling(("foo", "bar"))
|
source = tuple_to_upscaling(("foo", "bar"))
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_list(self):
|
def test_basic_list(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar"])
|
source = tuple_to_upscaling(["foo", "bar"])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_basic_dict(self):
|
def test_basic_dict(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar"])
|
source = tuple_to_upscaling(["foo", "bar"])
|
||||||
source["bin"] = "bin"
|
source["bin"] = "bin"
|
||||||
|
|
||||||
# make sure this is returned with extra fields
|
# make sure this is returned with extra fields
|
||||||
second = tuple_to_source(source)
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
self.assertEqual(source, second)
|
self.assertEqual(source, second)
|
||||||
self.assertIn("bin", second)
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
def test_scale_tuple(self):
|
def test_scale_tuple(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar", 2])
|
source = tuple_to_upscaling(["foo", "bar", 2])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_half_tuple(self):
|
def test_half_tuple(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar", True])
|
source = tuple_to_upscaling(["foo", "bar", True])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_opset_tuple(self):
|
def test_opset_tuple(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar", 14])
|
source = tuple_to_upscaling(["foo", "bar", 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
def test_all_tuple(self):
|
def test_all_tuple(self):
|
||||||
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
|
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
|
||||||
self.assertEqual(source["name"], "foo")
|
self.assertEqual(source["name"], "foo")
|
||||||
self.assertEqual(source["source"], "bar")
|
self.assertEqual(source["source"], "bar")
|
||||||
self.assertEqual(source["scale"], 2)
|
self.assertEqual(source["scale"], 2)
|
||||||
self.assertEqual(source["half"], True)
|
self.assertEqual(source["half"], True)
|
||||||
self.assertEqual(source["opset"], 14)
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
class SourceFormatTests(unittest.TestCase):
|
class SourceFormatTests(unittest.TestCase):
|
||||||
def test_with_format(self):
|
def test_with_format(self):
|
||||||
result = source_format({
|
result = source_format(
|
||||||
"format": "foo",
|
{
|
||||||
})
|
"format": "foo",
|
||||||
self.assertEqual(result, "foo")
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "foo")
|
||||||
|
|
||||||
def test_source_known_extension(self):
|
def test_source_known_extension(self):
|
||||||
result = source_format({
|
result = source_format(
|
||||||
"source": "foo.safetensors",
|
{
|
||||||
})
|
"source": "foo.safetensors",
|
||||||
self.assertEqual(result, "safetensors")
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "safetensors")
|
||||||
|
|
||||||
def test_source_unknown_extension(self):
|
def test_source_unknown_extension(self):
|
||||||
result = source_format({
|
result = source_format({"source": "foo.none"})
|
||||||
"source": "foo.none"
|
self.assertEqual(result, None)
|
||||||
})
|
|
||||||
self.assertEqual(result, None)
|
|
||||||
|
|
||||||
def test_incomplete_model(self):
|
def test_incomplete_model(self):
|
||||||
self.assertIsNone(source_format({}))
|
self.assertIsNone(source_format({}))
|
||||||
|
|
||||||
|
|
||||||
class RemovePrefixTests(unittest.TestCase):
|
class RemovePrefixTests(unittest.TestCase):
|
||||||
def test_with_prefix(self):
|
def test_with_prefix(self):
|
||||||
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
|
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
|
||||||
|
|
||||||
def test_without_prefix(self):
|
def test_without_prefix(self):
|
||||||
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
||||||
|
|
||||||
|
|
||||||
class LoadTorchTests(unittest.TestCase):
|
class LoadTorchTests(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LoadTensorTests(unittest.TestCase):
|
class LoadTensorTests(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ResolveTensorTests(unittest.TestCase):
|
class ResolveTensorTests(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||||
def test_resolve_existing(self):
|
def test_resolve_existing(self):
|
||||||
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
|
self.assertEqual(
|
||||||
|
resolve_tensor("../models/.cache/upscaling-swinir"),
|
||||||
|
TEST_MODEL_UPSCALING_SWINIR,
|
||||||
|
)
|
||||||
|
|
||||||
def test_resolve_missing(self):
|
def test_resolve_missing(self):
|
||||||
self.assertIsNone(resolve_tensor("missing"))
|
self.assertIsNone(resolve_tensor("missing"))
|
||||||
|
|
|
@ -6,11 +6,13 @@ from onnx_web.params import DeviceParams
|
||||||
|
|
||||||
|
|
||||||
def test_needs_models(models: List[str]):
|
def test_needs_models(models: List[str]):
|
||||||
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
|
return skipUnless(
|
||||||
|
all([path.exists(model) for model in models]), "model does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_device() -> DeviceParams:
|
def test_device() -> DeviceParams:
|
||||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
||||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||||
|
|
|
@ -10,24 +10,24 @@ from onnx_web.image.mask_filter import (
|
||||||
|
|
||||||
|
|
||||||
class MaskFilterNoneTests(unittest.TestCase):
|
class MaskFilterNoneTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
mask = Image.new("RGB", dims)
|
mask = Image.new("RGB", dims)
|
||||||
result = mask_filter_none(mask, dims, (0, 0))
|
result = mask_filter_none(mask, dims, (0, 0))
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
class MaskFilterGaussianMultiplyTests(unittest.TestCase):
|
class MaskFilterGaussianMultiplyTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
mask = Image.new("RGB", dims)
|
mask = Image.new("RGB", dims)
|
||||||
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
|
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
class MaskFilterGaussianScreenTests(unittest.TestCase):
|
class MaskFilterGaussianScreenTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
mask = Image.new("RGB", dims)
|
mask = Image.new("RGB", dims)
|
||||||
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
|
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
|
@ -11,27 +11,27 @@ from onnx_web.server.context import ServerContext
|
||||||
|
|
||||||
|
|
||||||
class SourceFilterNoneTests(unittest.TestCase):
|
class SourceFilterNoneTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
source = Image.new("RGB", dims)
|
source = Image.new("RGB", dims)
|
||||||
result = source_filter_none(server, source)
|
result = source_filter_none(server, source)
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
class SourceFilterGaussianTests(unittest.TestCase):
|
class SourceFilterGaussianTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
source = Image.new("RGB", dims)
|
source = Image.new("RGB", dims)
|
||||||
result = source_filter_gaussian(server, source)
|
result = source_filter_gaussian(server, source)
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
class SourceFilterNoiseTests(unittest.TestCase):
|
class SourceFilterNoiseTests(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dims = (64, 64)
|
dims = (64, 64)
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
source = Image.new("RGB", dims)
|
source = Image.new("RGB", dims)
|
||||||
result = source_filter_noise(server, source)
|
result = source_filter_noise(server, source)
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
|
@ -7,18 +7,18 @@ from onnx_web.params import Border
|
||||||
|
|
||||||
|
|
||||||
class ExpandImageTests(unittest.TestCase):
|
class ExpandImageTests(unittest.TestCase):
|
||||||
def test_expand(self):
|
def test_expand(self):
|
||||||
result = expand_image(
|
result = expand_image(
|
||||||
Image.new("RGB", (8, 8)),
|
Image.new("RGB", (8, 8)),
|
||||||
Image.new("RGB", (8, 8), "white"),
|
Image.new("RGB", (8, 8), "white"),
|
||||||
Border.even(4),
|
Border.even(4),
|
||||||
)
|
)
|
||||||
self.assertEqual(result[0].size, (16, 16))
|
self.assertEqual(result[0].size, (16, 16))
|
||||||
|
|
||||||
def test_masked(self):
|
def test_masked(self):
|
||||||
result = expand_image(
|
result = expand_image(
|
||||||
Image.new("RGB", (8, 8), "red"),
|
Image.new("RGB", (8, 8), "red"),
|
||||||
Image.new("RGB", (8, 8), "white"),
|
Image.new("RGB", (8, 8), "white"),
|
||||||
Border.even(4),
|
Border.even(4),
|
||||||
)
|
)
|
||||||
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))
|
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))
|
||||||
|
|
|
@ -1,43 +1,43 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class MockPipeline():
|
class MockPipeline:
|
||||||
# flags
|
# flags
|
||||||
slice_size: Optional[str]
|
slice_size: Optional[str]
|
||||||
vae_slicing: Optional[bool]
|
vae_slicing: Optional[bool]
|
||||||
sequential_offload: Optional[bool]
|
sequential_offload: Optional[bool]
|
||||||
model_offload: Optional[bool]
|
model_offload: Optional[bool]
|
||||||
xformers: Optional[bool]
|
xformers: Optional[bool]
|
||||||
|
|
||||||
# stubs
|
# stubs
|
||||||
_encode_prompt: Optional[Any]
|
_encode_prompt: Optional[Any]
|
||||||
unet: Optional[Any]
|
unet: Optional[Any]
|
||||||
vae_decoder: Optional[Any]
|
vae_decoder: Optional[Any]
|
||||||
vae_encoder: Optional[Any]
|
vae_encoder: Optional[Any]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.slice_size = None
|
self.slice_size = None
|
||||||
self.vae_slicing = None
|
self.vae_slicing = None
|
||||||
self.sequential_offload = None
|
self.sequential_offload = None
|
||||||
self.model_offload = None
|
self.model_offload = None
|
||||||
self.xformers = None
|
self.xformers = None
|
||||||
|
|
||||||
self._encode_prompt = None
|
self._encode_prompt = None
|
||||||
self.unet = None
|
self.unet = None
|
||||||
self.vae_decoder = None
|
self.vae_decoder = None
|
||||||
self.vae_encoder = None
|
self.vae_encoder = None
|
||||||
|
|
||||||
def enable_attention_slicing(self, slice_size: str = None):
|
def enable_attention_slicing(self, slice_size: str = None):
|
||||||
self.slice_size = slice_size
|
self.slice_size = slice_size
|
||||||
|
|
||||||
def enable_vae_slicing(self):
|
def enable_vae_slicing(self):
|
||||||
self.vae_slicing = True
|
self.vae_slicing = True
|
||||||
|
|
||||||
def enable_sequential_cpu_offload(self):
|
def enable_sequential_cpu_offload(self):
|
||||||
self.sequential_offload = True
|
self.sequential_offload = True
|
||||||
|
|
||||||
def enable_model_cpu_offload(self):
|
def enable_model_cpu_offload(self):
|
||||||
self.model_offload = True
|
self.model_offload = True
|
||||||
|
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
self.xformers = True
|
self.xformers = True
|
||||||
|
|
|
@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase):
|
||||||
str(["foo"]),
|
str(["foo"]),
|
||||||
str(PromptPhrase(["bar"], weight=1.5)),
|
str(PromptPhrase(["bar"], weight=1.5)),
|
||||||
str(["bin"]),
|
str(["bin"]),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multi_word_phrase(self):
|
def test_multi_word_phrase(self):
|
||||||
|
@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase):
|
||||||
str(["foo", "bar"]),
|
str(["foo", "bar"]),
|
||||||
str(PromptPhrase(["middle", "words"], weight=1.5)),
|
str(PromptPhrase(["middle", "words"], weight=1.5)),
|
||||||
str(["bin", "bun"]),
|
str(["bin", "bun"]),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_nested_phrase(self):
|
def test_nested_phrase(self):
|
||||||
|
@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase):
|
||||||
[str(i) for i in res],
|
[str(i) for i in res],
|
||||||
[
|
[
|
||||||
str(["foo"]),
|
str(["foo"]),
|
||||||
str(PromptPhrase(["bar"], weight=(1.5 ** 3))),
|
str(PromptPhrase(["bar"], weight=(1.5**3))),
|
||||||
str(["bin"]),
|
str(["bin"]),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase):
|
||||||
params = get_config_params()
|
params = get_config_params()
|
||||||
self.assertIsNotNone(params)
|
self.assertIsNotNone(params)
|
||||||
|
|
||||||
|
|
||||||
class AvailablePlatformTests(unittest.TestCase):
|
class AvailablePlatformTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
platforms = get_available_platforms()
|
platforms = get_available_platforms()
|
||||||
self.assertIsNotNone(platforms)
|
self.assertIsNotNone(platforms)
|
||||||
|
|
||||||
|
|
||||||
class CorrectModelTests(unittest.TestCase):
|
class CorrectModelTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
models = get_correction_models()
|
models = get_correction_models()
|
||||||
self.assertIsNotNone(models)
|
self.assertIsNotNone(models)
|
||||||
|
|
||||||
|
|
||||||
class DiffusionModelTests(unittest.TestCase):
|
class DiffusionModelTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
models = get_diffusion_models()
|
models = get_diffusion_models()
|
||||||
self.assertIsNotNone(models)
|
self.assertIsNotNone(models)
|
||||||
|
|
||||||
|
|
||||||
class NetworkModelTests(unittest.TestCase):
|
class NetworkModelTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
models = get_network_models()
|
models = get_network_models()
|
||||||
self.assertIsNotNone(models)
|
self.assertIsNotNone(models)
|
||||||
|
|
||||||
|
|
||||||
class UpscalingModelTests(unittest.TestCase):
|
class UpscalingModelTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
models = get_upscaling_models()
|
models = get_upscaling_models()
|
||||||
self.assertIsNotNone(models)
|
self.assertIsNotNone(models)
|
||||||
|
|
||||||
|
|
||||||
class WildcardDataTests(unittest.TestCase):
|
class WildcardDataTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
wildcards = get_wildcard_data()
|
wildcards = get_wildcard_data()
|
||||||
self.assertIsNotNone(wildcards)
|
self.assertIsNotNone(wildcards)
|
||||||
|
|
||||||
|
|
||||||
class ExtraStringsTests(unittest.TestCase):
|
class ExtraStringsTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
strings = get_extra_strings()
|
strings = get_extra_strings()
|
||||||
self.assertIsNotNone(strings)
|
self.assertIsNotNone(strings)
|
||||||
|
|
||||||
|
|
||||||
class ExtraHashesTests(unittest.TestCase):
|
class ExtraHashesTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
hashes = get_extra_hashes()
|
hashes = get_extra_hashes()
|
||||||
self.assertIsNotNone(hashes)
|
self.assertIsNotNone(hashes)
|
||||||
|
|
||||||
|
|
||||||
class HighresMethodTests(unittest.TestCase):
|
class HighresMethodTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
methods = get_highres_methods()
|
methods = get_highres_methods()
|
||||||
self.assertIsNotNone(methods)
|
self.assertIsNotNone(methods)
|
||||||
|
|
||||||
|
|
||||||
class MaskFilterTests(unittest.TestCase):
|
class MaskFilterTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
filters = get_mask_filters()
|
filters = get_mask_filters()
|
||||||
self.assertIsNotNone(filters)
|
self.assertIsNotNone(filters)
|
||||||
|
|
||||||
|
|
||||||
class NoiseSourceTests(unittest.TestCase):
|
class NoiseSourceTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
sources = get_noise_sources()
|
sources = get_noise_sources()
|
||||||
self.assertIsNotNone(sources)
|
self.assertIsNotNone(sources)
|
||||||
|
|
||||||
|
|
||||||
class SourceFilterTests(unittest.TestCase):
|
class SourceFilterTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
filters = get_source_filters()
|
filters = get_source_filters()
|
||||||
self.assertIsNotNone(filters)
|
self.assertIsNotNone(filters)
|
||||||
|
|
||||||
|
|
||||||
class LoadExtrasTests(unittest.TestCase):
|
class LoadExtrasTests(unittest.TestCase):
|
||||||
def test_default_extras(self):
|
def test_default_extras(self):
|
||||||
server = ServerContext(extra_models=["../models/extras.json"])
|
server = ServerContext(extra_models=["../models/extras.json"])
|
||||||
load_extras(server)
|
load_extras(server)
|
||||||
|
|
||||||
|
|
||||||
class LoadModelsTests(unittest.TestCase):
|
class LoadModelsTests(unittest.TestCase):
|
||||||
def test_default_models(self):
|
def test_default_models(self):
|
||||||
server = ServerContext(model_path="../models")
|
server = ServerContext(model_path="../models")
|
||||||
|
|
|
@ -4,37 +4,37 @@ from onnx_web.server.model_cache import ModelCache
|
||||||
|
|
||||||
|
|
||||||
class TestModelCache(unittest.TestCase):
|
class TestModelCache(unittest.TestCase):
|
||||||
def test_drop_existing(self):
|
def test_drop_existing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
cache.set("foo", ("bar",), {})
|
cache.set("foo", ("bar",), {})
|
||||||
self.assertGreater(cache.size, 0)
|
self.assertGreater(cache.size, 0)
|
||||||
self.assertEqual(cache.drop("foo", ("bar",)), 1)
|
self.assertEqual(cache.drop("foo", ("bar",)), 1)
|
||||||
|
|
||||||
def test_drop_missing(self):
|
def test_drop_missing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
cache.set("foo", ("bar",), {})
|
cache.set("foo", ("bar",), {})
|
||||||
self.assertGreater(cache.size, 0)
|
self.assertGreater(cache.size, 0)
|
||||||
self.assertEqual(cache.drop("foo", ("bin",)), 0)
|
self.assertEqual(cache.drop("foo", ("bin",)), 0)
|
||||||
|
|
||||||
def test_get_existing(self):
|
def test_get_existing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
value = {}
|
value = {}
|
||||||
cache.set("foo", ("bar",), value)
|
cache.set("foo", ("bar",), value)
|
||||||
self.assertGreater(cache.size, 0)
|
self.assertGreater(cache.size, 0)
|
||||||
self.assertIs(cache.get("foo", ("bar",)), value)
|
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||||
|
|
||||||
def test_get_missing(self):
|
def test_get_missing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
value = {}
|
value = {}
|
||||||
cache.set("foo", ("bar",), value)
|
cache.set("foo", ("bar",), value)
|
||||||
self.assertGreater(cache.size, 0)
|
self.assertGreater(cache.size, 0)
|
||||||
self.assertIs(cache.get("foo", ("bin",)), None)
|
self.assertIs(cache.get("foo", ("bin",)), None)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def test_set_existing(self):
|
def test_set_existing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
@ -48,16 +48,16 @@ class TestModelCache(unittest.TestCase):
|
||||||
self.assertIs(cache.get("foo", ("bar",)), value)
|
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_set_missing(self):
|
def test_set_missing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
value = {}
|
value = {}
|
||||||
cache.set("foo", ("bar",), value)
|
cache.set("foo", ("bar",), value)
|
||||||
self.assertIs(cache.get("foo", ("bar",)), value)
|
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||||
|
|
||||||
def test_set_zero(self):
|
def test_set_zero(self):
|
||||||
cache = ModelCache(0)
|
cache = ModelCache(0)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
value = {}
|
value = {}
|
||||||
cache.set("foo", ("bar",), value)
|
cache.set("foo", ("bar",), value)
|
||||||
self.assertEqual(cache.size, 0)
|
self.assertEqual(cache.size, 0)
|
||||||
|
|
|
@ -24,253 +24,307 @@ from tests.mocks import MockPipeline
|
||||||
|
|
||||||
|
|
||||||
class TestAvailablePipelines(unittest.TestCase):
|
class TestAvailablePipelines(unittest.TestCase):
|
||||||
def test_available_pipelines(self):
|
def test_available_pipelines(self):
|
||||||
pipelines = get_available_pipelines()
|
pipelines = get_available_pipelines()
|
||||||
|
|
||||||
self.assertIn("txt2img", pipelines)
|
self.assertIn("txt2img", pipelines)
|
||||||
|
|
||||||
|
|
||||||
class TestPipelineSchedulers(unittest.TestCase):
|
class TestPipelineSchedulers(unittest.TestCase):
|
||||||
def test_pipeline_schedulers(self):
|
def test_pipeline_schedulers(self):
|
||||||
schedulers = get_pipeline_schedulers()
|
schedulers = get_pipeline_schedulers()
|
||||||
|
|
||||||
self.assertIn("euler-a", schedulers)
|
self.assertIn("euler-a", schedulers)
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulerNames(unittest.TestCase):
|
class TestSchedulerNames(unittest.TestCase):
|
||||||
def test_valid_name(self):
|
def test_valid_name(self):
|
||||||
scheduler = get_scheduler_name(DDIMScheduler)
|
scheduler = get_scheduler_name(DDIMScheduler)
|
||||||
|
|
||||||
self.assertEqual("ddim", scheduler)
|
self.assertEqual("ddim", scheduler)
|
||||||
|
|
||||||
def test_missing_names(self):
|
def test_missing_names(self):
|
||||||
self.assertIsNone(get_scheduler_name("test"))
|
self.assertIsNone(get_scheduler_name("test"))
|
||||||
|
|
||||||
|
|
||||||
class TestOptimizePipeline(unittest.TestCase):
|
class TestOptimizePipeline(unittest.TestCase):
|
||||||
def test_auto_attention_slicing(self):
|
def test_auto_attention_slicing(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-attention-slicing-auto",
|
"diffusers-attention-slicing-auto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.slice_size, "auto")
|
self.assertEqual(pipeline.slice_size, "auto")
|
||||||
|
|
||||||
def test_max_attention_slicing(self):
|
def test_max_attention_slicing(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-attention-slicing-max",
|
"diffusers-attention-slicing-max",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.slice_size, "max")
|
self.assertEqual(pipeline.slice_size, "max")
|
||||||
|
|
||||||
def test_vae_slicing(self):
|
def test_vae_slicing(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-vae-slicing",
|
"diffusers-vae-slicing",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.vae_slicing, True)
|
self.assertEqual(pipeline.vae_slicing, True)
|
||||||
|
|
||||||
def test_cpu_offload_sequential(self):
|
def test_cpu_offload_sequential(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-cpu-offload-sequential",
|
"diffusers-cpu-offload-sequential",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.sequential_offload, True)
|
self.assertEqual(pipeline.sequential_offload, True)
|
||||||
|
|
||||||
def test_cpu_offload_model(self):
|
def test_cpu_offload_model(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-cpu-offload-model",
|
"diffusers-cpu-offload-model",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.model_offload, True)
|
self.assertEqual(pipeline.model_offload, True)
|
||||||
|
|
||||||
def test_memory_efficient_attention(self):
|
def test_memory_efficient_attention(self):
|
||||||
server = ServerContext(
|
server = ServerContext(
|
||||||
optimizations=[
|
optimizations=[
|
||||||
"diffusers-memory-efficient-attention",
|
"diffusers-memory-efficient-attention",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
optimize_pipeline(server, pipeline)
|
optimize_pipeline(server, pipeline)
|
||||||
self.assertEqual(pipeline.xformers, True)
|
self.assertEqual(pipeline.xformers, True)
|
||||||
|
|
||||||
|
|
||||||
class TestPatchPipeline(unittest.TestCase):
|
class TestPatchPipeline(unittest.TestCase):
|
||||||
def test_expand_not_lpw(self):
|
def test_expand_not_lpw(self):
|
||||||
"""
|
"""
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||||
self.assertEqual(pipeline._encode_prompt, expand_prompt)
|
self.assertEqual(pipeline._encode_prompt, expand_prompt)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_unet_wrapper_not_xl(self):
|
def test_unet_wrapper_not_xl(self):
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
patch_pipeline(
|
||||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
server,
|
||||||
|
pipeline,
|
||||||
|
None,
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||||
|
|
||||||
def test_unet_wrapper_xl(self):
|
def test_unet_wrapper_xl(self):
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1))
|
patch_pipeline(
|
||||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
server,
|
||||||
|
pipeline,
|
||||||
|
None,
|
||||||
|
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||||
|
|
||||||
def test_vae_wrapper(self):
|
def test_vae_wrapper(self):
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
patch_pipeline(
|
||||||
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
server,
|
||||||
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
pipeline,
|
||||||
|
None,
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
||||||
|
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
||||||
|
|
||||||
|
|
||||||
class TestLoadControlNet(unittest.TestCase):
|
class TestLoadControlNet(unittest.TestCase):
|
||||||
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_existing(self):
|
path.exists("../models/control/canny.onnx"), "model does not exist"
|
||||||
"""
|
|
||||||
Should load a model
|
|
||||||
"""
|
|
||||||
components = load_controlnet(
|
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
|
|
||||||
)
|
)
|
||||||
self.assertIn("controlnet", components)
|
def test_load_existing(self):
|
||||||
|
"""
|
||||||
|
Should load a model
|
||||||
|
"""
|
||||||
|
components = load_controlnet(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
ImageParams(
|
||||||
|
"test",
|
||||||
|
"txt2img",
|
||||||
|
"ddim",
|
||||||
|
"test",
|
||||||
|
1.0,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
control=NetworkModel("canny", "control"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertIn("controlnet", components)
|
||||||
|
|
||||||
def test_load_missing(self):
|
def test_load_missing(self):
|
||||||
"""
|
"""
|
||||||
Should throw
|
Should throw
|
||||||
"""
|
"""
|
||||||
components = {}
|
components = {}
|
||||||
try:
|
try:
|
||||||
components = load_controlnet(
|
components = load_controlnet(
|
||||||
ServerContext(),
|
ServerContext(),
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
|
ImageParams(
|
||||||
)
|
"test",
|
||||||
except:
|
"txt2img",
|
||||||
self.assertNotIn("controlnet", components)
|
"ddim",
|
||||||
return
|
"test",
|
||||||
|
1.0,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
control=NetworkModel("missing", "control"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.assertNotIn("controlnet", components)
|
||||||
|
return
|
||||||
|
|
||||||
self.fail()
|
self.fail()
|
||||||
|
|
||||||
|
|
||||||
class TestLoadTextEncoders(unittest.TestCase):
|
class TestLoadTextEncoders(unittest.TestCase):
|
||||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_embeddings(self):
|
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
||||||
"""
|
"model does not exist",
|
||||||
Should add the token to tokenizer
|
|
||||||
Should increase the encoder dims
|
|
||||||
"""
|
|
||||||
components = load_text_encoders(
|
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/stable-diffusion-onnx-v1-5",
|
|
||||||
[
|
|
||||||
# TODO: add some embeddings
|
|
||||||
],
|
|
||||||
[],
|
|
||||||
torch.float32,
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertIn("text_encoder", components)
|
def test_load_embeddings(self):
|
||||||
|
"""
|
||||||
|
Should add the token to tokenizer
|
||||||
|
Should increase the encoder dims
|
||||||
|
"""
|
||||||
|
components = load_text_encoders(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/stable-diffusion-onnx-v1-5",
|
||||||
|
[
|
||||||
|
# TODO: add some embeddings
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
torch.float32,
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertIn("text_encoder", components)
|
||||||
|
|
||||||
def test_load_embeddings_xl(self):
|
def test_load_embeddings_xl(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_loras(self):
|
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
||||||
components = load_text_encoders(
|
"model does not exist",
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/stable-diffusion-onnx-v1-5",
|
|
||||||
[],
|
|
||||||
[
|
|
||||||
# TODO: add some loras
|
|
||||||
],
|
|
||||||
torch.float32,
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertIn("text_encoder", components)
|
def test_load_loras(self):
|
||||||
|
components = load_text_encoders(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/stable-diffusion-onnx-v1-5",
|
||||||
|
[],
|
||||||
|
[
|
||||||
|
# TODO: add some loras
|
||||||
|
],
|
||||||
|
torch.float32,
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertIn("text_encoder", components)
|
||||||
|
|
||||||
|
def test_load_loras_xl(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_load_loras_xl(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class TestLoadUnet(unittest.TestCase):
|
class TestLoadUnet(unittest.TestCase):
|
||||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_unet_loras(self):
|
path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
|
||||||
components = load_unet(
|
"model does not exist",
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/stable-diffusion-onnx-v1-5",
|
|
||||||
[
|
|
||||||
# TODO: add some loras
|
|
||||||
],
|
|
||||||
"unet",
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertIn("unet", components)
|
def test_load_unet_loras(self):
|
||||||
|
components = load_unet(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/stable-diffusion-onnx-v1-5",
|
||||||
|
[
|
||||||
|
# TODO: add some loras
|
||||||
|
],
|
||||||
|
"unet",
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertIn("unet", components)
|
||||||
|
|
||||||
def test_load_unet_loras_xl(self):
|
def test_load_unet_loras_xl(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_cnet_loras(self):
|
path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
|
||||||
components = load_unet(
|
"model does not exist",
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/stable-diffusion-onnx-v1-5",
|
|
||||||
[
|
|
||||||
# TODO: add some loras
|
|
||||||
],
|
|
||||||
"cnet",
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertIn("unet", components)
|
def test_load_cnet_loras(self):
|
||||||
|
components = load_unet(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/stable-diffusion-onnx-v1-5",
|
||||||
|
[
|
||||||
|
# TODO: add some loras
|
||||||
|
],
|
||||||
|
"cnet",
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertIn("unet", components)
|
||||||
|
|
||||||
|
|
||||||
class TestLoadVae(unittest.TestCase):
|
class TestLoadVae(unittest.TestCase):
|
||||||
@unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_single(self):
|
path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
|
||||||
"""
|
"model does not exist",
|
||||||
Should return single component
|
|
||||||
"""
|
|
||||||
components = load_vae(
|
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/upscaling-stable-diffusion-x4",
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertIn("vae", components)
|
def test_load_single(self):
|
||||||
self.assertNotIn("vae_decoder", components)
|
"""
|
||||||
self.assertNotIn("vae_encoder", components)
|
Should return single component
|
||||||
|
"""
|
||||||
|
components = load_vae(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/upscaling-stable-diffusion-x4",
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertIn("vae", components)
|
||||||
|
self.assertNotIn("vae_decoder", components)
|
||||||
|
self.assertNotIn("vae_encoder", components)
|
||||||
|
|
||||||
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
|
@unittest.skipUnless(
|
||||||
def test_load_split(self):
|
path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
|
||||||
"""
|
"model does not exist",
|
||||||
Should return split encoder/decoder
|
|
||||||
"""
|
|
||||||
components = load_vae(
|
|
||||||
ServerContext(model_path="../models"),
|
|
||||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
||||||
"../models/stable-diffusion-onnx-v1-5",
|
|
||||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
||||||
)
|
)
|
||||||
self.assertNotIn("vae", components)
|
def test_load_split(self):
|
||||||
self.assertIn("vae_decoder", components)
|
"""
|
||||||
self.assertIn("vae_encoder", components)
|
Should return split encoder/decoder
|
||||||
|
"""
|
||||||
|
components = load_vae(
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||||
|
"../models/stable-diffusion-onnx-v1-5",
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||||
|
)
|
||||||
|
self.assertNotIn("vae", components)
|
||||||
|
self.assertIn("vae_decoder", components)
|
||||||
|
self.assertIn("vae_encoder", components)
|
||||||
|
|
|
@ -17,155 +17,234 @@ from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_mod
|
||||||
|
|
||||||
|
|
||||||
class TestTxt2ImgPipeline(unittest.TestCase):
|
class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
active = Value("L", 0)
|
active = Value("L", 0)
|
||||||
idle = Value("L", 0)
|
idle = Value("L", 0)
|
||||||
|
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
test_device(),
|
test_device(),
|
||||||
cancel,
|
cancel,
|
||||||
logs,
|
logs,
|
||||||
pending,
|
pending,
|
||||||
progress,
|
progress,
|
||||||
active,
|
active,
|
||||||
idle,
|
idle,
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start("test")
|
||||||
|
|
||||||
run_txt2img_pipeline(
|
run_txt2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
ServerContext(model_path="../models", output_path="../outputs"),
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
ImageParams(
|
ImageParams(
|
||||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
Size(256, 256),
|
"txt2img",
|
||||||
["test-txt2img.png"],
|
"ddim",
|
||||||
UpscaleParams("test"),
|
"an astronaut eating a hamburger",
|
||||||
HighresParams(False, 1, 0, 0),
|
3.0,
|
||||||
)
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-txt2img-basic.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
|
||||||
|
output = Image.open("../outputs/test-txt2img-basic.png")
|
||||||
|
self.assertEqual(output.size, (256, 256))
|
||||||
|
# TODO: test contents of image
|
||||||
|
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_highres(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
run_txt2img_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
"txt2img",
|
||||||
|
"ddim",
|
||||||
|
"an astronaut eating a hamburger",
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-txt2img-highres.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(True, 2, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
|
||||||
|
output = Image.open("../outputs/test-txt2img-highres.png")
|
||||||
|
self.assertEqual(output.size, (512, 512))
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
|
|
||||||
|
|
||||||
class TestImg2ImgPipeline(unittest.TestCase):
|
class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
active = Value("L", 0)
|
active = Value("L", 0)
|
||||||
idle = Value("L", 0)
|
idle = Value("L", 0)
|
||||||
|
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
test_device(),
|
test_device(),
|
||||||
cancel,
|
cancel,
|
||||||
logs,
|
logs,
|
||||||
pending,
|
pending,
|
||||||
progress,
|
progress,
|
||||||
active,
|
active,
|
||||||
idle,
|
idle,
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start("test")
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
run_img2img_pipeline(
|
run_img2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
ServerContext(model_path="../models", output_path="../outputs"),
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
ImageParams(
|
ImageParams(
|
||||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
["test-img2img.png"],
|
"txt2img",
|
||||||
UpscaleParams("test"),
|
"ddim",
|
||||||
HighresParams(False, 1, 0, 0),
|
"an astronaut eating a hamburger",
|
||||||
source,
|
3.0,
|
||||||
1.0,
|
1,
|
||||||
)
|
1,
|
||||||
|
),
|
||||||
|
["test-img2img.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
|
||||||
|
|
||||||
class TestUpscalePipeline(unittest.TestCase):
|
class TestUpscalePipeline(unittest.TestCase):
|
||||||
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
active = Value("L", 0)
|
active = Value("L", 0)
|
||||||
idle = Value("L", 0)
|
idle = Value("L", 0)
|
||||||
|
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
test_device(),
|
test_device(),
|
||||||
cancel,
|
cancel,
|
||||||
logs,
|
logs,
|
||||||
pending,
|
pending,
|
||||||
progress,
|
progress,
|
||||||
active,
|
active,
|
||||||
idle,
|
idle,
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start("test")
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
run_upscale_pipeline(
|
run_upscale_pipeline(
|
||||||
worker,
|
worker,
|
||||||
ServerContext(model_path="../models", output_path="../outputs"),
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
ImageParams(
|
ImageParams(
|
||||||
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
"../models/upscaling-stable-diffusion-x4",
|
||||||
Size(256, 256),
|
"txt2img",
|
||||||
["test-upscale.png"],
|
"ddim",
|
||||||
UpscaleParams("test"),
|
"an astronaut eating a hamburger",
|
||||||
HighresParams(False, 1, 0, 0),
|
3.0,
|
||||||
source,
|
1,
|
||||||
)
|
1,
|
||||||
|
),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-upscale.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-upscale.png"))
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-upscale.png"))
|
|
||||||
|
|
||||||
class TestBlendPipeline(unittest.TestCase):
|
class TestBlendPipeline(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
active = Value("L", 0)
|
active = Value("L", 0)
|
||||||
idle = Value("L", 0)
|
idle = Value("L", 0)
|
||||||
|
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
test_device(),
|
test_device(),
|
||||||
cancel,
|
cancel,
|
||||||
logs,
|
logs,
|
||||||
pending,
|
pending,
|
||||||
progress,
|
progress,
|
||||||
active,
|
active,
|
||||||
idle,
|
idle,
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start("test")
|
||||||
|
|
||||||
source = Image.new("RGBA", (64, 64), "black")
|
source = Image.new("RGBA", (64, 64), "black")
|
||||||
mask = Image.new("RGBA", (64, 64), "white")
|
mask = Image.new("RGBA", (64, 64), "white")
|
||||||
run_blend_pipeline(
|
run_blend_pipeline(
|
||||||
worker,
|
worker,
|
||||||
ServerContext(model_path="../models", output_path="../outputs"),
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
ImageParams(
|
ImageParams(
|
||||||
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
Size(64, 64),
|
"txt2img",
|
||||||
["test-blend.png"],
|
"ddim",
|
||||||
UpscaleParams("test"),
|
"an astronaut eating a hamburger",
|
||||||
[source, source],
|
3.0,
|
||||||
mask,
|
1,
|
||||||
)
|
1,
|
||||||
|
),
|
||||||
|
Size(64, 64),
|
||||||
|
["test-blend.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
[source, source],
|
||||||
|
mask,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-blend.png"))
|
self.assertTrue(path.exists("../outputs/test-blend.png"))
|
||||||
|
|
|
@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import (
|
||||||
get_loras_from_prompt,
|
get_loras_from_prompt,
|
||||||
get_scaled_latents,
|
get_scaled_latents,
|
||||||
get_tile_latents,
|
get_tile_latents,
|
||||||
get_tokens_from_prompt,
|
|
||||||
pop_random,
|
pop_random,
|
||||||
slice_prompt,
|
slice_prompt,
|
||||||
)
|
)
|
||||||
|
@ -18,110 +17,128 @@ from onnx_web.params import Size
|
||||||
|
|
||||||
|
|
||||||
class TestExpandIntervalRanges(unittest.TestCase):
|
class TestExpandIntervalRanges(unittest.TestCase):
|
||||||
def test_prompt_with_no_ranges(self):
|
def test_prompt_with_no_ranges(self):
|
||||||
prompt = "an astronaut eating a hamburger"
|
prompt = "an astronaut eating a hamburger"
|
||||||
result = expand_interval_ranges(prompt)
|
result = expand_interval_ranges(prompt)
|
||||||
self.assertEqual(prompt, result)
|
self.assertEqual(prompt, result)
|
||||||
|
|
||||||
|
def test_prompt_with_range(self):
|
||||||
|
prompt = "an astronaut-{1,4} eating a hamburger"
|
||||||
|
result = expand_interval_ranges(prompt)
|
||||||
|
self.assertEqual(
|
||||||
|
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
|
||||||
|
)
|
||||||
|
|
||||||
def test_prompt_with_range(self):
|
|
||||||
prompt = "an astronaut-{1,4} eating a hamburger"
|
|
||||||
result = expand_interval_ranges(prompt)
|
|
||||||
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
|
|
||||||
|
|
||||||
class TestExpandAlternativeRanges(unittest.TestCase):
|
class TestExpandAlternativeRanges(unittest.TestCase):
|
||||||
def test_prompt_with_no_ranges(self):
|
def test_prompt_with_no_ranges(self):
|
||||||
prompt = "an astronaut eating a hamburger"
|
prompt = "an astronaut eating a hamburger"
|
||||||
result = expand_alternative_ranges(prompt)
|
result = expand_alternative_ranges(prompt)
|
||||||
self.assertEqual([prompt], result)
|
self.assertEqual([prompt], result)
|
||||||
|
|
||||||
|
def test_ranges_match(self):
|
||||||
|
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
|
||||||
|
result = expand_alternative_ranges(prompt)
|
||||||
|
self.assertEqual(
|
||||||
|
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
|
||||||
|
)
|
||||||
|
|
||||||
def test_ranges_match(self):
|
|
||||||
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
|
|
||||||
result = expand_alternative_ranges(prompt)
|
|
||||||
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
|
|
||||||
|
|
||||||
class TestInversionsFromPrompt(unittest.TestCase):
|
class TestInversionsFromPrompt(unittest.TestCase):
|
||||||
def test_get_inversions(self):
|
def test_get_inversions(self):
|
||||||
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
|
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
|
||||||
result, tokens = get_inversions_from_prompt(prompt)
|
result, tokens = get_inversions_from_prompt(prompt)
|
||||||
|
|
||||||
|
self.assertEqual(result, " an astronaut eating an embedding")
|
||||||
|
self.assertEqual(tokens, [("test", 1.0)])
|
||||||
|
|
||||||
self.assertEqual(result, " an astronaut eating an embedding")
|
|
||||||
self.assertEqual(tokens, [("test", 1.0)])
|
|
||||||
|
|
||||||
class TestLoRAsFromPrompt(unittest.TestCase):
|
class TestLoRAsFromPrompt(unittest.TestCase):
|
||||||
def test_get_loras(self):
|
def test_get_loras(self):
|
||||||
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
|
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
|
||||||
result, tokens = get_loras_from_prompt(prompt)
|
result, tokens = get_loras_from_prompt(prompt)
|
||||||
|
|
||||||
|
self.assertEqual(result, " an astronaut eating a LoRA")
|
||||||
|
self.assertEqual(tokens, [("test", 1.0)])
|
||||||
|
|
||||||
self.assertEqual(result, " an astronaut eating a LoRA")
|
|
||||||
self.assertEqual(tokens, [("test", 1.0)])
|
|
||||||
|
|
||||||
class TestLatentsFromSeed(unittest.TestCase):
|
class TestLatentsFromSeed(unittest.TestCase):
|
||||||
def test_batch_size(self):
|
def test_batch_size(self):
|
||||||
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
|
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
|
||||||
self.assertEqual(latents.shape, (4, 4, 8, 8))
|
self.assertEqual(latents.shape, (4, 4, 8, 8))
|
||||||
|
|
||||||
|
def test_consistency(self):
|
||||||
|
latents1 = get_latents_from_seed(1, Size(64, 64))
|
||||||
|
latents2 = get_latents_from_seed(1, Size(64, 64))
|
||||||
|
self.assertTrue(np.array_equal(latents1, latents2))
|
||||||
|
|
||||||
def test_consistency(self):
|
|
||||||
latents1 = get_latents_from_seed(1, Size(64, 64))
|
|
||||||
latents2 = get_latents_from_seed(1, Size(64, 64))
|
|
||||||
self.assertTrue(np.array_equal(latents1, latents2))
|
|
||||||
|
|
||||||
class TestTileLatents(unittest.TestCase):
|
class TestTileLatents(unittest.TestCase):
|
||||||
def test_full_tile(self):
|
def test_full_tile(self):
|
||||||
partial = np.zeros((1, 1, 64, 64))
|
partial = np.zeros((1, 1, 64, 64))
|
||||||
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
||||||
self.assertEqual(full.shape, (1, 1, 8, 8))
|
self.assertEqual(full.shape, (1, 1, 8, 8))
|
||||||
|
|
||||||
def test_contract_tile(self):
|
def test_contract_tile(self):
|
||||||
partial = np.zeros((1, 1, 64, 64))
|
partial = np.zeros((1, 1, 64, 64))
|
||||||
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
|
full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32))
|
||||||
self.assertEqual(full.shape, (1, 1, 4, 4))
|
self.assertEqual(full.shape, (1, 1, 4, 4))
|
||||||
|
|
||||||
|
def test_expand_tile(self):
|
||||||
|
partial = np.zeros((1, 1, 32, 32))
|
||||||
|
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
||||||
|
self.assertEqual(full.shape, (1, 1, 8, 8))
|
||||||
|
|
||||||
def test_expand_tile(self):
|
|
||||||
partial = np.zeros((1, 1, 32, 32))
|
|
||||||
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
|
|
||||||
self.assertEqual(full.shape, (1, 1, 8, 8))
|
|
||||||
|
|
||||||
class TestScaledLatents(unittest.TestCase):
|
class TestScaledLatents(unittest.TestCase):
|
||||||
def test_scale_up(self):
|
def test_scale_up(self):
|
||||||
latents = get_latents_from_seed(1, Size(16, 16))
|
latents = get_latents_from_seed(1, Size(16, 16))
|
||||||
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
|
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
|
||||||
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
|
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
|
||||||
|
|
||||||
|
def test_scale_down(self):
|
||||||
|
latents = get_latents_from_seed(1, Size(16, 16))
|
||||||
|
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
|
||||||
|
self.assertEqual(
|
||||||
|
(
|
||||||
|
latents[0, 0, 0, 0]
|
||||||
|
+ latents[0, 0, 0, 1]
|
||||||
|
+ latents[0, 0, 1, 0]
|
||||||
|
+ latents[0, 0, 1, 1]
|
||||||
|
)
|
||||||
|
/ 4,
|
||||||
|
scaled[0, 0, 0, 0],
|
||||||
|
)
|
||||||
|
|
||||||
def test_scale_down(self):
|
|
||||||
latents = get_latents_from_seed(1, Size(16, 16))
|
|
||||||
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
|
|
||||||
self.assertEqual((
|
|
||||||
latents[0, 0, 0, 0] +
|
|
||||||
latents[0, 0, 0, 1] +
|
|
||||||
latents[0, 0, 1, 0] +
|
|
||||||
latents[0, 0, 1, 1]
|
|
||||||
) / 4, scaled[0, 0, 0, 0])
|
|
||||||
|
|
||||||
class TestReplaceWildcards(unittest.TestCase):
|
class TestReplaceWildcards(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestPopRandom(unittest.TestCase):
|
class TestPopRandom(unittest.TestCase):
|
||||||
def test_pop(self):
|
def test_pop(self):
|
||||||
items = ["1", "2", "3"]
|
items = ["1", "2", "3"]
|
||||||
pop_random(items)
|
pop_random(items)
|
||||||
self.assertEqual(len(items), 2)
|
self.assertEqual(len(items), 2)
|
||||||
|
|
||||||
|
|
||||||
class TestRepairNaN(unittest.TestCase):
|
class TestRepairNaN(unittest.TestCase):
|
||||||
def test_unchanged(self):
|
def test_unchanged(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_missing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_missing(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class TestSlicePrompt(unittest.TestCase):
|
class TestSlicePrompt(unittest.TestCase):
|
||||||
def test_slice_no_delimiter(self):
|
def test_slice_no_delimiter(self):
|
||||||
slice = slice_prompt("foo", 1)
|
slice = slice_prompt("foo", 1)
|
||||||
self.assertEqual(slice, "foo")
|
self.assertEqual(slice, "foo")
|
||||||
|
|
||||||
def test_slice_within_range(self):
|
def test_slice_within_range(self):
|
||||||
slice = slice_prompt("foo || bar", 1)
|
slice = slice_prompt("foo || bar", 1)
|
||||||
self.assertEqual(slice, " bar")
|
self.assertEqual(slice, " bar")
|
||||||
|
|
||||||
def test_slice_outside_range(self):
|
def test_slice_outside_range(self):
|
||||||
slice = slice_prompt("foo || bar", 9)
|
slice = slice_prompt("foo || bar", 9)
|
||||||
self.assertEqual(slice, " bar")
|
self.assertEqual(slice, " bar")
|
||||||
|
|
|
@ -13,122 +13,128 @@ lock = Event()
|
||||||
|
|
||||||
|
|
||||||
def test_job(*args, **kwargs):
|
def test_job(*args, **kwargs):
|
||||||
lock.wait()
|
lock.wait()
|
||||||
|
|
||||||
|
|
||||||
def wait_job(*args, **kwargs):
|
def wait_job(*args, **kwargs):
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerPool(unittest.TestCase):
|
class TestWorkerPool(unittest.TestCase):
|
||||||
# lock: Optional[Event]
|
# lock: Optional[Event]
|
||||||
pool: Optional[DevicePoolExecutor]
|
pool: Optional[DevicePoolExecutor]
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.pool = None
|
self.pool = None
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
if self.pool is not None:
|
if self.pool is not None:
|
||||||
self.pool.join()
|
self.pool.join()
|
||||||
|
|
||||||
def test_no_devices(self):
|
def test_no_devices(self):
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
def test_fake_worker(self):
|
def test_fake_worker(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
self.assertEqual(len(self.pool.workers), 1)
|
self.assertEqual(len(self.pool.workers), 1)
|
||||||
|
|
||||||
def test_cancel_pending(self):
|
def test_cancel_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
self.pool.submit("test", wait_job, lock=lock)
|
self.pool.submit("test", wait_job, lock=lock)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.done("test"), (True, None))
|
||||||
|
|
||||||
self.assertTrue(self.pool.cancel("test"))
|
self.assertTrue(self.pool.cancel("test"))
|
||||||
self.assertEqual(self.pool.done("test"), (False, None))
|
self.assertEqual(self.pool.done("test"), (False, None))
|
||||||
|
|
||||||
def test_cancel_running(self):
|
def test_cancel_running(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_next_device(self):
|
def test_next_device(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
self.assertEqual(self.pool.get_next_device(), 0)
|
self.assertEqual(self.pool.get_next_device(), 0)
|
||||||
|
|
||||||
def test_needs_device(self):
|
def test_needs_device(self):
|
||||||
device1 = DeviceParams("cpu1", "CPUProvider")
|
device1 = DeviceParams("cpu1", "CPUProvider")
|
||||||
device2 = DeviceParams("cpu2", "CPUProvider")
|
device2 = DeviceParams("cpu2", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(
|
||||||
self.pool.start()
|
server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT
|
||||||
|
)
|
||||||
|
self.pool.start()
|
||||||
|
|
||||||
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
||||||
|
|
||||||
def test_done_running(self):
|
def test_done_running(self):
|
||||||
"""
|
"""
|
||||||
TODO: flaky
|
TODO: flaky
|
||||||
"""
|
"""
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
self.pool = DevicePoolExecutor(
|
||||||
self.pool.start(lock)
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
sleep(2.0)
|
)
|
||||||
|
self.pool.start(lock)
|
||||||
|
sleep(2.0)
|
||||||
|
|
||||||
self.pool.submit("test", test_job)
|
self.pool.submit("test", test_job)
|
||||||
sleep(2.0)
|
sleep(2.0)
|
||||||
|
|
||||||
pending, _progress = self.pool.done("test")
|
pending, _progress = self.pool.done("test")
|
||||||
self.assertFalse(pending)
|
self.assertFalse(pending)
|
||||||
|
|
||||||
def test_done_pending(self):
|
def test_done_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
|
|
||||||
self.pool.submit("test1", test_job)
|
self.pool.submit("test1", test_job)
|
||||||
self.pool.submit("test2", test_job)
|
self.pool.submit("test2", test_job)
|
||||||
self.assertTrue(self.pool.done("test2"), (True, None))
|
self.assertTrue(self.pool.done("test2"), (True, None))
|
||||||
|
|
||||||
lock.set()
|
lock.set()
|
||||||
|
|
||||||
def test_done_finished(self):
|
def test_done_finished(self):
|
||||||
"""
|
"""
|
||||||
TODO: flaky
|
TODO: flaky
|
||||||
"""
|
"""
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
self.pool = DevicePoolExecutor(
|
||||||
self.pool.start()
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
sleep(2.0)
|
)
|
||||||
|
self.pool.start()
|
||||||
|
sleep(2.0)
|
||||||
|
|
||||||
self.pool.submit("test", wait_job)
|
self.pool.submit("test", wait_job)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.done("test"), (True, None))
|
||||||
|
|
||||||
sleep(2.0)
|
sleep(2.0)
|
||||||
pending, _progress = self.pool.done("test")
|
pending, _progress = self.pool.done("test")
|
||||||
self.assertFalse(pending)
|
self.assertFalse(pending)
|
||||||
|
|
||||||
def test_recycle_live(self):
|
def test_recycle_live(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_recycle_dead(self):
|
def test_recycle_dead(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_running_status(self):
|
def test_running_status(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -18,119 +18,194 @@ from tests.helpers import test_device
|
||||||
|
|
||||||
|
|
||||||
def main_memory(_worker):
|
def main_memory(_worker):
|
||||||
raise Exception(MEMORY_ERRORS[0])
|
raise Exception(MEMORY_ERRORS[0])
|
||||||
|
|
||||||
|
|
||||||
def main_retry(_worker):
|
def main_retry(_worker):
|
||||||
raise RetryException()
|
raise RetryException()
|
||||||
|
|
||||||
|
|
||||||
def main_interrupt(_worker):
|
def main_interrupt(_worker):
|
||||||
raise KeyboardInterrupt()
|
raise KeyboardInterrupt()
|
||||||
|
|
||||||
|
|
||||||
class WorkerMainTests(unittest.TestCase):
|
class WorkerMainTests(unittest.TestCase):
|
||||||
def test_pending_exception_empty(self):
|
def test_pending_exception_empty(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_pending_exception_interrupt(self):
|
def test_pending_exception_interrupt(self):
|
||||||
status = None
|
status = None
|
||||||
|
|
||||||
def exit(exit_status):
|
def exit(exit_status):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_interrupt, [], {})
|
job = JobCommand("test", "test", main_interrupt, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
pid = Value("L", getpid())
|
pid = Value("L", getpid())
|
||||||
idle = Value("L", False)
|
idle = Value("L", False)
|
||||||
|
|
||||||
pending.put(job)
|
pending.put(job)
|
||||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
worker_main(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
pid,
|
||||||
|
idle,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
ServerContext(),
|
||||||
|
exit=exit,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(status, EXIT_INTERRUPT)
|
self.assertEqual(status, EXIT_INTERRUPT)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_pending_exception_retry(self):
|
def test_pending_exception_retry(self):
|
||||||
status = None
|
status = None
|
||||||
|
|
||||||
def exit(exit_status):
|
def exit(exit_status):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_retry, [], {})
|
job = JobCommand("test", "test", main_retry, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
pid = Value("L", getpid())
|
pid = Value("L", getpid())
|
||||||
idle = Value("L", False)
|
idle = Value("L", False)
|
||||||
|
|
||||||
pending.put(job)
|
pending.put(job)
|
||||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
worker_main(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
pid,
|
||||||
|
idle,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
ServerContext(),
|
||||||
|
exit=exit,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(status, EXIT_ERROR)
|
self.assertEqual(status, EXIT_ERROR)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_pending_exception_value(self):
|
def test_pending_exception_value(self):
|
||||||
status = None
|
status = None
|
||||||
|
|
||||||
def exit(exit_status):
|
def exit(exit_status):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
pid = Value("L", getpid())
|
pid = Value("L", getpid())
|
||||||
idle = Value("L", False)
|
idle = Value("L", False)
|
||||||
|
|
||||||
pending.close()
|
pending.close()
|
||||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
worker_main(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
pid,
|
||||||
|
idle,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
ServerContext(),
|
||||||
|
exit=exit,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(status, EXIT_ERROR)
|
self.assertEqual(status, EXIT_ERROR)
|
||||||
|
|
||||||
def test_pending_exception_other_memory(self):
|
def test_pending_exception_other_memory(self):
|
||||||
status = None
|
status = None
|
||||||
|
|
||||||
def exit(exit_status):
|
def exit(exit_status):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_memory, [], {})
|
job = JobCommand("test", "test", main_memory, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
progress = Queue()
|
progress = Queue()
|
||||||
pid = Value("L", getpid())
|
pid = Value("L", getpid())
|
||||||
idle = Value("L", False)
|
idle = Value("L", False)
|
||||||
|
|
||||||
pending.put(job)
|
pending.put(job)
|
||||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
worker_main(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
pid,
|
||||||
|
idle,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
ServerContext(),
|
||||||
|
exit=exit,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(status, EXIT_MEMORY)
|
self.assertEqual(status, EXIT_MEMORY)
|
||||||
|
|
||||||
|
def test_pending_exception_other_unknown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_pending_exception_other_unknown(self):
|
def test_pending_replaced(self):
|
||||||
pass
|
status = None
|
||||||
|
|
||||||
def test_pending_replaced(self):
|
def exit(exit_status):
|
||||||
status = None
|
nonlocal status
|
||||||
|
status = exit_status
|
||||||
|
|
||||||
def exit(exit_status):
|
cancel = Value("L", False)
|
||||||
nonlocal status
|
logs = Queue()
|
||||||
status = exit_status
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
pid = Value("L", 0)
|
||||||
|
idle = Value("L", False)
|
||||||
|
|
||||||
cancel = Value("L", False)
|
worker_main(
|
||||||
logs = Queue()
|
WorkerContext(
|
||||||
pending = Queue()
|
"test",
|
||||||
progress = Queue()
|
test_device(),
|
||||||
pid = Value("L", 0)
|
cancel,
|
||||||
idle = Value("L", False)
|
logs,
|
||||||
|
pending,
|
||||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
progress,
|
||||||
|
pid,
|
||||||
self.assertEqual(status, EXIT_REPLACED)
|
idle,
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
ServerContext(),
|
||||||
|
exit=exit,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(status, EXIT_REPLACED)
|
||||||
|
|
Loading…
Reference in New Issue