1
0
Fork 0

keep metadata when tiling

This commit is contained in:
Sean Sube 2024-01-06 14:27:55 -06:00
parent 3a647ad9bd
commit e8e2b92436
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 230 additions and 2 deletions

View File

@ -344,7 +344,7 @@ def process_tile_stack(
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata)
tiles.append((left, top, tile_stack.as_images()))

View File

@ -396,7 +396,7 @@ def run_inpaint_pipeline(
latents=latents,
)
for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)):
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
if full_res_inpaint:
if is_debug():
save_image(server, "adjusted-output.png", image)

View File

@ -0,0 +1,51 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
from onnx_web.chain.result import ImageMetadata, StageResult
from tests.helpers import test_params, test_size
class TestBlendDenoiseFastNLMeansStage(unittest.TestCase):
def test_run(self):
# Create a dummy image
size = test_size()
image = Image.new("RGB", (size.width, size.height), color="white")
# Create a dummy StageResult object
sources = StageResult.from_images(
[image],
metadata=[
ImageMetadata(
test_params(),
size,
)
],
)
# Create an instance of BlendDenoiseLocalStdStage
stage = BlendDenoiseFastNLMeansStage()
# Call the run method with dummy parameters
result = stage.run(
_worker=None,
_server=None,
_stage=None,
_params=None,
sources=sources,
strength=5,
range=4,
stage_source=None,
callback=None,
)
# Assert that the result is an instance of StageResult
self.assertIsInstance(result, StageResult)
# Assert that the result contains the denoised image
self.assertEqual(len(result), 1)
self.assertEqual(result.size(), size)
# Assert that the metadata is preserved
self.assertEqual(result.metadata, sources.metadata)

View File

@ -0,0 +1,9 @@
import unittest
from onnx_web.image.ade_palette import ade_palette
class TestADEPalette(unittest.TestCase):
def test_palette_length(self):
palette = ade_palette()
self.assertEqual(len(palette), 150, "Palette length should be 150")

View File

@ -0,0 +1,69 @@
import unittest
import numpy as np
from onnx_web.image.laion_face import draw_pupils, generate_annotation, reverse_channels
class TestLaionFace(unittest.TestCase):
@unittest.skip
def test_draw_pupils(self):
# Create a dummy image
image = np.zeros((100, 100, 3), dtype=np.uint8)
# Create a dummy landmark list
class LandmarkList:
def __init__(self, landmarks):
self.landmark = landmarks
# Create a dummy drawing spec
class DrawingSpec:
def __init__(self, color):
self.color = color
# Create some dummy landmarks
landmarks = [
# Add your landmarks here
]
# Create a dummy drawing spec
drawing_spec = DrawingSpec(color=(255, 0, 0)) # Red color
# Call the draw_pupils function
draw_pupils(image, LandmarkList(landmarks), drawing_spec)
self.assertNotEqual(np.sum(image), 0, "Image should be modified")
@unittest.skip
def test_generate_annotation(self):
# Create a dummy image
image = np.zeros((100, 100, 3), dtype=np.uint8)
# Call the generate_annotation function
result = generate_annotation(image, max_faces=1, min_confidence=0.5)
self.assertEqual(
result.shape,
image.shape,
"Result shape should be the same as the input image",
)
self.assertNotEqual(np.sum(result), 0, "Result should not be all zeros")
class TestReverseChannels(unittest.TestCase):
def test_reverse_channels(self):
# Create a dummy image
image = np.zeros((100, 100, 3), dtype=np.uint8)
layer = np.ones((100, 100), dtype=np.uint8)
image[:, :, 0] = layer
# Call the reverse_channels function
reversed_image = reverse_channels(image)
self.assertEqual(
image.shape, reversed_image.shape, "Image shape should be the same"
)
self.assertTrue(
np.array_equal(reversed_image[:, :, 2], layer),
"Channels should be reversed",
)

99
api/tests/test_utils.py Normal file
View File

