1
0
Fork 0

apply lint to tests, test highres

This commit is contained in:
Sean Sube 2023-11-19 23:18:57 -06:00
parent 4691e80744
commit 65912c5a4a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
26 changed files with 1506 additions and 1127 deletions

View File

@ -33,13 +33,17 @@ package-upload:
lint-check:
black --check onnx_web/
black --check tests/
flake8 onnx_web
flake8 tests
isort --check-only --skip __init__.py --filter-files onnx_web
isort --check-only --skip __init__.py --filter-files tests
lint-fix:
black onnx_web/
black tests/
flake8 onnx_web
flake8 tests
isort --skip __init__.py --filter-files onnx_web
isort --skip __init__.py --filter-files tests

View File

@ -19,6 +19,14 @@ class StageResult:
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")

View File

@ -25,9 +25,7 @@ class TileCallback(Protocol):
Definition for a tile job function.
"""
def __call__(
self, image: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
"""
Run this stage against a single tile.
"""
@ -319,6 +317,9 @@ def process_tile_stack(
if mask:
tile_mask = mask.crop((left, top, right, bottom))
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
for image_filter in filters:
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))

View File

@ -48,6 +48,7 @@ def main():
# debug options
if server.debug:
import debugpy
debugpy.listen(5678)
logger.warning("waiting for debugger")
debugpy.wait_for_client()

View File

@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = StageResult(images=[
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
])
]
)
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
class BlendImg2ImgStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_stage(self):
"""
stage = BlendImg2ImgStage()
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
server = ServerContext()
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
params = ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"euler-a",
"an astronaut eating a hamburger",
3.0,
1,
1,
)
server = ServerContext(model_path="../models", output_path="../outputs")
worker = WorkerContext(
"test",
DeviceParams("cpu", "CPUProvider"),
None,
None,
None,
None,
None,
None,
0,
)
sources = [
Image.new("RGB", (64, 64), "black"),
]
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
"""
pass
self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = StageResult(images=[
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
])
]
)
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -2,6 +2,7 @@ import unittest
from PIL import Image
from onnx_web.chain.result import StageResult
from onnx_web.chain.tile import (
complete_tile,
generate_tile_grid,
@ -48,7 +49,7 @@ class TestNeedsTile(unittest.TestCase):
self.assertFalse(needs_tile(64, 64, size=small))
def test_with_oversized_source(self):
def test_with_oversized_size(self):
large = Size(64, 64)
self.assertTrue(needs_tile(32, 32, size=large))
@ -82,21 +83,21 @@ class TestGenerateTileGrid(unittest.TestCase):
tiles = generate_tile_grid(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)])
def test_grid_no_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)])
def test_grid_50_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225)
self.assertEqual(len(tiles), 256)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)])
class TestGenerateTileSpiral(unittest.TestCase):
@ -124,12 +125,16 @@ class TestGenerateTileSpiral(unittest.TestCase):
class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_stack(source, 32, 1, [])
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (64, 64))
self.assertEqual(blend[0].size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_stack(source, 32, 1, [])
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
)
self.assertEqual(blend.size, (72, 72))
self.assertEqual(blend[0].size, (72, 72))

View File

@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self):
stage = UpscaleHighresStage()
sources = StageResult.empty()
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
self.assertEqual(len(result), 0)

View File

@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase):
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
self.assertEqual(
len(slim_model.graph.initializer), len(model.graph.initializer)
)
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
inputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/down_blocks_proj/MatMul"),
])
],
)
self.assertEqual(fixed, {
self.assertEqual(
fixed,
{
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
})
},
)
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/mid_blocks_proj/MatMul"),
])
],
)
self.assertEqual(fixed, {
self.assertEqual(
fixed,
{
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
})
},
)
def test_output_block(self):
pass
@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase):
nodes = {
"output_block_proj_out.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/up_blocks_proj_out/MatMul"),
])
],
)
self.assertEqual(fixed, {
self.assertEqual(
fixed,
{
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
})
},
)
class KernelSliceTests(unittest.TestCase):
@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
self.assertEqual(result.shape, (4, 4))
"""
class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self):
model = {
@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase):
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),

View File

