1
0
Fork 0

cancel unsafe jobs

This commit is contained in:
Sean Sube 2024-01-08 21:50:36 -06:00
parent a42d728006
commit 0215cb9ac6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 6 additions and 6 deletions

View File

@ -23,7 +23,7 @@ class BlendLinearStage(BaseStage):
*, *,
alpha: float, alpha: float,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")

View File

@ -3,6 +3,7 @@ from typing import Optional
from PIL import Image from PIL import Image
from ..errors import CancelledException
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
@ -51,7 +52,7 @@ class EditSafetyStage(BaseStage):
images = sources.as_images() images = sources.as_images()
results = [] results = []
for i, image in enumerate(images): for i, image in enumerate(images):
prompt = sources.metadata[i].prompt prompt = sources.metadata[i].params.prompt
check = nsfw_checker.check_for_nsfw(image, prompt=prompt) check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
if check.is_csam: if check.is_csam:
@ -64,7 +65,7 @@ class EditSafetyStage(BaseStage):
if is_csam: if is_csam:
logger.warning("blocking csam result") logger.warning("blocking csam result")
raise RuntimeError("csam detected") raise CancelledException("csam detected")
else: else:
return StageResult.from_images(results, metadata=sources.metadata) return StageResult.from_images(results, metadata=sources.metadata)
except ImportError: except ImportError:

View File

@ -1,8 +1,5 @@
from logging import getLogger from logging import getLogger
from .edit_safety import EditSafetyStage
from .edit_text import EditTextStage
from .base import BaseStage from .base import BaseStage
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
from .blend_denoise_localstd import BlendDenoiseLocalStdStage from .blend_denoise_localstd import BlendDenoiseLocalStdStage
@ -13,6 +10,8 @@ from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage from .correct_gfpgan import CorrectGFPGANStage
from .edit_metadata import EditMetadataStage from .edit_metadata import EditMetadataStage
from .edit_safety import EditSafetyStage
from .edit_text import EditTextStage
from .persist_disk import PersistDiskStage from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage from .reduce_crop import ReduceCropStage