1
0
Fork 0

more tests

This commit is contained in:
Sean Sube 2023-09-14 19:35:48 -05:00
parent 5b659a948a
commit b851c234fe
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 470 additions and 9 deletions

View File

@ -21,6 +21,7 @@ test:
python -m coverage run -m unittest discover -v -s tests/
python -m coverage html -i
python -m coverage xml -i
python -m coverage report -i
package: package-dist package-upload

View File

@ -95,7 +95,7 @@ class ChainPipeline:
def outputs(self, params: ImageParams, sources: int):
outputs = sources
for callback, _params, kwargs in self.stages:
outputs += callback.outputs(kwargs.get("params", params), outputs)
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable, List
from typing import Callable, List, Optional
from PIL import Image
@ -22,7 +22,7 @@ class SourceNoiseStage(BaseStage):
*,
size: Size,
noise_source: Callable,
stage_source: Image.Image,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
logger.info("generating image from noise source")

View File

@ -1,6 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import List
from typing import List, Optional
import requests
from PIL import Image
@ -23,7 +23,7 @@ class SourceURLStage(BaseStage):
sources: List[Image.Image],
*,
source_urls: List[str],
stage_source: Image.Image,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
logger.info("loading image from URL source")

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from onnx import ModelProto, load, numpy_helper
from onnx import ModelProto, NodeProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
@ -39,7 +39,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
lr = a
if kernel == (1, 1):
lr = np.expand_dims(lr, axis=(2, 3))
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
return hr + lr
@ -78,7 +78,7 @@ def fix_node_name(key: str):
return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
fixed = {}
for key, value in keys.items():

View File

@ -0,0 +1,21 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.blend_linear import BlendLinearStage
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = [
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
]
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0))

View File

@ -0,0 +1,26 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
class BlendImg2ImgStageTests(unittest.TestCase):
def test_stage(self):
"""
stage = BlendImg2ImgStage()
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
server = ServerContext()
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
sources = [
Image.new("RGB", (64, 64), "black"),
]
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
"""
pass

View File

@ -1,4 +1,5 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_mask import BlendMaskStage

View File

@ -0,0 +1,33 @@
import unittest
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
from onnx_web.params import DeviceParams, HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
class CorrectCodeformerStageTests(unittest.TestCase):
def test_empty(self):
"""
server = ServerContext()
apply_patches(server)
worker = WorkerContext(
"test",
DeviceParams("cpu", "CPUProvider"),
None,
None,
None,
None,
None,
None,
0,
)
stage = CorrectCodeformerStage()
sources = []
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0)
"""
pass

View File

@ -0,0 +1,28 @@
import unittest
from PIL import Image
from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceThumbnailStageTests(unittest.TestCase):
def test_empty(self):
stage_source = Image.new("RGB", (64, 64))
stage = ReduceThumbnailStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
stage_source=stage_source,
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,25 @@
import unittest
from onnx_web.chain.source_noise import SourceNoiseStage
from onnx_web.image.noise_source import noise_source_fill_edge
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceNoiseStageTests(unittest.TestCase):
def test_empty(self):
stage = SourceNoiseStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
noise_source=noise_source_fill_edge,
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,25 @@
import unittest
from onnx_web.chain.source_s3 import SourceS3Stage
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceS3StageTests(unittest.TestCase):
def test_empty(self):
stage = SourceS3Stage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
bucket="test",
source_keys=[],
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,24 @@
import unittest
from onnx_web.chain.source_url import SourceURLStage
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceURLStageTests(unittest.TestCase):
def test_empty(self):
stage = SourceURLStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
source_urls=[],
)
self.assertEqual(len(result), 0)

View File

@ -2,7 +2,12 @@ import unittest
from PIL import Image
from onnx_web.chain.tile import complete_tile, generate_tile_spiral, get_tile_grads, needs_tile
from onnx_web.chain.tile import (
complete_tile,
generate_tile_spiral,
get_tile_grads,
needs_tile,
)
from onnx_web.params import Size