@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import (
blend_embedding_embeddings,
blend_embedding_node,
blend_embedding_parameters,
blend_textual_inversions,
detect_embedding_format,
)
@ -59,9 +58,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_concept(embeds, {
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
@ -69,9 +74,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
def test_missing_base_token(self):
embeds = {}
blend_embedding_concept(embeds, {
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
@ -80,9 +91,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
embeds = {
"<test>": np.ones(TEST_DIMS),
}
blend_embedding_concept(embeds, {
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
@ -92,9 +109,15 @@ class BlendEmbeddingConceptTests(unittest.TestCase):
def test_missing_token(self):
embeds = {}
blend_embedding_concept(embeds, {
blend_embedding_concept(
embeds,
{
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
@ -108,9 +131,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
embeds = {
"test": np.ones(TEST_DIMS),
}
blend_embedding_parameters(embeds, {
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
@ -118,9 +147,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
def test_missing_base_token(self):
embeds = {}
blend_embedding_parameters(embeds, {
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
self.assertIn("test", embeds)
self.assertEqual(embeds["test"].shape, TEST_DIMS)
@ -129,9 +164,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
embeds = {
"test": np.ones(TEST_DIMS_EMBEDS),
}
blend_embedding_parameters(embeds, {
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
@ -141,9 +182,15 @@ class BlendEmbeddingParametersTests(unittest.TestCase):
def test_missing_token(self):
embeds = {}
blend_embedding_parameters(embeds, {
blend_embedding_parameters(
embeds,
{
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
}, np.float32, "test", 1.0)
},
np.float32,
"test",
1.0,
)
keys = list(embeds.keys())
keys.sort()
@ -198,14 +245,23 @@ class BlendEmbeddingNodeTests(unittest.TestCase):
weights = from_array(np.ones(TEST_DIMS))
weights.name = "text_model.embeddings.token_embedding.weight"
model = ModelProto(graph=GraphProto(initializer=[
model = ModelProto(
graph=GraphProto(
initializer=[
weights,
]))
]
)
)
embeds = {}
blend_embedding_node(model, {
'convert_tokens_to_ids': lambda t: t,
}, embeds, 2)
blend_embedding_node(
model,
{
"convert_tokens_to_ids": lambda t: t,
},
embeds,
2,
)
result = to_array(model.graph.initializer[0])

View File

@ -13,7 +13,6 @@ from onnx_web.convert.utils import (
tuple_to_upscaling,
)
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_UPSCALING_SWINIR,
test_needs_models,
)
@ -194,21 +193,23 @@ class TupleToUpscalingTests(unittest.TestCase):
class SourceFormatTests(unittest.TestCase):
def test_with_format(self):
result = source_format({
result = source_format(
{
"format": "foo",
})
}
)
self.assertEqual(result, "foo")
def test_source_known_extension(self):
result = source_format({
result = source_format(
{
"source": "foo.safetensors",
})
}
)
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self):
result = source_format({
"source": "foo.none"
})
result = source_format({"source": "foo.none"})
self.assertEqual(result, None)
def test_incomplete_model(self):
@ -234,7 +235,10 @@ class LoadTensorTests(unittest.TestCase):
class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
self.assertEqual(
resolve_tensor("../models/.cache/upscaling-swinir"),
TEST_MODEL_UPSCALING_SWINIR,
)
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))

View File

@ -6,7 +6,9 @@ from onnx_web.params import DeviceParams
def test_needs_models(models: List[str]):
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
return skipUnless(
all([path.exists(model) for model in models]), "model does not exist"
)
def test_device() -> DeviceParams:

View File

@ -1,7 +1,7 @@
from typing import Any, Optional
class MockPipeline():
class MockPipeline:
# flags
slice_size: Optional[str]
vae_slicing: Optional[bool]

View File

@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase):
str(["foo"]),
str(PromptPhrase(["bar"], weight=1.5)),
str(["bin"]),
]
],
)
def test_multi_word_phrase(self):
@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase):
str(["foo", "bar"]),
str(PromptPhrase(["middle", "words"], weight=1.5)),
str(["bin", "bun"]),
]
],
)
def test_nested_phrase(self):
@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase):
[str(i) for i in res],
[
str(["foo"]),
str(PromptPhrase(["bar"], weight=(1.5 ** 3))),
str(PromptPhrase(["bar"], weight=(1.5**3))),
str(["bin"]),
]
],
)

View File

@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase):
params = get_config_params()
self.assertIsNotNone(params)
class AvailablePlatformTests(unittest.TestCase):
def test_before_setup(self):
platforms = get_available_platforms()
self.assertIsNotNone(platforms)
class CorrectModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_correction_models()
self.assertIsNotNone(models)
class DiffusionModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_diffusion_models()
self.assertIsNotNone(models)
class NetworkModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_network_models()
self.assertIsNotNone(models)
class UpscalingModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_upscaling_models()
self.assertIsNotNone(models)
class WildcardDataTests(unittest.TestCase):
def test_before_setup(self):
wildcards = get_wildcard_data()
self.assertIsNotNone(wildcards)
class ExtraStringsTests(unittest.TestCase):
def test_before_setup(self):
strings = get_extra_strings()
self.assertIsNotNone(strings)
class ExtraHashesTests(unittest.TestCase):
def test_before_setup(self):
hashes = get_extra_hashes()
self.assertIsNotNone(hashes)
class HighresMethodTests(unittest.TestCase):
def test_before_setup(self):
methods = get_highres_methods()
self.assertIsNotNone(methods)
class MaskFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_mask_filters()
self.assertIsNotNone(filters)
class NoiseSourceTests(unittest.TestCase):
def test_before_setup(self):
sources = get_noise_sources()
self.assertIsNotNone(sources)
class SourceFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_source_filters()
self.assertIsNotNone(filters)
class LoadExtrasTests(unittest.TestCase):
def test_default_extras(self):
server = ServerContext(extra_models=["../models/extras.json"])
load_extras(server)
class LoadModelsTests(unittest.TestCase):
def test_default_models(self):
server = ServerContext(model_path="../models")

