1
0
Fork 0

fix(api): detect all mask keys, immediately bubble up cancellation errors

This commit is contained in:
Sean Sube 2023-12-04 18:44:58 -06:00
parent b29837d773
commit 95a62b17ed
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 16 additions and 2 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple
from PIL import Image
from ..errors import RetryException
from ..errors import CancelledException, RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
@ -146,7 +146,7 @@ class ChainPipeline:
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = "mask" in stage_kwargs or any(
must_tile = has_mask(stage_kwargs) or any(
[
needs_tile(
stage_pipe.max_tile,
@ -192,6 +192,10 @@ class ChainPipeline:
save_image(server, f"last-tile-{j}.png", image)
return tile_result
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled while tiling")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
@ -234,6 +238,10 @@ class ChainPipeline:
# does not like, so it throws
stage_sources = stage_result
break
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled during stage")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
@ -264,3 +272,9 @@ class ChainPipeline:
len(stage_sources),
)
return stage_sources
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
def has_mask(args: List[str]) -> bool:
return any([key in args for key in MASK_KEYS])