more tests
This commit is contained in:
parent
5b659a948a
commit
b851c234fe
|
@ -21,6 +21,7 @@ test:
|
|||
python -m coverage run -m unittest discover -v -s tests/
|
||||
python -m coverage html -i
|
||||
python -m coverage xml -i
|
||||
python -m coverage report -i
|
||||
|
||||
package: package-dist package-upload
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class ChainPipeline:
|
|||
def outputs(self, params: ImageParams, sources: int):
|
||||
outputs = sources
|
||||
for callback, _params, kwargs in self.stages:
|
||||
outputs += callback.outputs(kwargs.get("params", params), outputs)
|
||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -22,7 +22,7 @@ class SourceNoiseStage(BaseStage):
|
|||
*,
|
||||
size: Size,
|
||||
noise_source: Callable,
|
||||
stage_source: Image.Image,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
logger.info("generating image from noise source")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
@ -23,7 +23,7 @@ class SourceURLStage(BaseStage):
|
|||
sources: List[Image.Image],
|
||||
*,
|
||||
source_urls: List[str],
|
||||
stage_source: Image.Image,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
logger.info("loading image from URL source")
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import ModelProto, load, numpy_helper
|
||||
from onnx import ModelProto, NodeProto, load, numpy_helper
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import (
|
||||
convert_model_to_external_data,
|
||||
|
@ -39,7 +39,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|||
lr = a
|
||||
|
||||
if kernel == (1, 1):
|
||||
lr = np.expand_dims(lr, axis=(2, 3))
|
||||
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
|
||||
|
||||
return hr + lr
|
||||
|
||||
|
@ -78,7 +78,7 @@ def fix_node_name(key: str):
|
|||
return fixed_name
|
||||
|
||||
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
||||
fixed = {}
|
||||
|
||||
for key, value in keys.items():
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_grid import BlendGridStage
|
||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||
|
||||
|
||||
class BlendGridStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendGridStage()
|
||||
sources = [
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
]
|
||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||
|
||||
self.assertEqual(len(result), 5)
|
||||
self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0))
|
|
@ -0,0 +1,26 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||
from onnx_web.params import DeviceParams, ImageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
|
||||
|
||||
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
"""
|
||||
stage = BlendImg2ImgStage()
|
||||
params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1)
|
||||
server = ServerContext()
|
||||
worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0)
|
||||
sources = [
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
||||
"""
|
||||
pass
|
|
@ -1,4 +1,5 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_mask import BlendMaskStage
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
|
||||
from onnx_web.params import DeviceParams, HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.server.hacks import apply_patches
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
|
||||
|
||||
class CorrectCodeformerStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
"""
|
||||
server = ServerContext()
|
||||
apply_patches(server)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
DeviceParams("cpu", "CPUProvider"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
)
|
||||
stage = CorrectCodeformerStage()
|
||||
sources = []
|
||||
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,28 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.reduce_crop import ReduceCropStage
|
||||
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class ReduceThumbnailStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage_source = Image.new("RGB", (64, 64))
|
||||
stage = ReduceThumbnailStage()
|
||||
sources = []
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
stage_source=stage_source,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,25 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.source_noise import SourceNoiseStage
|
||||
from onnx_web.image.noise_source import noise_source_fill_edge
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class SourceNoiseStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = SourceNoiseStage()
|
||||
sources = []
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
noise_source=noise_source_fill_edge,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,25 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.source_s3 import SourceS3Stage
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class SourceS3StageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = SourceS3Stage()
|
||||
sources = []
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
bucket="test",
|
||||
source_keys=[],
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,24 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.source_url import SourceURLStage
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class SourceURLStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = SourceURLStage()
|
||||
sources = []
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
source_urls=[],
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -2,7 +2,12 @@ import unittest
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.tile import complete_tile, generate_tile_spiral, get_tile_grads, needs_tile
|
||||
from onnx_web.chain.tile import (
|
||||
complete_tile,
|
||||
generate_tile_spiral,
|
||||
get_tile_grads,
|
||||
needs_tile,
|
||||
)
|
||||
from onnx_web.params import Size
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from onnx import GraphProto, ModelProto, NodeProto
|
||||
from onnx.numpy_helper import from_array
|
||||
|
||||
from onnx_web.convert.diffusion.lora import (
|
||||
blend_loras,
|
||||
buffer_external_data_tensors,
|
||||
fix_initializer_name,
|
||||
fix_node_name,
|
||||
fix_xl_names,
|
||||
interp_to_match,
|
||||
kernel_slice,
|
||||
sum_weights,
|
||||
)
|
||||
|
||||
|
||||
class SumWeightsTests(unittest.TestCase):
|
||||
def test_same_shape(self):
|
||||
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
|
||||
self.assertEqual(weights.shape, (4, 4))
|
||||
|
||||
def test_1x1_kernel(self):
|
||||
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
|
||||
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
||||
|
||||
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
|
||||
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
||||
|
||||
|
||||
def test_3x3_kernel(self):
|
||||
"""
|
||||
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
|
||||
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BufferExternalDataTensorTests(unittest.TestCase):
|
||||
def test_basic_external(self):
|
||||
model = ModelProto(
|
||||
graph=GraphProto(
|
||||
initializer=[
|
||||
from_array(np.zeros((4, 4))),
|
||||
],
|
||||
)
|
||||
)
|
||||
(slim_model, external_weights) = buffer_external_data_tensors(model)
|
||||
|
||||
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
|
||||
self.assertEqual(len(external_weights), 1)
|
||||
|
||||
|
||||
class FixInitializerKeyTests(unittest.TestCase):
|
||||
def test_fix_name(self):
|
||||
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
|
||||
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
|
||||
|
||||
for input, output in zip(inputs, outputs):
|
||||
self.assertEqual(fix_initializer_name(input), output)
|
||||
|
||||
|
||||
class FixNodeNameTests(unittest.TestCase):
|
||||
def test_fix_name(self):
|
||||
inputs = [
|
||||
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
|
||||
"_prefix",
|
||||
]
|
||||
outputs = [
|
||||
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
|
||||
"prefix",
|
||||
]
|
||||
|
||||
for input, output in zip(inputs, outputs):
|
||||
self.assertEqual(fix_node_name(input), output)
|
||||
|
||||
|
||||
class FixXLNameTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
nodes = {}
|
||||
fixed = fix_xl_names(nodes, [])
|
||||
|
||||
self.assertEqual(fixed, {})
|
||||
|
||||
def test_input_block(self):
|
||||
nodes = {
|
||||
"input_block_proj.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
NodeProto(name="/down_blocks_proj/MatMul"),
|
||||
])
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
|
||||
})
|
||||
|
||||
def test_middle_block(self):
|
||||
nodes = {
|
||||
"middle_block_proj.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
NodeProto(name="/mid_blocks_proj/MatMul"),
|
||||
])
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
|
||||
})
|
||||
|
||||
def test_output_block(self):
|
||||
pass
|
||||
|
||||
def test_text_model(self):
|
||||
pass
|
||||
|
||||
def test_unknown_block(self):
|
||||
pass
|
||||
|
||||
def test_unmatched_block(self):
|
||||
nodes = {
|
||||
"lora_unet.input_block.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [""])
|
||||
|
||||
self.assertEqual(fixed, nodes)
|
||||
|
||||
def test_output_projection(self):
|
||||
nodes = {
|
||||
"output_block_proj_o.lora_down.weight": {},
|
||||
}
|
||||
fixed = fix_xl_names(nodes, [
|
||||
NodeProto(name="/up_blocks_proj_o/MatMul"),
|
||||
])
|
||||
|
||||
self.assertEqual(fixed, {
|
||||
"up_blocks_proj_out": nodes["output_block_proj_o.lora_down.weight"],
|
||||
})
|
||||
|
||||
|
||||
class KernelSliceTests(unittest.TestCase):
|
||||
def test_within_kernel(self):
|
||||
self.assertEqual(
|
||||
kernel_slice(1, 1, (3, 3, 3, 3)),
|
||||
(1, 1),
|
||||
)
|
||||
|
||||
def test_outside_kernel(self):
|
||||
self.assertEqual(
|
||||
kernel_slice(9, 9, (3, 3, 3, 3)),
|
||||
(2, 2),
|
||||
)
|
||||
|
||||
class BlendLoRATests(unittest.TestCase):
|
||||
pass
|
||||
|
||||
class InterpToMatchTests(unittest.TestCase):
|
||||
def test_same_shape(self):
|
||||
ref = np.zeros((4, 4))
|
||||
resize = np.zeros((4, 4))
|
||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
||||
|
||||
def test_different_one_dim(self):
|
||||
ref = np.zeros((4, 2))
|
||||
resize = np.zeros((4, 4))
|
||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
||||
|
||||
def test_different_both_dims(self):
|
||||
ref = np.zeros((2, 2))
|
||||
resize = np.zeros((4, 4))
|
||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
|
@ -0,0 +1,19 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress
|
||||
|
||||
|
||||
class ConversionContextTests(unittest.TestCase):
|
||||
def test_from_environ(self):
|
||||
context = ConversionContext.from_environ()
|
||||
self.assertEqual(context.opset, DEFAULT_OPSET)
|
||||
|
||||
def test_map_location(self):
|
||||
context = ConversionContext.from_environ()
|
||||
self.assertEqual(context.map_location.type, "cpu")
|
||||
|
||||
|
||||
class DownloadProgressTests(unittest.TestCase):
|
||||
def test_download_example(self):
|
||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
||||
self.assertEqual(path, "/tmp/example-dot-com")
|
|
@ -0,0 +1,83 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.server.load import (
|
||||
get_available_platforms,
|
||||
get_config_params,
|
||||
get_correction_models,
|
||||
get_diffusion_models,
|
||||
get_extra_hashes,
|
||||
get_extra_strings,
|
||||
get_highres_methods,
|
||||
get_mask_filters,
|
||||
get_network_models,
|
||||
get_noise_sources,
|
||||
get_source_filters,
|
||||
get_upscaling_models,
|
||||
get_wildcard_data,
|
||||
)
|
||||
|
||||
|
||||
class ConfigParamTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
params = get_config_params()
|
||||
self.assertIsNotNone(params)
|
||||
|
||||
class AvailablePlatformTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
platforms = get_available_platforms()
|
||||
self.assertIsNotNone(platforms)
|
||||
|
||||
class CorrectModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_correction_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
class DiffusionModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_diffusion_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
class NetworkModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_network_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
class UpscalingModelTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
models = get_upscaling_models()
|
||||
self.assertIsNotNone(models)
|
||||
|
||||
class WildcardDataTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
wildcards = get_wildcard_data()
|
||||
self.assertIsNotNone(wildcards)
|
||||
|
||||
class ExtraStringsTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
strings = get_extra_strings()
|
||||
self.assertIsNotNone(strings)
|
||||
|
||||
class ExtraHashesTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
hashes = get_extra_hashes()
|
||||
self.assertIsNotNone(hashes)
|
||||
|
||||
class HighresMethodTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
methods = get_highres_methods()
|
||||
self.assertIsNotNone(methods)
|
||||
|
||||
class MaskFilterTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
filters = get_mask_filters()
|
||||
self.assertIsNotNone(filters)
|
||||
|
||||
class NoiseSourceTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
sources = get_noise_sources()
|
||||
self.assertIsNotNone(sources)
|
||||
|
||||
class SourceFilterTests(unittest.TestCase):
|
||||
def test_before_setup(self):
|
||||
filters = get_source_filters()
|
||||
self.assertIsNotNone(filters)
|
Loading…
Reference in New Issue