diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 376c2af9..84d11393 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -7,7 +7,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -20,7 +20,7 @@ class BlendGridStage(BaseStage): _worker: WorkerContext, _server: ServerContext, _stage: StageParams, - _params: ImageParams, + params: ImageParams, sources: StageResult, *, height: int, @@ -53,7 +53,10 @@ class BlendGridStage(BaseStage): output.paste(images[n], (x * size.width, y * size.height)) result = StageResult(source=sources) - result.push_image(output, sources.metadata[0]) + result.push_image( + output, + ImageMetadata(params, Size(width, height), ancestors=[sources.metadata]), + ) return result def outputs( diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index b2350463..0d4a9a83 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -102,7 +102,8 @@ class BlendImg2ImgStage(BaseStage): outputs.extend(result.images) - return StageResult(images=outputs, metadata=sources.metadata) + metadata = [metadata.child(params) for metadata in sources.metadata] + return StageResult(images=outputs, metadata=metadata) def steps( self, diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 807c42f2..addaddca 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -27,14 +27,19 @@ class NetworkMetadata: class ImageMetadata: - border: Border - highres: HighresParams + ancestors: List["ImageMetadata"] params: ImageParams size: Size - upscale: UpscaleParams - inversions: Optional[List[NetworkMetadata]] - loras: Optional[List[NetworkMetadata]] - models: Optional[List[NetworkMetadata]] + + # models + inversions: List[NetworkMetadata] + loras: List[NetworkMetadata] + models: List[NetworkMetadata] + + # optional params + border: Optional[Border] + highres: Optional[HighresParams] + upscale: Optional[UpscaleParams] @staticmethod def unknown_image() -> "ImageMetadata": @@ -54,15 +59,40 @@ class ImageMetadata: inversions: Optional[List[NetworkMetadata]] = None, loras: Optional[List[NetworkMetadata]] = None, models: Optional[List[NetworkMetadata]] = None, + ancestors: Optional[List["ImageMetadata"]] = None, ) -> None: self.params = params self.size = size self.upscale = upscale self.border = border self.highres = highres - self.inversions = inversions - self.loras = loras - self.models = models + self.inversions = inversions or [] + self.loras = loras or [] + self.models = models or [] + self.ancestors = ancestors or [] + + def child( + self, + params: ImageParams, + size: Size, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, + inversions: Optional[List[NetworkMetadata]] = None, + loras: Optional[List[NetworkMetadata]] = None, + models: Optional[List[NetworkMetadata]] = None, + ) -> "ImageMetadata": + return ImageMetadata( + params, + size, + upscale, + border, + highres, + inversions, + loras, + models, + [self], + ) def get_model_hash( self, server: ServerContext, model: Optional[str] = None diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 5e0799c1..4d165ab1 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -117,4 +117,7 @@ class UpscaleRealESRGANStage(BaseStage): logger.info("final output image size: %s", output.shape) outputs.append(output) + for metadata in sources.metadata: + metadata.upscale = upscale + return StageResult(arrays=outputs, metadata=sources.metadata)