cancel unsafe jobs
This commit is contained in:
parent
a42d728006
commit
0215cb9ac6
|
@ -23,7 +23,7 @@ class BlendLinearStage(BaseStage):
|
||||||
*,
|
*,
|
||||||
alpha: float,
|
alpha: float,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..errors import CancelledException
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
@ -51,7 +52,7 @@ class EditSafetyStage(BaseStage):
|
||||||
images = sources.as_images()
|
images = sources.as_images()
|
||||||
results = []
|
results = []
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
prompt = sources.metadata[i].prompt
|
prompt = sources.metadata[i].params.prompt
|
||||||
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
||||||
|
|
||||||
if check.is_csam:
|
if check.is_csam:
|
||||||
|
@ -64,7 +65,7 @@ class EditSafetyStage(BaseStage):
|
||||||
|
|
||||||
if is_csam:
|
if is_csam:
|
||||||
logger.warning("blocking csam result")
|
logger.warning("blocking csam result")
|
||||||
raise RuntimeError("csam detected")
|
raise CancelledException("csam detected")
|
||||||
else:
|
else:
|
||||||
return StageResult.from_images(results, metadata=sources.metadata)
|
return StageResult.from_images(results, metadata=sources.metadata)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from .edit_safety import EditSafetyStage
|
|
||||||
from .edit_text import EditTextStage
|
|
||||||
|
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
|
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
|
||||||
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
|
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
|
||||||
|
@ -13,6 +10,8 @@ from .blend_mask import BlendMaskStage
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
from .correct_gfpgan import CorrectGFPGANStage
|
from .correct_gfpgan import CorrectGFPGANStage
|
||||||
from .edit_metadata import EditMetadataStage
|
from .edit_metadata import EditMetadataStage
|
||||||
|
from .edit_safety import EditSafetyStage
|
||||||
|
from .edit_text import EditTextStage
|
||||||
from .persist_disk import PersistDiskStage
|
from .persist_disk import PersistDiskStage
|
||||||
from .persist_s3 import PersistS3Stage
|
from .persist_s3 import PersistS3Stage
|
||||||
from .reduce_crop import ReduceCropStage
|
from .reduce_crop import ReduceCropStage
|
||||||
|
|
Loading…
Reference in New Issue