1
0
Fork 0

provide empty metadata for pipelines with source images

This commit is contained in:
Sean Sube 2024-01-06 14:17:26 -06:00
parent 3e5a95548b
commit 3a647ad9bd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 212 additions and 23 deletions

1
api/.gitignore vendored
View File

@ -7,6 +7,7 @@ entry.py
*.swp *.swp
*.pyc *.pyc
.cache/
__pycache__/ __pycache__/
dist/ dist/
htmlcov/ htmlcov/

View File

@ -36,6 +36,14 @@ class ImageMetadata:
loras: Optional[List[NetworkMetadata]] loras: Optional[List[NetworkMetadata]]
models: Optional[List[NetworkMetadata]] models: Optional[List[NetworkMetadata]]
@staticmethod
def unknown_image() -> "ImageMetadata":
UNKNOWN_STR = "unknown"
return ImageMetadata(
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
Size(0, 0),
)
def __init__( def __init__(
self, self,
params: ImageParams, params: ImageParams,
@ -212,22 +220,18 @@ class StageResult:
return StageResult(images=[]) return StageResult(images=[])
@staticmethod @staticmethod
def from_arrays( def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(arrays=arrays, metadata=metadata) return StageResult(arrays=arrays, metadata=metadata)
@staticmethod @staticmethod
def from_images( def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(images=images, metadata=metadata) return StageResult(images=images, metadata=metadata)
def __init__( def __init__(
self, self,
arrays: Optional[List[np.ndarray]] = None, arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None, images: Optional[List[Image.Image]] = None,
metadata: Optional[List[ImageMetadata]] = None, metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
source: Optional[Any] = None, source: Optional[Any] = None,
) -> None: ) -> None:
data_provided = sum( data_provided = sum(

View File

@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import ImageMetadata, StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -50,7 +50,8 @@ class SourceS3Stage(BaseStage):
logger.exception("error loading image from S3") logger.exception("error loading image from S3")
# TODO: attempt to load metadata from s3 or load it from the image itself (exif data) # TODO: attempt to load metadata from s3 or load it from the image itself (exif data)
return StageResult(outputs) metadata = [ImageMetadata.unknown_image()] * len(outputs)
return StageResult(outputs, metadata=metadata)
def outputs( def outputs(
self, self,

View File

@ -9,7 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import ImageMetadata, StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -42,7 +42,8 @@ class SourceURLStage(BaseStage):
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return StageResult(images=outputs) metadata = [ImageMetadata.unknown_image()] * len(outputs)
return StageResult(images=outputs, metadata=metadata)
def outputs( def outputs(
self, self,

View File

@ -13,7 +13,7 @@ from ..chain import (
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.highres import stage_highres from ..chain.highres import stage_highres
from ..chain.result import StageResult from ..chain.result import ImageMetadata, StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image, save_result from ..output import save_image, save_result
@ -212,7 +212,11 @@ def run_img2img_pipeline(
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
callback=progress,
) )
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
@ -385,7 +389,9 @@ def run_inpaint_pipeline(
worker, worker,
server, server,
params, params,
StageResult(images=[source]), # TODO: load metadata from source image StageResult(
images=[source], metadata=[ImageMetadata.unknown_image()]
), # TODO: load metadata from source image
callback=progress, callback=progress,
latents=latents, latents=latents,
) )
@ -459,7 +465,11 @@ def run_upscale_pipeline(
# run and save # run and save
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
callback=progress,
) )
save_result(server, images, worker.job) save_result(server, images, worker.job)
@ -508,7 +518,13 @@ def run_blend_pipeline(
# run and save # run and save
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=sources), callback=progress worker,
server,
params,
StageResult(
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
),
callback=progress,
) )
save_result(server, images, worker.job) save_result(server, images, worker.job)

View File

@ -64,6 +64,8 @@ def save_result(
) -> List[str]: ) -> List[str]:
images = result.as_images() images = result.as_images()
outputs = make_output_names(server, base_name, len(images)) outputs = make_output_names(server, base_name, len(images))
logger.debug("saving %s images: %s", len(images), outputs)
results = [] results = []
for image, metadata, filename in zip(images, result.metadata, outputs): for image, metadata, filename in zip(images, result.metadata, outputs):
results.append( results.append(

View File

@ -485,10 +485,10 @@ def check_ready(host: str, key: str) -> bool:
raise TestError("error getting image status") raise TestError("error getting image status")
def check_outputs(host: str, key: str) -> List[str]: def check_outputs(host: str, key: str) -> List[str]:
resp = requests.get(f"{host}/api/ready?output={key}") resp = requests.get(f"{host}/api/job/status?jobs={key}")
if resp.status_code == 200: if resp.status_code == 200:
json = resp.json() json = resp.json()
outputs = json.get("outputs", []) outputs = json[0].get("outputs", [])
return outputs return outputs
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text) logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
@ -499,12 +499,13 @@ def download_images(host: str, key: str) -> List[Image.Image]:
images = [] images = []
for key in outputs: for key in outputs:
resp = requests.get(f"{host}/output/{key}") url = f"{host}/output/{key}"
resp = requests.get(url)
if resp.status_code == 200: if resp.status_code == 200:
logger.debug("downloading image: %s", key) logger.debug("downloading image: %s", key)
images.append(Image.open(BytesIO(resp.content))) images.append(Image.open(BytesIO(resp.content)))
else: else:
logger.warning("download request failed: %s", resp.status_code) logger.warning("download request failed: %s: %s", url, resp.status_code)
raise TestError("error downloading image") raise TestError("error downloading image")
return images return images

View File

@ -2,7 +2,7 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.result import StageResult from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.chain.tile import ( from onnx_web.chain.tile import (
complete_tile, complete_tile,
generate_tile_grid, generate_tile_grid,
@ -126,7 +126,11 @@ class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self): def test_grid_full(self):
source = Image.new("RGB", (64, 64)) source = Image.new("RGB", (64, 64))
blend = process_tile_stack( blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
) )
self.assertEqual(blend[0].size, (64, 64)) self.assertEqual(blend[0].size, (64, 64))
@ -134,7 +138,11 @@ class TestProcessTileStack(unittest.TestCase):
def test_grid_partial(self): def test_grid_partial(self):
source = Image.new("RGB", (72, 72)) source = Image.new("RGB", (72, 72))
blend = process_tile_stack( blend = process_tile_stack(
StageResult(images=[source]), 32, 1, [], generate_tile_grid StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
) )
self.assertEqual(blend[0].size, (72, 72)) self.assertEqual(blend[0].size, (72, 72))

View File

@ -1,4 +1,5 @@
import unittest import unittest
from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import torch import torch
@ -6,6 +7,7 @@ from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import ( from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm, blend_node_conv_gemm,
blend_node_matmul, blend_node_matmul,
blend_weights_loha, blend_weights_loha,
@ -226,6 +228,30 @@ class BlendLoRATests(unittest.TestCase):
def test_node_dtype(self): def test_node_dtype(self):
pass pass
@patch("onnx_web.convert.diffusion.lora.load")
@patch("onnx_web.convert.diffusion.lora.load_tensor")
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
base_name = "model.onnx"
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
model_type = "unet"
model_index = 2
xl = True
mock_load.return_value = MagicMock()
mock_load_tensor.return_value = MagicMock()
# Call the blend_loras function
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
# Assert that the InferenceSession is called with the correct arguments
mock_load.assert_called_once_with(base_name)
# Assert that the model is loaded successfully
self.assertEqual(blended_model, mock_load.return_value)
# Assert that the blending logic is executed correctly
# (assertions specific to the blending logic can be added here)
class BlendWeightsLoHATests(unittest.TestCase): class BlendWeightsLoHATests(unittest.TestCase):
def test_blend_t1_t2(self): def test_blend_t1_t2(self):

View File

@ -1,11 +1,24 @@
import unittest import unittest
from os import path
import numpy as np
from PIL import Image from PIL import Image
from onnx_web.image.source_filter import ( from onnx_web.image.source_filter import (
filter_model_path,
pil_to_cv2,
source_filter_canny,
source_filter_depth,
source_filter_face,
source_filter_gaussian, source_filter_gaussian,
source_filter_hed,
source_filter_mlsd,
source_filter_noise, source_filter_noise,
source_filter_none, source_filter_none,
source_filter_normal,
source_filter_openpose,
source_filter_scribble,
source_filter_segment,
) )
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
@ -35,3 +48,119 @@ class SourceFilterNoiseTests(unittest.TestCase):
source = Image.new("RGB", dims) source = Image.new("RGB", dims)
result = source_filter_noise(server, source) result = source_filter_noise(server, source)
self.assertEqual(result.size, dims) self.assertEqual(result.size, dims)
class PILToCV2Tests(unittest.TestCase):
def test_conversion(self):
dims = (64, 64)
source = Image.new("RGB", dims)
result = pil_to_cv2(source)
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.shape, (dims[1], dims[0], 3))
self.assertEqual(result.dtype, np.uint8)
class FilterModelPathTests(unittest.TestCase):
def test_filter_model_path(self):
server = ServerContext()
filter_name = "gaussian"
expected_path = path.join(server.model_path, "filter", filter_name)
result = filter_model_path(server, filter_name)
self.assertEqual(result, expected_path)
class SourceFilterFaceTests(unittest.TestCase): # Added new test class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_face(server, source)
self.assertEqual(result.size, dims)
class SourceFilterSegmentTests(
unittest.TestCase
): # Added SourceFilterSegmentTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_segment(server, source)
self.assertEqual(result.size, dims)
class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_mlsd(server, source)
self.assertEqual(result.size, (512, 512))
class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_normal(server, source)
# normal will resize inputs to 384x384
self.assertEqual(result.size, (384, 384))
class SourceFilterHEDTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_hed(server, source)
self.assertEqual(result.size, (512, 512))
class SourceFilterScribbleTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_scribble(server, source)
# scribble will resize inputs to 512x512
self.assertEqual(result.size, (512, 512))
class SourceFilterDepthTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_depth(server, source)
self.assertEqual(result.size, dims)
class SourceFilterCannyTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_canny(server, source)
self.assertEqual(result.size, dims)
class SourceFilterOpenPoseTests(
unittest.TestCase
): # Added SourceFilterScribbleTests class
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_openpose(server, source)
# openpose will resize inputs to 512x512
self.assertEqual(result.size, (512, 512))