1
0
Fork 0

fix up highres metadata

This commit is contained in:
Sean Sube 2024-01-06 17:20:37 -06:00
parent ec48b27f93
commit 21666abf03
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 54 additions and 1 deletions

View File

@ -102,7 +102,9 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images)
metadata = [metadata.child(params) for metadata in sources.metadata]
metadata = [
metadata.child(params, metadata.size) for metadata in sources.metadata
]
return StageResult(images=outputs, metadata=metadata)
def steps(

View File

@ -0,0 +1,39 @@
from typing import Any
from ..params import HighresParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .result import StageResult
class EditMetadataStage:
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: StageResult,
*,
size: Size = None,
upscale: UpscaleParams = None,
highres: HighresParams = None,
note: str = None,
**kwargs,
) -> Any:
# Modify the source image's metadata using the provided parameters
for metadata in source.metadata:
if note is not None:
metadata.note = note
if size is not None:
metadata.size = size
if upscale is not None:
metadata.upscale = upscale
if highres is not None:
metadata.highres = highres
# Return the modified source image
return source

View File

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import Optional
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.edit_metadata import EditMetadataStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
@ -66,4 +67,11 @@ def stage_highres(
strength=highres.strength,
)
# add highres parameters to the image metadata
chain.stage(
EditMetadataStage(),
stage.with_args(outscale=1),
highres=highres,
)
return chain

View File

@ -28,6 +28,7 @@ class NetworkMetadata:
class ImageMetadata:
ancestors: List["ImageMetadata"]
note: str
params: ImageParams
size: Size
@ -70,6 +71,7 @@ class ImageMetadata:
self.loras = loras or []
self.models = models or []
self.ancestors = ancestors or []
self.note = ""
def child(
self,

View File

@ -9,6 +9,7 @@ from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .edit_metadata import EditMetadataStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
@ -38,6 +39,7 @@ CHAIN_STAGES = {
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"edit-metadata": EditMetadataStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,