View File

@ -122,25 +122,42 @@ class TestPatchPipeline(unittest.TestCase):
def test_unet_wrapper_not_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1))
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_vae_wrapper(self):
server = ServerContext()
pipeline = MockPipeline()
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
patch_pipeline(
server,
pipeline,
None,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
class TestLoadControlNet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/control/canny.onnx"), "model does not exist"
)
def test_load_existing(self):
"""
Should load a model
@ -148,7 +165,16 @@ class TestLoadControlNet(unittest.TestCase):
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
ImageParams(
"test",
"txt2img",
"ddim",
"test",
1.0,
10,
1,
control=NetworkModel("canny", "control"),
),
)
self.assertIn("controlnet", components)
@ -161,9 +187,18 @@ class TestLoadControlNet(unittest.TestCase):
components = load_controlnet(
ServerContext(),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
ImageParams(
"test",
"txt2img",
"ddim",
"test",
1.0,
10,
1,
control=NetworkModel("missing", "control"),
),
)
except:
except Exception:
self.assertNotIn("controlnet", components)
return
@ -171,7 +206,10 @@ class TestLoadControlNet(unittest.TestCase):
class TestLoadTextEncoders(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
"model does not exist",
)
def test_load_embeddings(self):
"""
Should add the token to tokenizer
@ -193,7 +231,10 @@ class TestLoadTextEncoders(unittest.TestCase):
def test_load_embeddings_xl(self):
pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
"model does not exist",
)
def test_load_loras(self):
components = load_text_encoders(
ServerContext(model_path="../models"),
@ -211,8 +252,12 @@ class TestLoadTextEncoders(unittest.TestCase):
def test_load_loras_xl(self):
pass
class TestLoadUnet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
"model does not exist",
)
def test_load_unet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
@ -229,7 +274,10 @@ class TestLoadUnet(unittest.TestCase):
def test_load_unet_loras_xl(self):
pass
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
"model does not exist",
)
def test_load_cnet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
@ -245,7 +293,10 @@ class TestLoadUnet(unittest.TestCase):
class TestLoadVae(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
"model does not exist",
)
def test_load_single(self):
"""
Should return single component
@ -260,7 +311,10 @@ class TestLoadVae(unittest.TestCase):
self.assertNotIn("vae_decoder", components)
self.assertNotIn("vae_encoder", components)
@unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
@unittest.skipUnless(
path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
"model does not exist",
)
def test_load_split(self):
"""
Should return split encoder/decoder

View File

@ -44,14 +44,70 @@ class TestTxt2ImgPipeline(unittest.TestCase):
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-txt2img.png"],
["test-txt2img-basic.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
output = Image.open("../outputs/test-txt2img-basic.png")
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test"),
HighresParams(True, 2, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
output = Image.open("../outputs/test-txt2img-highres.png")
self.assertEqual(output.size, (512, 512))
class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
@ -82,7 +138,14 @@ class TestImg2ImgPipeline(unittest.TestCase):
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
@ -92,6 +155,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self):
@ -121,7 +185,14 @@ class TestUpscalePipeline(unittest.TestCase):
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
"../models/upscaling-stable-diffusion-x4",
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
@ -131,6 +202,7 @@ class TestUpscalePipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-upscale.png"))
class TestBlendPipeline(unittest.TestCase):
def test_basic(self):
cancel = Value("L", 0)
@ -160,7 +232,14 @@ class TestBlendPipeline(unittest.TestCase):
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),

View File

