From 0215cb9ac61a2fa7d4464bdf2a49d252189125d2 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 8 Jan 2024 21:50:36 -0600 Subject: [PATCH] cancel unsafe jobs --- api/onnx_web/chain/blend_linear.py | 2 +- api/onnx_web/chain/edit_safety.py | 5 +++-- api/onnx_web/chain/stages.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 0b5e85db..5d1baa9e 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -23,7 +23,7 @@ class BlendLinearStage(BaseStage): *, alpha: float, stage_source: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("blending source images using linear interpolation") diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py index 3f3c038b..c0f1a0d5 100644 --- a/api/onnx_web/chain/edit_safety.py +++ b/api/onnx_web/chain/edit_safety.py @@ -3,6 +3,7 @@ from typing import Optional from PIL import Image +from ..errors import CancelledException from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -51,7 +52,7 @@ class EditSafetyStage(BaseStage): images = sources.as_images() results = [] for i, image in enumerate(images): - prompt = sources.metadata[i].prompt + prompt = sources.metadata[i].params.prompt check = nsfw_checker.check_for_nsfw(image, prompt=prompt) if check.is_csam: @@ -64,7 +65,7 @@ class EditSafetyStage(BaseStage): if is_csam: logger.warning("blocking csam result") - raise RuntimeError("csam detected") + raise CancelledException("csam detected") else: return StageResult.from_images(results, metadata=sources.metadata) except ImportError: diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 9ad44613..9fc4bd9a 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -1,8 +1,5 @@ from logging import getLogger -from .edit_safety import EditSafetyStage -from .edit_text import EditTextStage - from .base import BaseStage from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage from .blend_denoise_localstd import BlendDenoiseLocalStdStage @@ -13,6 +10,8 @@ from .blend_mask import BlendMaskStage from .correct_codeformer import CorrectCodeformerStage from .correct_gfpgan import CorrectGFPGANStage from .edit_metadata import EditMetadataStage +from .edit_safety import EditSafetyStage +from .edit_text import EditTextStage from .persist_disk import PersistDiskStage from .persist_s3 import PersistS3Stage from .reduce_crop import ReduceCropStage