2024-01-09 03:37:44 +00:00
|
|
|
from logging import getLogger
|
2024-01-14 15:59:19 +00:00
|
|
|
from typing import Any, Optional
|
2024-01-09 03:37:44 +00:00
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
2024-01-09 03:50:36 +00:00
|
|
|
from ..errors import CancelledException
|
2024-01-09 03:37:44 +00:00
|
|
|
from ..params import ImageParams, SizeChart, StageParams
|
|
|
|
from ..server import ServerContext
|
2024-01-14 15:59:19 +00:00
|
|
|
from ..server.model_cache import ModelTypes
|
2024-01-09 03:37:44 +00:00
|
|
|
from ..worker import ProgressCallback, WorkerContext
|
|
|
|
from .base import BaseStage
|
|
|
|
from .result import StageResult
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class EditSafetyStage(BaseStage):
|
|
|
|
max_tile = SizeChart.max
|
|
|
|
|
2024-01-14 15:59:19 +00:00
|
|
|
def load(self, server: ServerContext) -> Any:
|
|
|
|
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
|
|
|
from horde_safety.interrogate import get_interrogator_no_blip
|
|
|
|
from horde_safety.nsfw_checker_class import NSFWChecker
|
|
|
|
|
|
|
|
cache_key = ("horde-safety",)
|
|
|
|
cache_checker = server.cache.get(ModelTypes.safety, cache_key)
|
|
|
|
if cache_checker is not None:
|
|
|
|
return cache_checker
|
|
|
|
|
|
|
|
# set up
|
|
|
|
interrogator = get_interrogator_no_blip()
|
|
|
|
deep_danbooru_model = get_deep_danbooru_model()
|
|
|
|
|
|
|
|
nsfw_checker = NSFWChecker(
|
|
|
|
interrogator,
|
|
|
|
deep_danbooru_model,
|
|
|
|
)
|
|
|
|
|
|
|
|
server.cache.set(ModelTypes.safety, cache_key)
|
|
|
|
|
|
|
|
return nsfw_checker
|
|
|
|
|
2024-01-09 03:37:44 +00:00
|
|
|
def run(
|
|
|
|
self,
|
|
|
|
_worker: WorkerContext,
|
|
|
|
server: ServerContext,
|
|
|
|
_stage: StageParams,
|
|
|
|
_params: ImageParams,
|
|
|
|
sources: StageResult,
|
|
|
|
*,
|
|
|
|
callback: Optional[ProgressCallback] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> StageResult:
|
|
|
|
logger.info("checking results using horde safety")
|
|
|
|
|
|
|
|
# keep these within run to make this sort of like a plugin or peer dependency
|
|
|
|
try:
|
|
|
|
# set up
|
2024-01-14 15:59:19 +00:00
|
|
|
nsfw_checker = self.load(server)
|
2024-01-09 03:37:44 +00:00
|
|
|
block_nsfw = server.has_feature("horde-safety-nsfw")
|
|
|
|
is_csam = False
|
|
|
|
|
2024-01-14 15:59:19 +00:00
|
|
|
# check each output
|
2024-01-09 03:37:44 +00:00
|
|
|
images = sources.as_images()
|
|
|
|
results = []
|
|
|
|
for i, image in enumerate(images):
|
2024-01-09 03:50:36 +00:00
|
|
|
prompt = sources.metadata[i].params.prompt
|
2024-01-09 03:37:44 +00:00
|
|
|
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
|
|
|
|
|
|
|
if check.is_csam:
|
|
|
|
logger.warning("flagging csam result: %s, %s", i, prompt)
|
|
|
|
is_csam = True
|
2024-01-09 04:32:49 +00:00
|
|
|
continue
|
2024-01-09 03:37:44 +00:00
|
|
|
|
|
|
|
if check.is_nsfw and block_nsfw:
|
|
|
|
logger.warning("blocking nsfw image: %s, %s", i, prompt)
|
|
|
|
results.append(Image.new("RGB", image.size, color="black"))
|
2024-01-09 04:32:49 +00:00
|
|
|
continue
|
|
|
|
|
|
|
|
results.append(image)
|
2024-01-09 03:37:44 +00:00
|
|
|
|
|
|
|
if is_csam:
|
2024-01-14 15:59:19 +00:00
|
|
|
# TODO: save metadata to a report file
|
2024-01-09 03:37:44 +00:00
|
|
|
logger.warning("blocking csam result")
|
2024-01-09 04:14:32 +00:00
|
|
|
raise CancelledException(reason="csam")
|
2024-01-09 03:37:44 +00:00
|
|
|
else:
|
|
|
|
return StageResult.from_images(results, metadata=sources.metadata)
|
|
|
|
except ImportError:
|
|
|
|
logger.warning("horde safety not installed")
|
|
|
|
return StageResult.empty()
|