provide empty metadata for pipelines with source images
This commit is contained in:
parent
3e5a95548b
commit
3a647ad9bd
|
@ -7,6 +7,7 @@ entry.py
|
||||||
*.swp
|
*.swp
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
||||||
|
.cache/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
dist/
|
dist/
|
||||||
htmlcov/
|
htmlcov/
|
||||||
|
|
|
@ -36,6 +36,14 @@ class ImageMetadata:
|
||||||
loras: Optional[List[NetworkMetadata]]
|
loras: Optional[List[NetworkMetadata]]
|
||||||
models: Optional[List[NetworkMetadata]]
|
models: Optional[List[NetworkMetadata]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unknown_image() -> "ImageMetadata":
|
||||||
|
UNKNOWN_STR = "unknown"
|
||||||
|
return ImageMetadata(
|
||||||
|
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
|
||||||
|
Size(0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -212,22 +220,18 @@ class StageResult:
|
||||||
return StageResult(images=[])
|
return StageResult(images=[])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_arrays(
|
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
|
||||||
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
|
|
||||||
):
|
|
||||||
return StageResult(arrays=arrays, metadata=metadata)
|
return StageResult(arrays=arrays, metadata=metadata)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_images(
|
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
|
||||||
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
|
|
||||||
):
|
|
||||||
return StageResult(images=images, metadata=metadata)
|
return StageResult(images=images, metadata=metadata)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
arrays: Optional[List[np.ndarray]] = None,
|
arrays: Optional[List[np.ndarray]] = None,
|
||||||
images: Optional[List[Image.Image]] = None,
|
images: Optional[List[Image.Image]] = None,
|
||||||
metadata: Optional[List[ImageMetadata]] = None,
|
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
|
||||||
source: Optional[Any] = None,
|
source: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
data_provided = sum(
|
data_provided = sum(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import ImageMetadata, StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ class SourceS3Stage(BaseStage):
|
||||||
logger.exception("error loading image from S3")
|
logger.exception("error loading image from S3")
|
||||||
|
|
||||||
# TODO: attempt to load metadata from s3 or load it from the image itself (exif data)
|
# TODO: attempt to load metadata from s3 or load it from the image itself (exif data)
|
||||||
return StageResult(outputs)
|
metadata = [ImageMetadata.unknown_image()] * len(outputs)
|
||||||
|
return StageResult(outputs, metadata=metadata)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import ImageMetadata, StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -42,7 +42,8 @@ class SourceURLStage(BaseStage):
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
metadata = [ImageMetadata.unknown_image()] * len(outputs)
|
||||||
|
return StageResult(images=outputs, metadata=metadata)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ..chain import (
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
from ..chain.highres import stage_highres
|
from ..chain.highres import stage_highres
|
||||||
from ..chain.result import StageResult
|
from ..chain.result import ImageMetadata, StageResult
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image, save_result
|
from ..output import save_image, save_result
|
||||||
|
@ -212,7 +212,11 @@ def run_img2img_pipeline(
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=[source]), callback=progress
|
worker,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
|
@ -385,7 +389,9 @@ def run_inpaint_pipeline(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(images=[source]), # TODO: load metadata from source image
|
StageResult(
|
||||||
|
images=[source], metadata=[ImageMetadata.unknown_image()]
|
||||||
|
), # TODO: load metadata from source image
|
||||||
callback=progress,
|
callback=progress,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
@ -459,7 +465,11 @@ def run_upscale_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=[source]), callback=progress
|
worker,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
@ -508,7 +518,13 @@ def run_blend_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=sources), callback=progress
|
worker,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
StageResult(
|
||||||
|
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
|
||||||
|
),
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
|
@ -64,6 +64,8 @@ def save_result(
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
images = result.as_images()
|
images = result.as_images()
|
||||||
outputs = make_output_names(server, base_name, len(images))
|
outputs = make_output_names(server, base_name, len(images))
|
||||||
|
logger.debug("saving %s images: %s", len(images), outputs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for image, metadata, filename in zip(images, result.metadata, outputs):
|
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||||
results.append(
|
results.append(
|
||||||
|
|
|
@ -485,10 +485,10 @@ def check_ready(host: str, key: str) -> bool:
|
||||||
raise TestError("error getting image status")
|
raise TestError("error getting image status")
|
||||||
|
|
||||||
def check_outputs(host: str, key: str) -> List[str]:
|
def check_outputs(host: str, key: str) -> List[str]:
|
||||||
resp = requests.get(f"{host}/api/ready?output={key}")
|
resp = requests.get(f"{host}/api/job/status?jobs={key}")
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
json = resp.json()
|
json = resp.json()
|
||||||
outputs = json.get("outputs", [])
|
outputs = json[0].get("outputs", [])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
|
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
|
||||||
|
@ -499,12 +499,13 @@ def download_images(host: str, key: str) -> List[Image.Image]:
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
for key in outputs:
|
for key in outputs:
|
||||||
resp = requests.get(f"{host}/output/{key}")
|
url = f"{host}/output/{key}"
|
||||||
|
resp = requests.get(url)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
logger.debug("downloading image: %s", key)
|
logger.debug("downloading image: %s", key)
|
||||||
images.append(Image.open(BytesIO(resp.content)))
|
images.append(Image.open(BytesIO(resp.content)))
|
||||||
else:
|
else:
|
||||||
logger.warning("download request failed: %s", resp.status_code)
|
logger.warning("download request failed: %s: %s", url, resp.status_code)
|
||||||
raise TestError("error downloading image")
|
raise TestError("error downloading image")
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.result import StageResult
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||||
from onnx_web.chain.tile import (
|
from onnx_web.chain.tile import (
|
||||||
complete_tile,
|
complete_tile,
|
||||||
generate_tile_grid,
|
generate_tile_grid,
|
||||||
|
@ -126,7 +126,11 @@ class TestProcessTileStack(unittest.TestCase):
|
||||||
def test_grid_full(self):
|
def test_grid_full(self):
|
||||||
source = Image.new("RGB", (64, 64))
|
source = Image.new("RGB", (64, 64))
|
||||||
blend = process_tile_stack(
|
blend = process_tile_stack(
|
||||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||||
|
32,
|
||||||
|
1,
|
||||||
|
[],
|
||||||
|
generate_tile_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(blend[0].size, (64, 64))
|
self.assertEqual(blend[0].size, (64, 64))
|
||||||
|
@ -134,7 +138,11 @@ class TestProcessTileStack(unittest.TestCase):
|
||||||
def test_grid_partial(self):
|
def test_grid_partial(self):
|
||||||
source = Image.new("RGB", (72, 72))
|
source = Image.new("RGB", (72, 72))
|
||||||
blend = process_tile_stack(
|
blend = process_tile_stack(
|
||||||
StageResult(images=[source]), 32, 1, [], generate_tile_grid
|
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||||
|
32,
|
||||||
|
1,
|
||||||
|
[],
|
||||||
|
generate_tile_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(blend[0].size, (72, 72))
|
self.assertEqual(blend[0].size, (72, 72))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,6 +7,7 @@ from onnx import GraphProto, ModelProto, NodeProto
|
||||||
from onnx.numpy_helper import from_array
|
from onnx.numpy_helper import from_array
|
||||||
|
|
||||||
from onnx_web.convert.diffusion.lora import (
|
from onnx_web.convert.diffusion.lora import (
|
||||||
|
blend_loras,
|
||||||
blend_node_conv_gemm,
|
blend_node_conv_gemm,
|
||||||
blend_node_matmul,
|
blend_node_matmul,
|
||||||
blend_weights_loha,
|
blend_weights_loha,
|
||||||
|
@ -226,6 +228,30 @@ class BlendLoRATests(unittest.TestCase):
|
||||||
def test_node_dtype(self):
|
def test_node_dtype(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.diffusion.lora.load")
|
||||||
|
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
||||||
|
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
||||||
|
base_name = "model.onnx"
|
||||||
|
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
||||||
|
model_type = "unet"
|
||||||
|
model_index = 2
|
||||||
|
xl = True
|
||||||
|
|
||||||
|
mock_load.return_value = MagicMock()
|
||||||
|
mock_load_tensor.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Call the blend_loras function
|
||||||
|
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
|
||||||
|
|
||||||
|
# Assert that the InferenceSession is called with the correct arguments
|
||||||
|
mock_load.assert_called_once_with(base_name)
|
||||||
|
|
||||||
|
# Assert that the model is loaded successfully
|
||||||
|
self.assertEqual(blended_model, mock_load.return_value)
|
||||||
|
|
||||||
|
# Assert that the blending logic is executed correctly
|
||||||
|
# (assertions specific to the blending logic can be added here)
|
||||||
|
|
||||||
|
|
||||||
class BlendWeightsLoHATests(unittest.TestCase):
|
class BlendWeightsLoHATests(unittest.TestCase):
|
||||||
def test_blend_t1_t2(self):
|
def test_blend_t1_t2(self):
|
||||||
|
|
|
@ -1,11 +1,24 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.image.source_filter import (
|
from onnx_web.image.source_filter import (
|
||||||
|
filter_model_path,
|
||||||
|
pil_to_cv2,
|
||||||
|
source_filter_canny,
|
||||||
|
source_filter_depth,
|
||||||
|
source_filter_face,
|
||||||
source_filter_gaussian,
|
source_filter_gaussian,
|
||||||
|
source_filter_hed,
|
||||||
|
source_filter_mlsd,
|
||||||
source_filter_noise,
|
source_filter_noise,
|
||||||
source_filter_none,
|
source_filter_none,
|
||||||
|
source_filter_normal,
|
||||||
|
source_filter_openpose,
|
||||||
|
source_filter_scribble,
|
||||||
|
source_filter_segment,
|
||||||
)
|
)
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
|
|
||||||
|
@ -35,3 +48,119 @@ class SourceFilterNoiseTests(unittest.TestCase):
|
||||||
source = Image.new("RGB", dims)
|
source = Image.new("RGB", dims)
|
||||||
result = source_filter_noise(server, source)
|
result = source_filter_noise(server, source)
|
||||||
self.assertEqual(result.size, dims)
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class PILToCV2Tests(unittest.TestCase):
|
||||||
|
def test_conversion(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = pil_to_cv2(source)
|
||||||
|
self.assertIsInstance(result, np.ndarray)
|
||||||
|
self.assertEqual(result.shape, (dims[1], dims[0], 3))
|
||||||
|
self.assertEqual(result.dtype, np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterModelPathTests(unittest.TestCase):
|
||||||
|
def test_filter_model_path(self):
|
||||||
|
server = ServerContext()
|
||||||
|
filter_name = "gaussian"
|
||||||
|
expected_path = path.join(server.model_path, "filter", filter_name)
|
||||||
|
result = filter_model_path(server, filter_name)
|
||||||
|
self.assertEqual(result, expected_path)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterFaceTests(unittest.TestCase): # Added new test class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_face(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterSegmentTests(
|
||||||
|
unittest.TestCase
|
||||||
|
): # Added SourceFilterSegmentTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_segment(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_mlsd(server, source)
|
||||||
|
self.assertEqual(result.size, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_normal(server, source)
|
||||||
|
|
||||||
|
# normal will resize inputs to 384x384
|
||||||
|
self.assertEqual(result.size, (384, 384))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterHEDTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_hed(server, source)
|
||||||
|
self.assertEqual(result.size, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterScribbleTests(
|
||||||
|
unittest.TestCase
|
||||||
|
): # Added SourceFilterScribbleTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_scribble(server, source)
|
||||||
|
|
||||||
|
# scribble will resize inputs to 512x512
|
||||||
|
self.assertEqual(result.size, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterDepthTests(
|
||||||
|
unittest.TestCase
|
||||||
|
): # Added SourceFilterScribbleTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_depth(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterCannyTests(
|
||||||
|
unittest.TestCase
|
||||||
|
): # Added SourceFilterScribbleTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_canny(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterOpenPoseTests(
|
||||||
|
unittest.TestCase
|
||||||
|
): # Added SourceFilterScribbleTests class
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_openpose(server, source)
|
||||||
|
|
||||||
|
# openpose will resize inputs to 512x512
|
||||||
|
self.assertEqual(result.size, (512, 512))
|
||||||
|
|
Loading…
Reference in New Issue