fix(api): keep metadata from tiled stages
This commit is contained in:
parent
ce48ce2d10
commit
f785548fb8
|
@ -203,7 +203,7 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
if is_debug():
|
||||
for j, image in enumerate(tile_result.as_image()):
|
||||
for j, image in enumerate(tile_result.as_images()):
|
||||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
worker.set_tiles(current=progress[0], total=progress[1])
|
||||
|
@ -224,7 +224,7 @@ class ChainPipeline:
|
|||
|
||||
raise RetryException("exhausted retries on tile")
|
||||
|
||||
stage_results = process_tile_order(
|
||||
stage_result = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
stage_sources,
|
||||
tile,
|
||||
|
@ -233,8 +233,7 @@ class ChainPipeline:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
metadata = stage_sources.metadata
|
||||
stage_sources = StageResult(images=stage_results, metadata=metadata)
|
||||
stage_sources = stage_result
|
||||
else:
|
||||
logger.debug(
|
||||
"image does not contain sources and is within tile size of %s, running stage",
|
||||
|
|
|
@ -9,7 +9,7 @@ from PIL import Image
|
|||
|
||||
from ..image.noise_source import noise_source_histogram
|
||||
from ..params import Size, TileOrder
|
||||
from .result import StageResult
|
||||
from .result import ImageMetadata, StageResult
|
||||
|
||||
# from skimage.exposure import match_histograms
|
||||
|
||||
|
@ -266,7 +266,7 @@ def process_tile_stack(
|
|||
tile_generator: TileGenerator,
|
||||
overlap: float = 0.5,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
sources = stack.as_images()
|
||||
|
||||
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||
|
@ -276,6 +276,7 @@ def process_tile_stack(
|
|||
if not mask:
|
||||
tile_mask = None
|
||||
|
||||
metadata: List[ImageMetadata] = stack.metadata
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
tile_coords = tile_generator(width, height, tile, overlap)
|
||||
|
||||
|
@ -343,7 +344,6 @@ def process_tile_stack(
|
|||
)
|
||||
tile_mask = Image.new("L", (tile, tile), color=0)
|
||||
tile_mask.paste(base_mask, (left_margin, top_margin))
|
||||
|
||||
else:
|
||||
logger.debug("tiling normally")
|
||||
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
||||
|
@ -359,6 +359,9 @@ def process_tile_stack(
|
|||
if isinstance(tile_stack, list):
|
||||
tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata)
|
||||
|
||||
# metadata gets replaced rather than combined, since it should be the same for each tile
|
||||
# this will need to change if tiles can have individual metadata
|
||||
metadata = tile_stack.metadata
|
||||
tiles.append((left, top, tile_stack.as_images()))
|
||||
|
||||
lefts, tops, stacks = list(zip(*tiles))
|
||||
|
@ -371,7 +374,7 @@ def process_tile_stack(
|
|||
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
|
||||
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
||||
|
||||
return result
|
||||
return StageResult(images=result, metadata=metadata)
|
||||
|
||||
|
||||
def process_tile_order(
|
||||
|
@ -381,7 +384,7 @@ def process_tile_order(
|
|||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
if order == TileOrder.grid:
|
||||
logger.debug("using grid tile order with tile size: %s", tile)
|
||||
return process_tile_stack(
|
||||
|
|
Loading…
Reference in New Issue