1
0
Fork 0

more tests

This commit is contained in:
Sean Sube 2023-09-14 19:35:48 -05:00
parent 5b659a948a
commit b851c234fe
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 470 additions and 9 deletions

View File

@ -21,6 +21,7 @@ test:
python -m coverage run -m unittest discover -v -s tests/ python -m coverage run -m unittest discover -v -s tests/
python -m coverage html -i python -m coverage html -i
python -m coverage xml -i python -m coverage xml -i
python -m coverage report -i
package: package-dist package-upload package: package-dist package-upload

View File

@ -95,7 +95,7 @@ class ChainPipeline:
def outputs(self, params: ImageParams, sources: int): def outputs(self, params: ImageParams, sources: int):
outputs = sources outputs = sources
for callback, _params, kwargs in self.stages: for callback, _params, kwargs in self.stages:
outputs += callback.outputs(kwargs.get("params", params), outputs) outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List from typing import Callable, List, Optional
from PIL import Image from PIL import Image
@ -22,7 +22,7 @@ class SourceNoiseStage(BaseStage):
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Image.Image, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> List[Image.Image]:
logger.info("generating image from noise source") logger.info("generating image from noise source")

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import List from typing import List, Optional
import requests import requests
from PIL import Image from PIL import Image
@ -23,7 +23,7 @@ class SourceURLStage(BaseStage):
sources: List[Image.Image], sources: List[Image.Image],
*, *,
source_urls: List[str], source_urls: List[str],
stage_source: Image.Image, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> List[Image.Image]:
logger.info("loading image from URL source") logger.info("loading image from URL source")

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from onnx import ModelProto, load, numpy_helper from onnx import ModelProto, NodeProto, load, numpy_helper
from onnx.checker import check_model from onnx.checker import check_model
from onnx.external_data_helper import ( from onnx.external_data_helper import (
convert_model_to_external_data, convert_model_to_external_data,
@ -39,7 +39,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
lr = a lr = a
if kernel == (1, 1): if kernel == (1, 1):
lr = np.expand_dims(lr, axis=(2, 3)) lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
return hr + lr return hr + lr
@ -78,7 +78,7 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
fixed = {} fixed = {}
for key, value in keys.items(): for key, value in keys.items():

View File

@ -0,0 +1,21 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.blend_linear import BlendLinearStage
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = [
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
]
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0))

View File

@ -0,0 +1,26 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
class BlendImg2ImgStageTests(unittest.TestCase):
def test_stage(self):
"""
stage = BlendImg2ImgStage()
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
server = ServerContext()
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
sources = [
Image.new("RGB", (64, 64), "black"),
]
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
"""
pass

View File

@ -1,4 +1,5 @@
import unittest import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_mask import BlendMaskStage from onnx_web.chain.blend_mask import BlendMaskStage

View File

@ -0,0 +1,33 @@
import unittest
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
from onnx_web.params import DeviceParams, HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
class CorrectCodeformerStageTests(unittest.TestCase):
def test_empty(self):
"""
server = ServerContext()
apply_patches(server)
worker = WorkerContext(
"test",
DeviceParams("cpu", "CPUProvider"),
None,
None,
None,
None,
None,
None,
0,
)
stage = CorrectCodeformerStage()
sources = []
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0)
"""
pass

View File

@ -0,0 +1,28 @@
import unittest
from PIL import Image
from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceThumbnailStageTests(unittest.TestCase):
def test_empty(self):
stage_source = Image.new("RGB", (64, 64))
stage = ReduceThumbnailStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
stage_source=stage_source,
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,25 @@
import unittest
from onnx_web.chain.source_noise import SourceNoiseStage
from onnx_web.image.noise_source import noise_source_fill_edge
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceNoiseStageTests(unittest.TestCase):
def test_empty(self):
stage = SourceNoiseStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
noise_source=noise_source_fill_edge,
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,25 @@
import unittest
from onnx_web.chain.source_s3 import SourceS3Stage
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceS3StageTests(unittest.TestCase):
def test_empty(self):
stage = SourceS3Stage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
bucket="test",
source_keys=[],
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,24 @@
import unittest
from onnx_web.chain.source_url import SourceURLStage
from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceURLStageTests(unittest.TestCase):
def test_empty(self):
stage = SourceURLStage()
sources = []
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
source_urls=[],
)
self.assertEqual(len(result), 0)

View File

