1
0
Fork 0

start inheriting metadata

This commit is contained in:
Sean Sube 2024-01-06 16:59:02 -06:00
parent 9e201fc94a
commit ec48b27f93
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 50 additions and 13 deletions

View File

@ -7,7 +7,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import ImageMetadata, StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,7 +20,7 @@ class BlendGridStage(BaseStage):
_worker: WorkerContext, _worker: WorkerContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, params: ImageParams,
sources: StageResult, sources: StageResult,
*, *,
height: int, height: int,
@ -53,7 +53,10 @@ class BlendGridStage(BaseStage):
output.paste(images[n], (x * size.width, y * size.height)) output.paste(images[n], (x * size.width, y * size.height))
result = StageResult(source=sources) result = StageResult(source=sources)
result.push_image(output, sources.metadata[0]) result.push_image(
output,
ImageMetadata(params, Size(width, height), ancestors=[sources.metadata]),
)
return result return result
def outputs( def outputs(

View File

@ -102,7 +102,8 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return StageResult(images=outputs, metadata=sources.metadata) metadata = [metadata.child(params) for metadata in sources.metadata]
return StageResult(images=outputs, metadata=metadata)
def steps( def steps(
self, self,

View File

@ -27,14 +27,19 @@ class NetworkMetadata:
class ImageMetadata: class ImageMetadata:
border: Border ancestors: List["ImageMetadata"]
highres: HighresParams
params: ImageParams params: ImageParams
size: Size size: Size
upscale: UpscaleParams
inversions: Optional[List[NetworkMetadata]] # models
loras: Optional[List[NetworkMetadata]] inversions: List[NetworkMetadata]
models: Optional[List[NetworkMetadata]] loras: List[NetworkMetadata]
models: List[NetworkMetadata]
# optional params
border: Optional[Border]
highres: Optional[HighresParams]
upscale: Optional[UpscaleParams]
@staticmethod @staticmethod
def unknown_image() -> "ImageMetadata": def unknown_image() -> "ImageMetadata":
@ -54,15 +59,40 @@ class ImageMetadata:
inversions: Optional[List[NetworkMetadata]] = None, inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None, loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None, models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None,
) -> None: ) -> None:
self.params = params self.params = params
self.size = size self.size = size
self.upscale = upscale self.upscale = upscale
self.border = border self.border = border
self.highres = highres self.highres = highres
self.inversions = inversions self.inversions = inversions or []
self.loras = loras self.loras = loras or []
self.models = models self.models = models or []
self.ancestors = ancestors or []
def child(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params,
size,
upscale,
border,
highres,
inversions,
loras,
models,
[self],
)
def get_model_hash( def get_model_hash(
self, server: ServerContext, model: Optional[str] = None self, server: ServerContext, model: Optional[str] = None

View File

@ -117,4 +117,7 @@ class UpscaleRealESRGANStage(BaseStage):
logger.info("final output image size: %s", output.shape) logger.info("final output image size: %s", output.shape)
outputs.append(output) outputs.append(output)
for metadata in sources.metadata:
metadata.upscale = upscale
return StageResult(arrays=outputs, metadata=sources.metadata) return StageResult(arrays=outputs, metadata=sources.metadata)