1
0
Fork 0

start inheriting metadata

This commit is contained in:
Sean Sube 2024-01-06 16:59:02 -06:00
parent 9e201fc94a
commit ec48b27f93
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 50 additions and 13 deletions

View File

@ -7,7 +7,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
from .result import ImageMetadata, StageResult
logger = getLogger(__name__)
@ -20,7 +20,7 @@ class BlendGridStage(BaseStage):
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
params: ImageParams,
sources: StageResult,
*,
height: int,
@ -53,7 +53,10 @@ class BlendGridStage(BaseStage):
output.paste(images[n], (x * size.width, y * size.height))
result = StageResult(source=sources)
result.push_image(output, sources.metadata[0])
result.push_image(
output,
ImageMetadata(params, Size(width, height), ancestors=[sources.metadata]),
)
return result
def outputs(

View File

@ -102,7 +102,8 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images)
return StageResult(images=outputs, metadata=sources.metadata)
metadata = [metadata.child(params) for metadata in sources.metadata]
return StageResult(images=outputs, metadata=metadata)
def steps(
self,

View File

@ -27,14 +27,19 @@ class NetworkMetadata:
class ImageMetadata:
border: Border
highres: HighresParams
ancestors: List["ImageMetadata"]
params: ImageParams
size: Size
upscale: UpscaleParams
inversions: Optional[List[NetworkMetadata]]
loras: Optional[List[NetworkMetadata]]
models: Optional[List[NetworkMetadata]]
# models
inversions: List[NetworkMetadata]
loras: List[NetworkMetadata]
models: List[NetworkMetadata]
# optional params
border: Optional[Border]
highres: Optional[HighresParams]
upscale: Optional[UpscaleParams]
@staticmethod
def unknown_image() -> "ImageMetadata":
@ -54,15 +59,40 @@ class ImageMetadata:
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None,
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
self.inversions = inversions
self.loras = loras
self.models = models
self.inversions = inversions or []
self.loras = loras or []
self.models = models or []
self.ancestors = ancestors or []
def child(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params,
size,
upscale,
border,
highres,
inversions,
loras,
models,
[self],
)
def get_model_hash(
self, server: ServerContext, model: Optional[str] = None

View File

@ -117,4 +117,7 @@ class UpscaleRealESRGANStage(BaseStage):
logger.info("final output image size: %s", output.shape)
outputs.append(output)
for metadata in sources.metadata:
metadata.upscale = upscale
return StageResult(arrays=outputs, metadata=sources.metadata)