View File

View File

View File

@ -0,0 +1,170 @@
import unittest
import numpy as np
from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
buffer_external_data_tensors,
fix_initializer_name,
fix_node_name,
fix_xl_names,
interp_to_match,
kernel_slice,
sum_weights,
)
class SumWeightsTests(unittest.TestCase):
def test_same_shape(self):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4))
def test_1x1_kernel(self):
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
"""
pass
class BufferExternalDataTensorTests(unittest.TestCase):
def test_basic_external(self):
model = ModelProto(
graph=GraphProto(
initializer=[
from_array(np.zeros((4, 4))),
],
)
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
class FixNodeNameTests(unittest.TestCase):
def test_fix_name(self):
inputs = [
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
"_prefix",
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
"prefix",
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_node_name(input), output)
class FixXLNameTests(unittest.TestCase):
def test_empty(self):
nodes = {}
fixed = fix_xl_names(nodes, [])
self.assertEqual(fixed, {})
def test_input_block(self):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/down_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
})
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/mid_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
})
def test_output_block(self):
pass
def test_text_model(self):
pass
def test_unknown_block(self):
pass
def test_unmatched_block(self):
nodes = {
"lora_unet.input_block.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [""])
self.assertEqual(fixed, nodes)
def test_output_projection(self):
nodes = {
"output_block_proj_o.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/up_blocks_proj_o/MatMul"),
])
self.assertEqual(fixed, {
"up_blocks_proj_out": nodes["output_block_proj_o.lora_down.weight"],
})
class KernelSliceTests(unittest.TestCase):
def test_within_kernel(self):
self.assertEqual(
kernel_slice(1, 1, (3, 3, 3, 3)),
(1, 1),
)
def test_outside_kernel(self):
self.assertEqual(
kernel_slice(9, 9, (3, 3, 3, 3)),
(2, 2),
)
class BlendLoRATests(unittest.TestCase):
pass
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))

View File

@ -0,0 +1,19 @@
import unittest
from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress
class ConversionContextTests(unittest.TestCase):
def test_from_environ(self):
context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET)
def test_map_location(self):
context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu")
class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")

View File

@ -0,0 +1,83 @@
import unittest
from onnx_web.server.load import (
get_available_platforms,
get_config_params,
get_correction_models,
get_diffusion_models,
get_extra_hashes,
get_extra_strings,
get_highres_methods,
get_mask_filters,
get_network_models,
get_noise_sources,
get_source_filters,
get_upscaling_models,
get_wildcard_data,
)
class ConfigParamTests(unittest.TestCase):
def test_before_setup(self):
params = get_config_params()
self.assertIsNotNone(params)
class AvailablePlatformTests(unittest.TestCase):
def test_before_setup(self):
platforms = get_available_platforms()
self.assertIsNotNone(platforms)
class CorrectModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_correction_models()
self.assertIsNotNone(models)
class DiffusionModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_diffusion_models()
self.assertIsNotNone(models)
class NetworkModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_network_models()
self.assertIsNotNone(models)
class UpscalingModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_upscaling_models()
self.assertIsNotNone(models)
class WildcardDataTests(unittest.TestCase):
def test_before_setup(self):
wildcards = get_wildcard_data()
self.assertIsNotNone(wildcards)
class ExtraStringsTests(unittest.TestCase):
def test_before_setup(self):
strings = get_extra_strings()
self.assertIsNotNone(strings)
class ExtraHashesTests(unittest.TestCase):
def test_before_setup(self):
hashes = get_extra_hashes()
self.assertIsNotNone(hashes)
class HighresMethodTests(unittest.TestCase):
def test_before_setup(self):
methods = get_highres_methods()
self.assertIsNotNone(methods)
class MaskFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_mask_filters()
self.assertIsNotNone(filters)
class NoiseSourceTests(unittest.TestCase):
def test_before_setup(self):
sources = get_noise_sources()
self.assertIsNotNone(sources)
class SourceFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_source_filters()
self.assertIsNotNone(filters)