From 21666abf03f0c4ac3f907247e4b7769152c6c7fe Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 6 Jan 2024 17:20:37 -0600 Subject: [PATCH] fix up highres metadata --- api/onnx_web/chain/blend_img2img.py | 4 ++- api/onnx_web/chain/edit_metadata.py | 39 +++++++++++++++++++++++++++++ api/onnx_web/chain/edit_text.py | 0 api/onnx_web/chain/highres.py | 8 ++++++ api/onnx_web/chain/result.py | 2 ++ api/onnx_web/chain/stages.py | 2 ++ 6 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 api/onnx_web/chain/edit_metadata.py create mode 100644 api/onnx_web/chain/edit_text.py diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0d4a9a83..e6ee29a4 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -102,7 +102,9 @@ class BlendImg2ImgStage(BaseStage): outputs.extend(result.images) - metadata = [metadata.child(params) for metadata in sources.metadata] + metadata = [ + metadata.child(params, metadata.size) for metadata in sources.metadata + ] return StageResult(images=outputs, metadata=metadata) def steps( diff --git a/api/onnx_web/chain/edit_metadata.py b/api/onnx_web/chain/edit_metadata.py new file mode 100644 index 00000000..522ddb81 --- /dev/null +++ b/api/onnx_web/chain/edit_metadata.py @@ -0,0 +1,39 @@ +from typing import Any + +from ..params import HighresParams, ImageParams, Size, StageParams, UpscaleParams +from ..server import ServerContext +from ..worker import WorkerContext +from .result import StageResult + + +class EditMetadataStage: + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: StageResult, + *, + size: Size = None, + upscale: UpscaleParams = None, + highres: HighresParams = None, + note: str = None, + **kwargs, + ) -> Any: + # Modify the source image's metadata using the provided parameters + for metadata in source.metadata: + if note is not None: + metadata.note = note + + if size is not None: + metadata.size = size + + if upscale is not None: + metadata.upscale = upscale + + if highres is not None: + metadata.highres = highres + + # Return the modified source image + return source diff --git a/api/onnx_web/chain/edit_text.py b/api/onnx_web/chain/edit_text.py new file mode 100644 index 00000000..e69de29b diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 8c5f0fb1..5e6fd22e 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -2,6 +2,7 @@ from logging import getLogger from typing import Optional from ..chain.blend_img2img import BlendImg2ImgStage +from ..chain.edit_metadata import EditMetadataStage from ..chain.upscale import stage_upscale_correction from ..chain.upscale_simple import UpscaleSimpleStage from ..params import HighresParams, ImageParams, StageParams, UpscaleParams @@ -66,4 +67,11 @@ def stage_highres( strength=highres.strength, ) + # add highres parameters to the image metadata + chain.stage( + EditMetadataStage(), + stage.with_args(outscale=1), + highres=highres, + ) + return chain diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index addaddca..e8b40ea3 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -28,6 +28,7 @@ class NetworkMetadata: class ImageMetadata: ancestors: List["ImageMetadata"] + note: str params: ImageParams size: Size @@ -70,6 +71,7 @@ class ImageMetadata: self.loras = loras or [] self.models = models or [] self.ancestors = ancestors or [] + self.note = "" def child( self, diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 0b3e6359..1e094052 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -9,6 +9,7 @@ from .blend_linear import BlendLinearStage from .blend_mask import BlendMaskStage from .correct_codeformer import CorrectCodeformerStage from .correct_gfpgan import CorrectGFPGANStage +from .edit_metadata import EditMetadataStage from .persist_disk import PersistDiskStage from .persist_s3 import PersistS3Stage from .reduce_crop import ReduceCropStage @@ -38,6 +39,7 @@ CHAIN_STAGES = { "blend-mask": BlendMaskStage, "correct-codeformer": CorrectCodeformerStage, "correct-gfpgan": CorrectGFPGANStage, + "edit-metadata": EditMetadataStage, "persist-disk": PersistDiskStage, "persist-s3": PersistS3Stage, "reduce-crop": ReduceCropStage,