@ -2,7 +2,12 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.tile import complete_tile, generate_tile_spiral, get_tile_grads, needs_tile from onnx_web.chain.tile import (
complete_tile,
generate_tile_spiral,
get_tile_grads,
needs_tile,
)
from onnx_web.params import Size from onnx_web.params import Size

View File

View File

View File

@ -0,0 +1,170 @@
import unittest
import numpy as np
from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
buffer_external_data_tensors,
fix_initializer_name,
fix_node_name,
fix_xl_names,
interp_to_match,
kernel_slice,
sum_weights,
)
class SumWeightsTests(unittest.TestCase):
def test_same_shape(self):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4))
def test_1x1_kernel(self):
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
"""
pass
class BufferExternalDataTensorTests(unittest.TestCase):
def test_basic_external(self):
model = ModelProto(
graph=GraphProto(
initializer=[
from_array(np.zeros((4, 4))),
],
)
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
class FixNodeNameTests(unittest.TestCase):
def test_fix_name(self):
inputs = [
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
"_prefix",
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
"prefix",
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_node_name(input), output)
class FixXLNameTests(unittest.TestCase):
def test_empty(self):
nodes = {}
fixed = fix_xl_names(nodes, [])
self.assertEqual(fixed, {})
def test_input_block(self):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/down_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
})
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/mid_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
})
def test_output_block(self):
pass
def test_text_model(self):
pass
def test_unknown_block(self):
pass
def test_unmatched_block(self):
nodes = {
"lora_unet.input_block.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [""])
self.assertEqual(fixed, nodes)
def test_output_projection(self):
nodes = {
"output_block_proj_o.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/up_blocks_proj_o/MatMul"),
])
self.assertEqual(fixed, {
"up_blocks_proj_out": nodes["output_block_proj_o.lora_down.weight"],
})
class KernelSliceTests(unittest.TestCase):
def test_within_kernel(self):
self.assertEqual(
kernel_slice(1, 1, (3, 3, 3, 3)),
(1, 1),
)
def test_outside_kernel(self):
self.assertEqual(
kernel_slice(9, 9, (3, 3, 3, 3)),
(2, 2),
)
class BlendLoRATests(unittest.TestCase):
pass
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))

View File

@ -0,0 +1,19 @@
import unittest
from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress
class ConversionContextTests(unittest.TestCase):
def test_from_environ(self):
context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET)
def test_map_location(self):
context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu")
class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")

View File

@ -0,0 +1,83 @@
import unittest
from onnx_web.server.load import (
get_available_platforms,
get_config_params,
get_correction_models,
get_diffusion_models,
get_extra_hashes,
get_extra_strings,
get_highres_methods,
get_mask_filters,
get_network_models,
get_noise_sources,
get_source_filters,
get_upscaling_models,
get_wildcard_data,
)
class ConfigParamTests(unittest.TestCase):
def test_before_setup(self):
params = get_config_params()
self.assertIsNotNone(params)
class AvailablePlatformTests(unittest.TestCase):
def test_before_setup(self):
platforms = get_available_platforms()
self.assertIsNotNone(platforms)
class CorrectModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_correction_models()
self.assertIsNotNone(models)
class DiffusionModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_diffusion_models()
self.assertIsNotNone(models)
class NetworkModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_network_models()
self.assertIsNotNone(models)
class UpscalingModelTests(unittest.TestCase):
def test_before_setup(self):
models = get_upscaling_models()
self.assertIsNotNone(models)
class WildcardDataTests(unittest.TestCase):
def test_before_setup(self):
wildcards = get_wildcard_data()
self.assertIsNotNone(wildcards)
class ExtraStringsTests(unittest.TestCase):
def test_before_setup(self):
strings = get_extra_strings()
self.assertIsNotNone(strings)
class ExtraHashesTests(unittest.TestCase):
def test_before_setup(self):
hashes = get_extra_hashes()
self.assertIsNotNone(hashes)
class HighresMethodTests(unittest.TestCase):
def test_before_setup(self):
methods = get_highres_methods()
self.assertIsNotNone(methods)
class MaskFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_mask_filters()
self.assertIsNotNone(filters)
class NoiseSourceTests(unittest.TestCase):
def test_before_setup(self):
sources = get_noise_sources()
self.assertIsNotNone(sources)
class SourceFilterTests(unittest.TestCase):
def test_before_setup(self):
filters = get_source_filters()
self.assertIsNotNone(filters)