@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import (
get_loras_from_prompt,
get_scaled_latents,
get_tile_latents,
get_tokens_from_prompt,
pop_random,
slice_prompt,
)
@ -26,7 +25,10 @@ class TestExpandIntervalRanges(unittest.TestCase):
def test_prompt_with_range(self):
prompt = "an astronaut-{1,4} eating a hamburger"
result = expand_interval_ranges(prompt)
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
self.assertEqual(
result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger"
)
class TestExpandAlternativeRanges(unittest.TestCase):
def test_prompt_with_no_ranges(self):
@ -37,7 +39,10 @@ class TestExpandAlternativeRanges(unittest.TestCase):
def test_ranges_match(self):
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
result = expand_alternative_ranges(prompt)
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
self.assertEqual(
result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]
)
class TestInversionsFromPrompt(unittest.TestCase):
def test_get_inversions(self):
@ -47,6 +52,7 @@ class TestInversionsFromPrompt(unittest.TestCase):
self.assertEqual(result, " an astronaut eating an embedding")
self.assertEqual(tokens, [("test", 1.0)])
class TestLoRAsFromPrompt(unittest.TestCase):
def test_get_loras(self):
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
@ -55,6 +61,7 @@ class TestLoRAsFromPrompt(unittest.TestCase):
self.assertEqual(result, " an astronaut eating a LoRA")
self.assertEqual(tokens, [("test", 1.0)])
class TestLatentsFromSeed(unittest.TestCase):
def test_batch_size(self):
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
@ -65,6 +72,7 @@ class TestLatentsFromSeed(unittest.TestCase):
latents2 = get_latents_from_seed(1, Size(64, 64))
self.assertTrue(np.array_equal(latents1, latents2))
class TestTileLatents(unittest.TestCase):
def test_full_tile(self):
partial = np.zeros((1, 1, 64, 64))
@ -81,6 +89,7 @@ class TestTileLatents(unittest.TestCase):
full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64))
self.assertEqual(full.shape, (1, 1, 8, 8))
class TestScaledLatents(unittest.TestCase):
def test_scale_up(self):
latents = get_latents_from_seed(1, Size(16, 16))
@ -90,22 +99,29 @@ class TestScaledLatents(unittest.TestCase):
def test_scale_down(self):
latents = get_latents_from_seed(1, Size(16, 16))
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
self.assertEqual((
latents[0, 0, 0, 0] +
latents[0, 0, 0, 1] +
latents[0, 0, 1, 0] +
latents[0, 0, 1, 1]
) / 4, scaled[0, 0, 0, 0])
self.assertEqual(
(
latents[0, 0, 0, 0]
+ latents[0, 0, 0, 1]
+ latents[0, 0, 1, 0]
+ latents[0, 0, 1, 1]
)
/ 4,
scaled[0, 0, 0, 0],
)
class TestReplaceWildcards(unittest.TestCase):
pass
class TestPopRandom(unittest.TestCase):
def test_pop(self):
items = ["1", "2", "3"]
pop_random(items)
self.assertEqual(len(items), 2)
class TestRepairNaN(unittest.TestCase):
def test_unchanged(self):
pass
@ -113,6 +129,7 @@ class TestRepairNaN(unittest.TestCase):
def test_missing(self):
pass
class TestSlicePrompt(unittest.TestCase):
def test_slice_no_delimiter(self):
slice = slice_prompt("foo", 1)

View File

@ -71,7 +71,9 @@ class TestWorkerPool(unittest.TestCase):
device1 = DeviceParams("cpu1", "CPUProvider")
device2 = DeviceParams("cpu2", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT)
self.pool = DevicePoolExecutor(
server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT
)
self.pool.start()
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
@ -83,7 +85,9 @@ class TestWorkerPool(unittest.TestCase):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start(lock)
sleep(2.0)
@ -113,7 +117,9 @@ class TestWorkerPool(unittest.TestCase):
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
sleep(2.0)

View File

@ -20,9 +20,11 @@ from tests.helpers import test_device
def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0])
def main_retry(_worker):
raise RetryException()
def main_interrupt(_worker):
raise KeyboardInterrupt()
@ -47,7 +49,22 @@ class WorkerMainTests(unittest.TestCase):
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_INTERRUPT)
pass
@ -68,7 +85,22 @@ class WorkerMainTests(unittest.TestCase):
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR)
pass
@ -88,7 +120,22 @@ class WorkerMainTests(unittest.TestCase):
idle = Value("L", False)
pending.close()
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_ERROR)
@ -108,11 +155,25 @@ class WorkerMainTests(unittest.TestCase):
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_MEMORY)
def test_pending_exception_other_unknown(self):
pass
@ -130,7 +191,21 @@ class WorkerMainTests(unittest.TestCase):
pid = Value("L", 0)
idle = Value("L", False)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(
WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
pid,
idle,
0,
0.0,
),
ServerContext(),
exit=exit,
)
self.assertEqual(status, EXIT_REPLACED)