diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py index c0f1a0d5..35d535b5 100644 --- a/api/onnx_web/chain/edit_safety.py +++ b/api/onnx_web/chain/edit_safety.py @@ -65,7 +65,7 @@ class EditSafetyStage(BaseStage): if is_csam: logger.warning("blocking csam result") - raise CancelledException("csam detected") + raise CancelledException(reason="csam") else: return StageResult.from_images(results, metadata=sources.metadata) except ImportError: diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index e78d12bd..ad4b3e66 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -174,7 +174,7 @@ class ChainPipeline: worker.set_tiles(0) if must_tile: logger.info( - "image contains sources or is larger than tile size of %s, tiling stage", + "image has mask or is larger than tile size of %s, tiling stage", tile, ) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 971c6b0c..ac6a36c4 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -59,7 +59,9 @@ def add_safety_stage( if server.has_feature("horde-safety"): from ..chain.edit_safety import EditSafetyStage - pipeline.stage(EditSafetyStage(), StageParams()) + pipeline.stage( + EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile) + ) def run_txt2img_pipeline( diff --git a/api/onnx_web/errors.py b/api/onnx_web/errors.py index 4ab6f79b..ef7a768f 100644 --- a/api/onnx_web/errors.py +++ b/api/onnx_web/errors.py @@ -1,3 +1,6 @@ +from typing import Optional + + class RetryException(Exception): """ Used when a chain pipeline has run out of retries. @@ -11,7 +14,12 @@ class CancelledException(Exception): Used when a job has been cancelled and needs to stop. """ - pass + reason: Optional[str] + + def __init__(self, *args: object, reason: Optional[str] = None) -> None: + super().__init__(*args) + + self.reason = reason class RequestException(Exception): diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 83820e87..f838a019 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -1,7 +1,7 @@ from io import BytesIO from logging import getLogger from os import path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate @@ -117,8 +117,9 @@ def image_reply( stages: Progress = None, steps: Progress = None, tiles: Progress = None, - outputs: List[str] = None, - metadata: List[ImageMetadata] = None, + outputs: Optional[List[str]] = None, + metadata: Optional[List[ImageMetadata]] = None, + reason: Optional[str] = None, ) -> Dict[str, Any]: if queue is None: queue = EMPTY_PROGRESS @@ -141,6 +142,9 @@ def image_reply( "tiles": tiles.tojson(), } + if reason is not None: + data["reason"] = reason + if outputs is not None: if metadata is None: logger.error("metadata is required with outputs") @@ -705,6 +709,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): tiles=progress.tiles, outputs=outputs, metadata=metadata, + reason=progress.reason, ) ) else: diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 70ac11ad..fcbc5b7b 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional class JobStatus(str, Enum): @@ -64,7 +64,8 @@ class ProgressCommand: job: str job_type: str status: JobStatus - result: Any # really StageResult but that would be a very circular import + reason: Optional[str] + result: Optional[Any] # really StageResult but that would be a very circular import steps: Progress stages: Progress tiles: Progress @@ -79,6 +80,7 @@ class ProgressCommand: stages: Progress, tiles: Progress, result: Any = None, + reason: Optional[str] = None, ): self.job = job self.job_type = job_type @@ -90,6 +92,7 @@ class ProgressCommand: self.stages = stages self.tiles = tiles self.result = result + self.reason = reason class JobCommand: diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index da68ddc6..1f5c57a0 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -218,7 +218,7 @@ class WorkerContext: block=False, ) - def fail(self) -> None: + def fail(self, reason: Optional[str] = None) -> None: if self.job is None: logger.warning("setting failure without an active job") else: @@ -232,6 +232,7 @@ class WorkerContext: steps=self.steps, stages=self.stages, tiles=self.tiles, + reason=reason, # TODO: should this include partial results? ) self.progress.put( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 062d58e4..260f6674 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -5,7 +5,7 @@ from sys import exit from setproctitle import setproctitle -from ..errors import RetryException +from ..errors import CancelledException, RetryException from ..server import ServerContext, apply_patches from ..torch_before_ort import get_available_providers from .context import WorkerContext @@ -82,13 +82,16 @@ def worker_main( logger.exception("value error in worker, exiting") worker.fail() return exit(EXIT_ERROR) + except CancelledException as e: + logger.warning("job was cancelled, continuing") + worker.fail(e.reason or "cancelled") except Exception as e: e_str = str(e) # restart the worker on memory errors for e_mem in MEMORY_ERRORS: if e_mem in e_str: logger.error("detected out-of-memory error, exiting: %s", e) - worker.fail() + worker.fail("oom") return exit(EXIT_MEMORY) # carry on for other errors diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index c822004d..5126d573 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -89,7 +89,7 @@ export const UNKNOWN_ERROR = `${IMAGE_ERROR}unknown`; export function getImageErrorReason(image: FailedJobResponse | UnknownJobResponse) { if (image.status === JobStatus.FAILED) { - const error = image.error; + const error = image.reason; if (doesExist(error) && error.startsWith(ANY_ERROR)) { return error; } diff --git a/gui/src/strings/de.ts b/gui/src/strings/de.ts index 5e4ae8de..d2fd9e37 100644 --- a/gui/src/strings/de.ts +++ b/gui/src/strings/de.ts @@ -13,7 +13,9 @@ export const I18N_STRINGS_DE = { convert: '', error: { image: { + csam: '', memory: '', + oom: '', unknown: '', }, inpaint: { diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index e53ce931..ad342461 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -8,7 +8,9 @@ export const I18N_STRINGS_EN = { convert: 'Save and Convert', error: { image: { + csam: 'CSAM detected', memory: 'Memory error generating image', + oom: 'Out of memory generating image', unknown: 'Unknown error generating image', }, inpaint: { diff --git a/gui/src/strings/es.ts b/gui/src/strings/es.ts index 44925135..0c68942b 100644 --- a/gui/src/strings/es.ts +++ b/gui/src/strings/es.ts @@ -13,7 +13,9 @@ export const I18N_STRINGS_ES = { convert: '', error: { image: { + csam: '', memory: '', + oom: '', unknown: '', }, inpaint: { diff --git a/gui/src/strings/fr.ts b/gui/src/strings/fr.ts index a3b453df..5fa19c0d 100644 --- a/gui/src/strings/fr.ts +++ b/gui/src/strings/fr.ts @@ -13,7 +13,9 @@ export const I18N_STRINGS_FR = { convert: '', error: { image: { + csam: '', memory: '', + oom: '', unknown: '', }, inpaint: { diff --git a/gui/src/types/api-v2.ts b/gui/src/types/api-v2.ts index 093847de..3595a7c6 100644 --- a/gui/src/types/api-v2.ts +++ b/gui/src/types/api-v2.ts @@ -76,7 +76,7 @@ export interface CancelledJobResponse extends BaseJobResponse { export interface FailedJobResponse extends BaseJobResponse { status: JobStatus.FAILED; - error?: string; + reason?: string; } /**