1
0
Fork 0

fix(api): make with_args methods accept and ignore extra args

This commit is contained in:
Sean Sube 2024-01-28 15:15:24 -06:00
parent f8bdd76bea
commit 7c6ae6a094
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 42 additions and 16 deletions

View File

@ -187,14 +187,15 @@ class ChainPipeline:
) -> List[Image.Image]: ) -> List[Image.Image]:
for _i in range(worker.retries): for _i in range(worker.retries):
try: try:
stage_input = StageResult(
images=source_tile, metadata=stage_sources.metadata
)
tile_result = stage_pipe.run( tile_result = stage_pipe.run(
worker, worker,
server, server,
stage_params, stage_params,
per_stage_params, per_stage_params,
StageResult( stage_input,
images=source_tile, metadata=stage_sources.metadata
),
tile_mask=tile_mask, tile_mask=tile_mask,
callback=callback, callback=callback,
dims=dims, dims=dims,

View File

@ -263,7 +263,9 @@ class ImageMetadata:
loras: Optional[List[NetworkMetadata]] = None, loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None, models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None, ancestors: Optional[List["ImageMetadata"]] = None,
**kwargs,
) -> "ImageMetadata": ) -> "ImageMetadata":
logger.info("ignoring extra kwargs for metadata: %s", kwargs)
return ImageMetadata( return ImageMetadata(
params or self.params, params or self.params,
size or self.size, size or self.size,

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext, ProgressCallback from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import ImageMetadata, StageResult from .result import ImageMetadata, StageResult

View File

@ -57,12 +57,20 @@ class Border:
"bottom": self.bottom, "bottom": self.bottom,
} }
def with_args(self, **kwargs): def with_args(
self,
left: Optional[int] = None,
right: Optional[int] = None,
top: Optional[int] = None,
bottom: Optional[int] = None,
**kwargs,
):
logger.debug("ignoring extra kwargs for border: %s", kwargs)
return Border( return Border(
kwargs.get("left", self.left), left or self.left,
kwargs.get("right", self.right), right or self.right,
kwargs.get("top", self.top), top or self.top,
kwargs.get("bottom", self.bottom), bottom or self.bottom,
) )
@classmethod @classmethod
@ -111,10 +119,16 @@ class Size:
"height": self.height, "height": self.height,
} }
def with_args(self, **kwargs): def with_args(
self,
height: Optional[int] = None,
width: Optional[int] = None,
**kwargs,
):
logger.debug("ignoring extra kwargs for size: %s", kwargs)
return Size( return Size(
kwargs.get("width", self.width), width or self.width,
kwargs.get("height", self.height), height or self.height,
) )
@ -395,13 +409,18 @@ class StageParams:
def with_args( def with_args(
self, self,
name: Optional[str] = None,
outscale: Optional[int] = None,
tile_order: Optional[str] = None,
tile_size: Optional[int] = None,
**kwargs, **kwargs,
): ):
logger.debug("ignoring extra kwargs for stage: %s", kwargs)
return StageParams( return StageParams(
name=kwargs.get("name", self.name), name=name or self.name,
outscale=kwargs.get("outscale", self.outscale), outscale=outscale or self.outscale,
tile_order=kwargs.get("tile_order", self.tile_order), tile_order=tile_order or self.tile_order,
tile_size=kwargs.get("tile_size", self.tile_size), tile_size=tile_size or self.tile_size,
) )
@ -493,7 +512,9 @@ class UpscaleParams:
pre_pad: Optional[int] = None, pre_pad: Optional[int] = None,
tile_pad: Optional[int] = None, tile_pad: Optional[int] = None,
upscale_order: Optional[UpscaleOrder] = None, upscale_order: Optional[UpscaleOrder] = None,
**kwargs,
): ):
logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
return UpscaleParams( return UpscaleParams(
upscale_model=upscale_model or self.upscale_model, upscale_model=upscale_model or self.upscale_model,
correction_model=correction_model or self.correction_model, correction_model=correction_model or self.correction_model,
@ -558,7 +579,9 @@ class HighresParams:
strength: Optional[float] = None, strength: Optional[float] = None,
method: Optional[UpscaleMethod] = None, method: Optional[UpscaleMethod] = None,
iterations: Optional[int] = None, iterations: Optional[int] = None,
**kwargs,
): ):
logger.debug("ignoring extra kwargs for highres: %s", kwargs)
return HighresParams( return HighresParams(
enabled=enabled or self.enabled, enabled=enabled or self.enabled,
scale=scale or self.scale, scale=scale or self.scale,