diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py new file mode 100644 index 00000000..3f3c038b --- /dev/null +++ b/api/onnx_web/chain/edit_safety.py @@ -0,0 +1,72 @@ +from logging import getLogger +from typing import Optional + +from PIL import Image + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class EditSafetyStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + sources: StageResult, + *, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + logger.info("checking results using horde safety") + + # keep these within run to make this sort of like a plugin or peer dependency + try: + from horde_safety.deep_danbooru_model import get_deep_danbooru_model + from horde_safety.interrogate import get_interrogator_no_blip + from horde_safety.nsfw_checker_class import NSFWChecker + + # set up + block_nsfw = server.has_feature("horde-safety-nsfw") + + interrogator = get_interrogator_no_blip() + deep_danbooru_model = get_deep_danbooru_model() + + nsfw_checker = NSFWChecker( + interrogator, + deep_danbooru_model, + ) + + # individual flags from NSFWResult + is_csam = False + + images = sources.as_images() + results = [] + for i, image in enumerate(images): + prompt = sources.metadata[i].prompt + check = nsfw_checker.check_for_nsfw(image, prompt=prompt) + + if check.is_csam: + logger.warning("flagging csam result: %s, %s", i, prompt) + is_csam = True + + if check.is_nsfw and block_nsfw: + logger.warning("blocking nsfw image: %s, %s", i, prompt) + results.append(Image.new("RGB", image.size, color="black")) + + if is_csam: + logger.warning("blocking csam result") + raise RuntimeError("csam detected") + else: + return StageResult.from_images(results, metadata=sources.metadata) + except ImportError: + logger.warning("horde safety not installed") + return StageResult.empty() diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 1e094052..9ad44613 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -1,5 +1,8 @@ from logging import getLogger +from .edit_safety import EditSafetyStage +from .edit_text import EditTextStage + from .base import BaseStage from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage from .blend_denoise_localstd import BlendDenoiseLocalStdStage @@ -40,6 +43,8 @@ CHAIN_STAGES = { "correct-codeformer": CorrectCodeformerStage, "correct-gfpgan": CorrectGFPGANStage, "edit-metadata": EditMetadataStage, + "edit-safety": EditSafetyStage, + "edit-text": EditTextStage, "persist-disk": PersistDiskStage, "persist-s3": PersistS3Stage, "reduce-crop": ReduceCropStage, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index bb90e0ab..971c6b0c 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -52,6 +52,16 @@ def get_highres_tile( return params.unet_tile +def add_safety_stage( + server: ServerContext, + pipeline: ChainPipeline, +) -> None: + if server.has_feature("horde-safety"): + from ..chain.edit_safety import EditSafetyStage + + pipeline.stage(EditSafetyStage(), StageParams()) + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, @@ -110,6 +120,8 @@ def run_txt2img_pipeline( upscale=after_upscale, ) + add_safety_stage(server, chain) + # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback(reset=True) @@ -209,6 +221,8 @@ def run_img2img_pipeline( chain=chain, ) + add_safety_stage(server, chain) + # run and append the filtered source progress = worker.get_progress_callback(reset=True) images = chain( @@ -382,6 +396,8 @@ def run_inpaint_pipeline( chain=chain, ) + add_safety_stage(server, chain) + # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback(reset=True) @@ -462,6 +478,8 @@ def run_upscale_pipeline( chain=chain, ) + add_safety_stage(server, chain) + # run and save progress = worker.get_progress_callback(reset=True) images = chain( @@ -515,6 +533,8 @@ def run_blend_pipeline( chain=chain, ) + add_safety_stage(server, chain) + # run and save progress = worker.get_progress_callback(reset=True) images = chain(