fix(api): tile stages based on input image or size param
This commit is contained in:
parent
b8aef2cd32
commit
c9a1ace40b
|
@ -11,7 +11,7 @@ from ..server import ServerContext
|
|||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .tile import process_tile_order
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -149,13 +149,16 @@ class ChainPipeline:
|
|||
"running stage %s without source image, %s", name, kwargs.keys()
|
||||
)
|
||||
|
||||
if image is not None and (
|
||||
image.width > stage_params.tile_size
|
||||
or image.height > stage_params.tile_size
|
||||
if needs_tile(
|
||||
stage_pipe.max_tile,
|
||||
stage_params.tile_size,
|
||||
size=kwargs.get("size", None),
|
||||
source=image,
|
||||
):
|
||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
logger.info(
|
||||
"image larger than tile size of %s, tiling stage",
|
||||
stage_params.tile_size,
|
||||
tile,
|
||||
)
|
||||
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
|
@ -177,7 +180,7 @@ class ChainPipeline:
|
|||
image = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
image,
|
||||
stage_params.tile_size,
|
||||
tile,
|
||||
stage_params.outscale,
|
||||
[stage_tile],
|
||||
**kwargs,
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from typing import List, Protocol, Tuple
|
||||
from typing import List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..params import TileOrder
|
||||
from ..params import Size, TileOrder
|
||||
|
||||
# from skimage.exposure import match_histograms
|
||||
|
||||
|
@ -37,6 +37,23 @@ def complete_tile(
|
|||
return source
|
||||
|
||||
|
||||
def needs_tile(
|
||||
max_tile: int,
|
||||
stage_tile: int,
|
||||
size: Optional[Size] = None,
|
||||
source: Optional[Image.Image] = None,
|
||||
) -> bool:
|
||||
tile = min(max_tile, stage_tile)
|
||||
|
||||
if source is not None:
|
||||
return source.width > tile or source.height > tile
|
||||
|
||||
if size is not None:
|
||||
return size.width > tile or size.height > tile
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_tile_grads(
|
||||
left: int,
|
||||
top: int,
|
||||
|
|
Loading…
Reference in New Issue