diff --git a/api/.gitignore b/api/.gitignore index 165cb300..1f387ff6 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -7,6 +7,7 @@ entry.py *.swp *.pyc +.cache/ __pycache__/ dist/ htmlcov/ diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 57f8619e..807c42f2 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -36,6 +36,14 @@ class ImageMetadata: loras: Optional[List[NetworkMetadata]] models: Optional[List[NetworkMetadata]] + @staticmethod + def unknown_image() -> "ImageMetadata": + UNKNOWN_STR = "unknown" + return ImageMetadata( + ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0), + Size(0, 0), + ) + def __init__( self, params: ImageParams, @@ -212,22 +220,18 @@ class StageResult: return StageResult(images=[]) @staticmethod - def from_arrays( - arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None - ): + def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]): return StageResult(arrays=arrays, metadata=metadata) @staticmethod - def from_images( - images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None - ): + def from_images(images: List[Image.Image], metadata: List[ImageMetadata]): return StageResult(images=images, metadata=metadata) def __init__( self, arrays: Optional[List[np.ndarray]] = None, images: Optional[List[Image.Image]] = None, - metadata: Optional[List[ImageMetadata]] = None, + metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional source: Optional[Any] = None, ) -> None: data_provided = sum( diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 7e5666d2..a34d087f 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -50,7 +50,8 @@ class SourceS3Stage(BaseStage): logger.exception("error loading image from S3") # TODO: attempt to load metadata from s3 or load it from the image itself (exif data) - return StageResult(outputs) + metadata = [ImageMetadata.unknown_image()] * len(outputs) + return StageResult(outputs, metadata=metadata) def outputs( self, diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 7fe158bf..60a3ca4f 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -42,7 +42,8 @@ class SourceURLStage(BaseStage): logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return StageResult(images=outputs) + metadata = [ImageMetadata.unknown_image()] * len(outputs) + return StageResult(images=outputs, metadata=metadata) def outputs( self, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index eed650d2..9b9cfce2 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -13,7 +13,7 @@ from ..chain import ( UpscaleOutpaintStage, ) from ..chain.highres import stage_highres -from ..chain.result import StageResult +from ..chain.result import ImageMetadata, StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image from ..output import save_image, save_result @@ -212,7 +212,11 @@ def run_img2img_pipeline( # run and append the filtered source progress = worker.get_progress_callback(reset=True) images = chain( - worker, server, params, StageResult(images=[source]), callback=progress + worker, + server, + params, + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + callback=progress, ) if source_filter is not None and source_filter != "none": @@ -385,7 +389,9 @@ def run_inpaint_pipeline( worker, server, params, - StageResult(images=[source]), # TODO: load metadata from source image + StageResult( + images=[source], metadata=[ImageMetadata.unknown_image()] + ), # TODO: load metadata from source image callback=progress, latents=latents, ) @@ -459,7 +465,11 @@ def run_upscale_pipeline( # run and save progress = worker.get_progress_callback(reset=True) images = chain( - worker, server, params, StageResult(images=[source]), callback=progress + worker, + server, + params, + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + callback=progress, ) save_result(server, images, worker.job) @@ -508,7 +518,13 @@ def run_blend_pipeline( # run and save progress = worker.get_progress_callback(reset=True) images = chain( - worker, server, params, StageResult(images=sources), callback=progress + worker, + server, + params, + StageResult( + images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources) + ), + callback=progress, ) save_result(server, images, worker.job) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 0e9a32c4..49106ac7 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -64,6 +64,8 @@ def save_result( ) -> List[str]: images = result.as_images() outputs = make_output_names(server, base_name, len(images)) + logger.debug("saving %s images: %s", len(images), outputs) + results = [] for image, metadata, filename in zip(images, result.metadata, outputs): results.append( diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index c281cc10..7372cbfd 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -485,10 +485,10 @@ def check_ready(host: str, key: str) -> bool: raise TestError("error getting image status") def check_outputs(host: str, key: str) -> List[str]: - resp = requests.get(f"{host}/api/ready?output={key}") + resp = requests.get(f"{host}/api/job/status?jobs={key}") if resp.status_code == 200: json = resp.json() - outputs = json.get("outputs", []) + outputs = json[0].get("outputs", []) return outputs logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text) @@ -499,12 +499,13 @@ def download_images(host: str, key: str) -> List[Image.Image]: images = [] for key in outputs: - resp = requests.get(f"{host}/output/{key}") + url = f"{host}/output/{key}" + resp = requests.get(url) if resp.status_code == 200: logger.debug("downloading image: %s", key) images.append(Image.open(BytesIO(resp.content))) else: - logger.warning("download request failed: %s", resp.status_code) + logger.warning("download request failed: %s: %s", url, resp.status_code) raise TestError("error downloading image") return images diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index c27cb077..caf88335 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -2,7 +2,7 @@ import unittest from PIL import Image -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult from onnx_web.chain.tile import ( complete_tile, generate_tile_grid, @@ -126,7 +126,11 @@ class TestProcessTileStack(unittest.TestCase): def test_grid_full(self): source = Image.new("RGB", (64, 64)) blend = process_tile_stack( - StageResult(images=[source]), 32, 1, [], generate_tile_grid + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + 32, + 1, + [], + generate_tile_grid, ) self.assertEqual(blend[0].size, (64, 64)) @@ -134,7 +138,11 @@ class TestProcessTileStack(unittest.TestCase): def test_grid_partial(self): source = Image.new("RGB", (72, 72)) blend = process_tile_stack( - StageResult(images=[source]), 32, 1, [], generate_tile_grid + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + 32, + 1, + [], + generate_tile_grid, ) self.assertEqual(blend[0].size, (72, 72)) diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index a39b0b8b..17dc3b6a 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import MagicMock, patch import numpy as np import torch @@ -6,6 +7,7 @@ from onnx import GraphProto, ModelProto, NodeProto from onnx.numpy_helper import from_array from onnx_web.convert.diffusion.lora import ( + blend_loras, blend_node_conv_gemm, blend_node_matmul, blend_weights_loha, @@ -226,6 +228,30 @@ class BlendLoRATests(unittest.TestCase): def test_node_dtype(self): pass + @patch("onnx_web.convert.diffusion.lora.load") + @patch("onnx_web.convert.diffusion.lora.load_tensor") + def test_blend_loras_load_str(self, mock_load_tensor, mock_load): + base_name = "model.onnx" + loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)] + model_type = "unet" + model_index = 2 + xl = True + + mock_load.return_value = MagicMock() + mock_load_tensor.return_value = MagicMock() + + # Call the blend_loras function + blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl) + + # Assert that the InferenceSession is called with the correct arguments + mock_load.assert_called_once_with(base_name) + + # Assert that the model is loaded successfully + self.assertEqual(blended_model, mock_load.return_value) + + # Assert that the blending logic is executed correctly + # (assertions specific to the blending logic can be added here) + class BlendWeightsLoHATests(unittest.TestCase): def test_blend_t1_t2(self): diff --git a/api/tests/image/test_source_filter.py b/api/tests/image/test_source_filter.py index fb44073e..85dc2c5c 100644 --- a/api/tests/image/test_source_filter.py +++ b/api/tests/image/test_source_filter.py @@ -1,11 +1,24 @@ import unittest +from os import path +import numpy as np from PIL import Image from onnx_web.image.source_filter import ( + filter_model_path, + pil_to_cv2, + source_filter_canny, + source_filter_depth, + source_filter_face, source_filter_gaussian, + source_filter_hed, + source_filter_mlsd, source_filter_noise, source_filter_none, + source_filter_normal, + source_filter_openpose, + source_filter_scribble, + source_filter_segment, ) from onnx_web.server.context import ServerContext @@ -35,3 +48,119 @@ class SourceFilterNoiseTests(unittest.TestCase): source = Image.new("RGB", dims) result = source_filter_noise(server, source) self.assertEqual(result.size, dims) + + +class PILToCV2Tests(unittest.TestCase): + def test_conversion(self): + dims = (64, 64) + source = Image.new("RGB", dims) + result = pil_to_cv2(source) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (dims[1], dims[0], 3)) + self.assertEqual(result.dtype, np.uint8) + + +class FilterModelPathTests(unittest.TestCase): + def test_filter_model_path(self): + server = ServerContext() + filter_name = "gaussian" + expected_path = path.join(server.model_path, "filter", filter_name) + result = filter_model_path(server, filter_name) + self.assertEqual(result, expected_path) + + +class SourceFilterFaceTests(unittest.TestCase): # Added new test class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_face(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterSegmentTests( + unittest.TestCase +): # Added SourceFilterSegmentTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_segment(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_mlsd(server, source) + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_normal(server, source) + + # normal will resize inputs to 384x384 + self.assertEqual(result.size, (384, 384)) + + +class SourceFilterHEDTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_hed(server, source) + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterScribbleTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_scribble(server, source) + + # scribble will resize inputs to 512x512 + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterDepthTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_depth(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterCannyTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_canny(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterOpenPoseTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_openpose(server, source) + + # openpose will resize inputs to 512x512 + self.assertEqual(result.size, (512, 512))