@ -0,0 +1,99 @@
import unittest
from onnx_web.utils import (
get_and_clamp_float,
get_and_clamp_int,
get_boolean,
get_from_list,
get_from_map,
get_list,
get_not_empty,
split_list,
)
class TestUtils(unittest.TestCase):
def test_split_list_empty(self):
self.assertEqual(split_list(""), [])
self.assertEqual(split_list(" "), [])
self.assertEqual(split_list(" , "), [])
def test_split_list_single(self):
self.assertEqual(split_list("a"), ["a"])
self.assertEqual(split_list(" a "), ["a"])
self.assertEqual(split_list(" a, "), ["a"])
self.assertEqual(split_list(" a , "), ["a"])
def test_split_list_multiple(self):
self.assertEqual(split_list("a,b"), ["a", "b"])
self.assertEqual(split_list(" a , b "), ["a", "b"])
self.assertEqual(split_list(" a, b "), ["a", "b"])
self.assertEqual(split_list(" a ,b "), ["a", "b"])
def test_get_boolean_empty(self):
self.assertEqual(get_boolean({}, "key", False), False)
self.assertEqual(get_boolean({}, "key", True), True)
def test_get_boolean_true(self):
self.assertEqual(get_boolean({"key": True}, "key", False), True)
self.assertEqual(get_boolean({"key": True}, "key", True), True)
def test_get_boolean_false(self):
self.assertEqual(get_boolean({"key": False}, "key", False), False)
self.assertEqual(get_boolean({"key": False}, "key", True), False)
def test_get_list_empty(self):
self.assertEqual(get_list({}, "key", ""), [])
self.assertEqual(get_list({}, "key", "a"), ["a"])
def test_get_list_exists(self):
self.assertEqual(get_list({"key": "a,b"}, "key", ""), ["a", "b"])
self.assertEqual(get_list({"key": "a,b"}, "key", "c"), ["a", "b"])
def test_get_and_clamp_float_empty(self):
self.assertEqual(get_and_clamp_float({}, "key", 0.0, 1.0), 0.0)
self.assertEqual(get_and_clamp_float({}, "key", 1.0, 1.0), 1.0)
def test_get_and_clamp_float_clamped(self):
self.assertEqual(get_and_clamp_float({"key": -1.0}, "key", 0.0, 1.0), 0.0)
self.assertEqual(get_and_clamp_float({"key": 2.0}, "key", 0.0, 1.0), 1.0)
def test_get_and_clamp_float_normal(self):
self.assertEqual(get_and_clamp_float({"key": 0.5}, "key", 0.0, 1.0), 0.5)
def test_get_and_clamp_int_empty(self):
self.assertEqual(get_and_clamp_int({}, "key", 0, 1), 1)
self.assertEqual(get_and_clamp_int({}, "key", 1, 1), 1)
def test_get_and_clamp_int_clamped(self):
self.assertEqual(get_and_clamp_int({"key": 0}, "key", 1, 1), 1)
self.assertEqual(get_and_clamp_int({"key": 2}, "key", 1, 1), 1)
def test_get_and_clamp_int_normal(self):
self.assertEqual(get_and_clamp_int({"key": 1}, "key", 0, 1), 1)
def test_get_from_list_empty(self):
self.assertEqual(get_from_list({}, "key", ["a", "b"]), "a")
self.assertEqual(get_from_list({}, "key", ["a", "b"], "a"), "a")
def test_get_from_list_exists(self):
self.assertEqual(get_from_list({"key": "a"}, "key", ["a", "b"]), "a")
self.assertEqual(get_from_list({"key": "b"}, "key", ["a", "b"]), "b")
def test_get_from_list_invalid(self):
self.assertEqual(get_from_list({"key": "c"}, "key", ["a", "b"]), "a")
def test_get_from_map_empty(self):
self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "a"), 1)
self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "b"), 2)
def test_get_from_map_exists(self):
self.assertEqual(get_from_map({"key": "a"}, "key", {"a": 1, "b": 2}, "a"), 1)
self.assertEqual(get_from_map({"key": "b"}, "key", {"a": 1, "b": 2}, "a"), 2)
def test_get_not_empty_empty(self):
self.assertEqual(get_not_empty({}, "key", "a"), "a")
self.assertEqual(get_not_empty({"key": ""}, "key", "a"), "a")
def test_get_not_empty_exists(self):
self.assertEqual(get_not_empty({"key": "b"}, "key", "a"), "b")