fix(api): detect all mask keys, immediately bubble up cancellation errors
This commit is contained in:
parent
b29837d773
commit
95a62b17ed
|
@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..errors import RetryException
|
||||
from ..errors import CancelledException, RetryException
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
|
@ -146,7 +146,7 @@ class ChainPipeline:
|
|||
kwargs.pop("params")
|
||||
|
||||
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
||||
must_tile = "mask" in stage_kwargs or any(
|
||||
must_tile = has_mask(stage_kwargs) or any(
|
||||
[
|
||||
needs_tile(
|
||||
stage_pipe.max_tile,
|
||||
|
@ -192,6 +192,10 @@ class ChainPipeline:
|
|||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
return tile_result
|
||||
except CancelledException as err:
|
||||
worker.retries = 0
|
||||
logger.exception("job was cancelled while tiling")
|
||||
raise err
|
||||
except Exception:
|
||||
worker.retries = worker.retries - 1
|
||||
logger.exception(
|
||||
|
@ -234,6 +238,10 @@ class ChainPipeline:
|
|||
# does not like, so it throws
|
||||
stage_sources = stage_result
|
||||
break
|
||||
except CancelledException as err:
|
||||
worker.retries = 0
|
||||
logger.exception("job was cancelled during stage")
|
||||
raise err
|
||||
except Exception:
|
||||
worker.retries = worker.retries - 1
|
||||
logger.exception(
|
||||
|
@ -264,3 +272,9 @@ class ChainPipeline:
|
|||
len(stage_sources),
|
||||
)
|
||||
return stage_sources
|
||||
|
||||
|
||||
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
|
||||
|
||||
def has_mask(args: List[str]) -> bool:
|
||||
return any([key in args for key in MASK_KEYS])
|
||||
|
|
Loading…
Reference in New Issue