provide empty metadata for pipelines with source images
This commit is contained in:
parent
3e5a95548b
commit
3a647ad9bd
|
@ -7,6 +7,7 @@ entry.py
|
|||
*.swp
|
||||
*.pyc
|
||||
|
||||
.cache/
|
||||
__pycache__/
|
||||
dist/
|
||||
htmlcov/
|
||||
|
|
|
@ -36,6 +36,14 @@ class ImageMetadata:
|
|||
loras: Optional[List[NetworkMetadata]]
|
||||
models: Optional[List[NetworkMetadata]]
|
||||
|
||||
@staticmethod
|
||||
def unknown_image() -> "ImageMetadata":
|
||||
UNKNOWN_STR = "unknown"
|
||||
return ImageMetadata(
|
||||
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
|
||||
Size(0, 0),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: ImageParams,
|
||||
|
@ -212,22 +220,18 @@ class StageResult:
|
|||
return StageResult(images=[])
|
||||
|
||||
@staticmethod
|
||||
def from_arrays(
|
||||
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
|
||||
):
|
||||
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
|
||||
return StageResult(arrays=arrays, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def from_images(
|
||||
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
|
||||
):
|
||||
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
|
||||
return StageResult(images=images, metadata=metadata)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arrays: Optional[List[np.ndarray]] = None,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
metadata: Optional[List[ImageMetadata]] = None,
|
||||
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
|
||||
source: Optional[Any] = None,
|
||||
) -> None:
|
||||
data_provided = sum(
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
from .result import ImageMetadata, StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -50,7 +50,8 @@ class SourceS3Stage(BaseStage):
|
|||
logger.exception("error loading image from S3")
|
||||
|
||||
# TODO: attempt to load metadata from s3 or load it from the image itself (exif data)
|
||||
return StageResult(outputs)
|
||||
metadata = [ImageMetadata.unknown_image()] * len(outputs)
|
||||
return StageResult(outputs, metadata=metadata)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
from .result import ImageMetadata, StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -42,7 +42,8 @@ class SourceURLStage(BaseStage):
|
|||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
metadata = [ImageMetadata.unknown_image()] * len(outputs)
|
||||
return StageResult(images=outputs, metadata=metadata)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..chain import (
|
|||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.highres import stage_highres
|
||||
from ..chain.result import StageResult
|
||||
from ..chain.result import ImageMetadata, StageResult
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image, save_result
|
||||
|
@ -212,7 +212,11 @@ def run_img2img_pipeline(
|
|||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
|
@ -385,7 +389,9 @@ def run_inpaint_pipeline(
|
|||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source]), # TODO: load metadata from source image
|
||||
StageResult(
|
||||
images=[source], metadata=[ImageMetadata.unknown_image()]
|
||||
), # TODO: load metadata from source image
|
||||
callback=progress,
|
||||
latents=latents,
|
||||
)
|
||||
|
@ -459,7 +465,11 @@ def run_upscale_pipeline(
|
|||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
save_result(server, images, worker.job)
|
||||
|
@ -508,7 +518,13 @@ def run_blend_pipeline(
|
|||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=sources), callback=progress
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(
|
||||
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
|
||||
),
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
save_result(server, images, worker.job)
|
||||
|
|
|
@ -64,6 +64,8 @@ def save_result(
|
|||
) -> List[str]:
|
||||
images = result.as_images()
|
||||
outputs = make_output_names(server, base_name, len(images))
|
||||
logger.debug("saving %s images: %s", len(images), outputs)
|
||||
|
||||
results = []
|
||||
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||
results.append(
|
||||
|
|
|
@ -485,10 +485,10 @@ def check_ready(host: str, key: str) -> bool:
|
|||
raise TestError("error getting image status")
|
||||
|
||||
def check_outputs(host: str, key: str) -> List[str]:
|
||||
resp = requests.get(f"{host}/api/ready?output={key}")
|
||||
resp = requests.get(f"{host}/api/job/status?jobs={key}")
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
outputs = json.get("outputs", [])
|
||||
outputs = json[0].get("outputs", [])
|
||||
return outputs
|
||||
|
||||
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
|
||||
|
@ -499,12 +499,13 @@ def download_images(host: str, key: str) -> List[Image.Image]:
|
|||
|
||||
images = []
|
||||
for key in outputs:
|
||||
resp = requests.get(f"{host}/output/{key}")
|
||||
url = f"{host}/output/{key}"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code == 200:
|
||||
logger.debug("downloading image: %s", key)
|
||||
images.append(Image.open(BytesIO(resp.content)))
|
||||
else:
|
||||
logger.warning("download request failed: %s", resp.status_code)
|
||||
logger.warning("download request failed: %s: %s", url, resp.status_code)
|
||||
raise TestError("error downloading image")
|
||||
|
||||
return images
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||
from onnx_web.chain.tile import (
|
||||
complete_tile,
|
||||
generate_tile_grid,
|
||||
|
@ -126,7 +126,11 @@ class TestProcessTileStack(unittest.TestCase):
|
|||
def test_grid_full(self):
|
||||
source = Image.new("RGB", (64, 64))
|
||||
blend = process_tile_stack(
|
||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
32,
|
||||
1,
|
||||
[],
|
||||
generate_tile_grid,
|
||||
)
|
||||
|
||||
self.assertEqual(blend[0].size, (64, 64))
|
||||
|
@ -134,7 +138,11 @@ class TestProcessTileStack(unittest.TestCase):
|
|||
def test_grid_partial(self):
|
||||
source = Image.new("RGB", (72, 72))
|
||||
blend = process_tile_stack(
|
||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
32,
|
||||
1,
|
||||
[],
|
||||
generate_tile_grid,
|
||||
)
|
||||
|
||||
self.assertEqual(blend[0].size, (72, 72))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -6,6 +7,7 @@ from onnx import GraphProto, ModelProto, NodeProto
|
|||
from onnx.numpy_helper import from_array
|
||||
|
||||
from onnx_web.convert.diffusion.lora import (
|
||||
blend_loras,
|
||||
blend_node_conv_gemm,
|
||||
blend_node_matmul,
|
||||
blend_weights_loha,
|
||||
|
@ -226,6 +228,30 @@ class BlendLoRATests(unittest.TestCase):
|
|||
def test_node_dtype(self):
|
||||
pass
|
||||
|
||||
@patch("onnx_web.convert.diffusion.lora.load")
|
||||
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
||||
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
||||
base_name = "model.onnx"
|
||||
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
||||
model_type = "unet"
|
||||
model_index = 2
|
||||
xl = True
|
||||
|
||||
mock_load.return_value = MagicMock()
|
||||
mock_load_tensor.return_value = MagicMock()
|
||||
|
||||
# Call the blend_loras function
|
||||
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
|
||||
|
||||
# Assert that the InferenceSession is called with the correct arguments
|
||||
mock_load.assert_called_once_with(base_name)
|
||||
|
||||
# Assert that the model is loaded successfully
|
||||
self.assertEqual(blended_model, mock_load.return_value)
|
||||
|
||||
# Assert that the blending logic is executed correctly
|
||||
# (assertions specific to the blending logic can be added here)
|
||||
|
||||
|
||||
class BlendWeightsLoHATests(unittest.TestCase):
|
||||
def test_blend_t1_t2(self):
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
import unittest
|
||||
from os import path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.image.source_filter import (
|
||||
filter_model_path,
|
||||
pil_to_cv2,
|
||||
source_filter_canny,
|
||||
source_filter_depth,
|
||||
source_filter_face,
|
||||
source_filter_gaussian,
|
||||
source_filter_hed,
|
||||
source_filter_mlsd,
|
||||
source_filter_noise,
|
||||
source_filter_none,
|
||||
source_filter_normal,
|
||||
source_filter_openpose,
|
||||
source_filter_scribble,
|
||||
source_filter_segment,
|
||||
)
|
||||
from onnx_web.server.context import ServerContext
|
||||
|
||||
|
@ -35,3 +48,119 @@ class SourceFilterNoiseTests(unittest.TestCase):
|
|||
source = Image.new("RGB", dims)
|
||||
result = source_filter_noise(server, source)
|
||||
self.assertEqual(result.size, dims)
|
||||
|
||||
|
||||
class PILToCV2Tests(unittest.TestCase):
|
||||
def test_conversion(self):
|
||||
dims = (64, 64)
|
||||
source = Image.new("RGB", dims)
|
||||
result = pil_to_cv2(source)
|
||||
self.assertIsInstance(result, np.ndarray)
|
||||
self.assertEqual(result.shape, (dims[1], dims[0], 3))
|
||||
self.assertEqual(result.dtype, np.uint8)
|
||||
|
||||
|
||||
class FilterModelPathTests(unittest.TestCase):
|
||||
def test_filter_model_path(self):
|
||||
server = ServerContext()
|
||||
filter_name = "gaussian"
|
||||
expected_path = path.join(server.model_path, "filter", filter_name)
|
||||
result = filter_model_path(server, filter_name)
|
||||
self.assertEqual(result, expected_path)
|
||||
|
||||
|
||||
class SourceFilterFaceTests(unittest.TestCase): # Added new test class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_face(server, source)
|
||||
self.assertEqual(result.size, dims)
|
||||
|
||||
|
||||
class SourceFilterSegmentTests(
|
||||
unittest.TestCase
|
||||
): # Added SourceFilterSegmentTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_segment(server, source)
|
||||
self.assertEqual(result.size, dims)
|
||||
|
||||
|
||||
class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_mlsd(server, source)
|
||||
self.assertEqual(result.size, (512, 512))
|
||||
|
||||
|
||||
class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_normal(server, source)
|
||||
|
||||
# normal will resize inputs to 384x384
|
||||
self.assertEqual(result.size, (384, 384))
|
||||
|
||||
|
||||
class SourceFilterHEDTests(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_hed(server, source)
|
||||
self.assertEqual(result.size, (512, 512))
|
||||
|
||||
|
||||
class SourceFilterScribbleTests(
|
||||
unittest.TestCase
|
||||
): # Added SourceFilterScribbleTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_scribble(server, source)
|
||||
|
||||
# scribble will resize inputs to 512x512
|
||||
self.assertEqual(result.size, (512, 512))
|
||||
|
||||
|
||||
class SourceFilterDepthTests(
|
||||
unittest.TestCase
|
||||
): # Added SourceFilterScribbleTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_depth(server, source)
|
||||
self.assertEqual(result.size, dims)
|
||||
|
||||
|
||||
class SourceFilterCannyTests(
|
||||
unittest.TestCase
|
||||
): # Added SourceFilterScribbleTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_canny(server, source)
|
||||
self.assertEqual(result.size, dims)
|
||||
|
||||
|
||||
class SourceFilterOpenPoseTests(
|
||||
unittest.TestCase
|
||||
): # Added SourceFilterScribbleTests class
|
||||
def test_basic(self):
|
||||
dims = (64, 64)
|
||||
server = ServerContext()
|
||||
source = Image.new("RGB", dims)
|
||||
result = source_filter_openpose(server, source)
|
||||
|
||||
# openpose will resize inputs to 512x512
|
||||
self.assertEqual(result.size, (512, 512))
|
||||
|
|
Loading…
Reference in New Issue