from logging import getLogger from typing import Any, Optional from PIL import Image from ..errors import CancelledException from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..server.model_cache import ModelTypes from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult logger = getLogger(__name__) class EditSafetyStage(BaseStage): max_tile = SizeChart.max def load(self, server: ServerContext) -> Any: from horde_safety.deep_danbooru_model import get_deep_danbooru_model from horde_safety.interrogate import get_interrogator_no_blip from horde_safety.nsfw_checker_class import NSFWChecker cache_key = ("horde-safety",) cache_checker = server.cache.get(ModelTypes.safety, cache_key) if cache_checker is not None: return cache_checker # set up interrogator = get_interrogator_no_blip() deep_danbooru_model = get_deep_danbooru_model() nsfw_checker = NSFWChecker( interrogator, deep_danbooru_model, ) server.cache.set(ModelTypes.safety, cache_key) return nsfw_checker def run( self, _worker: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, sources: StageResult, *, callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("checking results using horde safety") # keep these within run to make this sort of like a plugin or peer dependency try: # set up nsfw_checker = self.load(server) block_nsfw = server.has_feature("horde-safety-nsfw") is_csam = False # check each output images = sources.as_images() results = [] for i, image in enumerate(images): prompt = sources.metadata[i].params.prompt check = nsfw_checker.check_for_nsfw(image, prompt=prompt) if check.is_csam: logger.warning("flagging csam result: %s, %s", i, prompt) is_csam = True continue if check.is_nsfw and block_nsfw: logger.warning("blocking nsfw image: %s, %s", i, prompt) results.append(Image.new("RGB", image.size, color="black")) continue results.append(image) if is_csam: # TODO: save metadata to a report file logger.warning("blocking csam result") raise CancelledException(reason="csam") else: return StageResult.from_images(results, metadata=sources.metadata) except ImportError: logger.warning("horde safety not installed") return StageResult.empty()