1
0
Fork 0

provide empty metadata for pipelines with source images

This commit is contained in:
Sean Sube 2024-01-06 14:17:26 -06:00
parent 3e5a95548b
commit 3a647ad9bd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 212 additions and 23 deletions

1
api/.gitignore vendored
View File

@ -7,6 +7,7 @@ entry.py
*.swp
*.pyc
.cache/
__pycache__/
dist/
htmlcov/

View File

@ -36,6 +36,14 @@ class ImageMetadata:
loras: Optional[List[NetworkMetadata]]
models: Optional[List[NetworkMetadata]]
@staticmethod
def unknown_image() -> "ImageMetadata":
UNKNOWN_STR = "unknown"
return ImageMetadata(
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
Size(0, 0),
)
def __init__(
self,
params: ImageParams,
@ -212,22 +220,18 @@ class StageResult:
return StageResult(images=[])
@staticmethod
def from_arrays(
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
):
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
return StageResult(arrays=arrays, metadata=metadata)
@staticmethod
def from_images(
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
):
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
return StageResult(images=images, metadata=metadata)
def __init__(
self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
metadata: Optional[List[ImageMetadata]] = None,
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
source: Optional[Any] = None,
) -> None:
data_provided = sum(

View File

@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
from .result import ImageMetadata, StageResult
logger = getLogger(__name__)
@ -50,7 +50,8 @@ class SourceS3Stage(BaseStage):
logger.exception("error loading image from S3")
# TODO: attempt to load metadata from s3 or load it from the image itself (exif data)
return StageResult(outputs)
metadata = [ImageMetadata.unknown_image()] * len(outputs)
return StageResult(outputs, metadata=metadata)
def outputs(
self,

View File

@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
from .result import ImageMetadata, StageResult
logger = getLogger(__name__)
@ -42,7 +42,8 @@ class SourceURLStage(BaseStage):
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return StageResult(images=outputs)
metadata = [ImageMetadata.unknown_image()] * len(outputs)
return StageResult(images=outputs, metadata=metadata)
def outputs(
self,

View File

@ -13,7 +13,7 @@ from ..chain import (
UpscaleOutpaintStage,
)
from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.result import ImageMetadata, StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image, save_result
@ -212,7 +212,11 @@ def run_img2img_pipeline(
# run and append the filtered source
progress = worker.get_progress_callback(reset=True)
images = chain(
worker, server, params, StageResult(images=[source]), callback=progress
worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
callback=progress,
)
if source_filter is not None and source_filter != "none":
@ -385,7 +389,9 @@ def run_inpaint_pipeline(
worker,
server,
params,
StageResult(images=[source]), # TODO: load metadata from source image
StageResult(
images=[source], metadata=[ImageMetadata.unknown_image()]
), # TODO: load metadata from source image
callback=progress,
latents=latents,
)
@ -459,7 +465,11 @@ def run_upscale_pipeline(
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(
worker, server, params, StageResult(images=[source]), callback=progress
worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
callback=progress,
)
save_result(server, images, worker.job)
@ -508,7 +518,13 @@ def run_blend_pipeline(
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(
worker, server, params, StageResult(images=sources), callback=progress
worker,
server,
params,
StageResult(
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
),
callback=progress,
)
save_result(server, images, worker.job)

View File

@ -64,6 +64,8 @@ def save_result(
) -> List[str]:
images = result.as_images()
outputs = make_output_names(server, base_name, len(images))
logger.debug("saving %s images: %s", len(images), outputs)
results = []
for image, metadata, filename in zip(images, result.metadata, outputs):
results.append(

View File

@ -485,10 +485,10 @@ def check_ready(host: str, key: str) -> bool:
raise TestError("error getting image status")
def check_outputs(host: str, key: str) -> List[str]:
resp = requests.get(f"{host}/api/ready?output={key}")
resp = requests.get(f"{host}/api/job/status?jobs={key}")
if resp.status_code == 200:
json = resp.json()
outputs = json.get("outputs", [])
outputs = json[0].get("outputs", [])
return outputs
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
@ -499,12 +499,13 @@ def download_images(host: str, key: str) -> List[Image.Image]:
images = []
for key in outputs:
resp = requests.get(f"{host}/output/{key}")
url = f"{host}/output/{key}"
resp = requests.get(url)
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
images.append(Image.open(BytesIO(resp.content)))
else:
logger.warning("download request failed: %s", resp.status_code)
logger.warning("download request failed: %s: %s", url, resp.status_code)
raise TestError("error downloading image")
return images

View File

@ -2,7 +2,7 @@ import unittest
from PIL import Image
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.chain.tile import (
complete_tile,
generate_tile_grid,
@ -126,7 +126,11 @@ class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
)
self.assertEqual(blend[0].size, (64, 64))
@ -134,7 +138,11 @@ class TestProcessTileStack(unittest.TestCase):
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
)
self.assertEqual(blend[0].size, (72, 72))

View File

@ -1,4 +1,5 @@
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
import torch
@ -6,6 +7,7 @@ from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
@ -226,6 +228,30 @@ class BlendLoRATests(unittest.TestCase):
def test_node_dtype(self):
pass
@patch("onnx_web.convert.diffusion.lora.load")
@patch("onnx_web.convert.diffusion.lora.load_tensor")
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
base_name = "model.onnx"
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
model_type = "unet"
model_index = 2
xl = True
mock_load.return_value = MagicMock()
mock_load_tensor.return_value = MagicMock()
# Call the blend_loras function
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
# Assert that the InferenceSession is called with the correct arguments
mock_load.assert_called_once_with(base_name)
# Assert that the model is loaded successfully
self.assertEqual(blended_model, mock_load.return_value)
# Assert that the blending logic is executed correctly
# (assertions specific to the blending logic can be added here)
class BlendWeightsLoHATests(unittest.TestCase):
def test_blend_t1_t2(self):

View File

@ -1,11 +1,24 @@
import unittest
from os import path
import numpy as np
from PIL import Image
from onnx_web.image.source_filter import (
filter_model_path,
pil_to_cv2,
source_filter_canny,
source_filter_depth,
source_filter_face,
source_filter_gaussian,
source_filter_hed,
source_filter_mlsd,
source_filter_noise,
source_filter_none,
source_filter_normal,
source_filter_openpose,
source_filter_scribble,
source_filter_segment,
)
from onnx_web.server.context import ServerContext
@ -35,3 +48,119 @@ class SourceFilterNoiseTests(unittest.TestCase):
source = Image.new("RGB", dims)
result = source_filter_noise(server, source)
self.assertEqual(result.size, dims)
class PILToCV2Tests(unittest.TestCase):
def test_conversion(self):
dims = (64, 64)
source = Image.new("RGB", dims)
result = pil_to_cv2(source)
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.shape, (dims[1], dims[0], 3))
self.assertEqual(result.dtype, np.uint8)
class FilterModelPathTests(unittest.TestCase):
def test_filter_model_path(self):
server = ServerContext()
filter_name = "gaussian"
expected_path = path.join(server.model_path, "filter", filter_name)
result = filter_model_path(server, filter_name)
self.assertEqual(result, expected_path)
class SourceFilterFaceTests(unittest.TestCase): # Added new test class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_face(server, source)
self.assertEqual(result.size, dims)
class SourceFilterSegmentTests(
unittest.TestCase
): # Added SourceFilterSegmentTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_segment(server, source)
self.assertEqual(result.size, dims)
class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_mlsd(server, source)
self.assertEqual(result.size, (512, 512))
class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_normal(server, source)
# normal will resize inputs to 384x384
self.assertEqual(result.size, (384, 384))
class SourceFilterHEDTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_hed(server, source)
self.assertEqual(result.size, (512, 512))
class SourceFilterScribbleTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_scribble(server, source)
# scribble will resize inputs to 512x512
self.assertEqual(result.size, (512, 512))
class SourceFilterDepthTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_depth(server, source)
self.assertEqual(result.size, dims)
class SourceFilterCannyTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_canny(server, source)
self.assertEqual(result.size, dims)
class SourceFilterOpenPoseTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_openpose(server, source)
# openpose will resize inputs to 512x512
self.assertEqual(result.size, (512, 512))