1
0
Fork 0
onnx-web/api/onnx_web/chain/edit_safety.py

94 lines
3.1 KiB
Python
Raw Normal View History

from logging import getLogger
2024-01-14 15:59:19 +00:00
from typing import Any, Optional
from PIL import Image
2024-01-09 03:50:36 +00:00
from ..errors import CancelledException
2024-01-15 01:24:50 +00:00
from ..output import save_metadata
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
2024-01-14 15:59:19 +00:00
from ..server.model_cache import ModelTypes
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class EditSafetyStage(BaseStage):
max_tile = SizeChart.max
2024-01-14 15:59:19 +00:00
def load(self, server: ServerContext) -> Any:
2024-01-14 21:36:21 +00:00
# keep these within run to make this sort of like a plugin or peer dependency
2024-01-14 15:59:19 +00:00
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
from horde_safety.nsfw_checker_class import NSFWChecker
2024-01-14 21:36:21 +00:00
# check cache
2024-01-14 15:59:19 +00:00
cache_key = ("horde-safety",)
cache_checker = server.cache.get(ModelTypes.safety, cache_key)
if cache_checker is not None:
return cache_checker
# set up
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
nsfw_checker = NSFWChecker(
interrogator,
deep_danbooru_model,
)
server.cache.set(ModelTypes.safety, cache_key)
return nsfw_checker
def run(
self,
2024-01-15 01:24:50 +00:00
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("checking results using horde safety")
try:
# set up
2024-01-14 15:59:19 +00:00
nsfw_checker = self.load(server)
block_nsfw = server.has_feature("horde-safety-nsfw")
is_csam = False
2024-01-14 15:59:19 +00:00
# check each output
images = sources.as_images()
results = []
for i, image in enumerate(images):
2024-01-15 01:24:50 +00:00
metadata = sources.metadata[i]
prompt = metadata.params.prompt
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
if check.is_csam:
logger.warning("flagging csam result: %s, %s", i, prompt)
is_csam = True
2024-01-15 01:24:50 +00:00
report_name = f"csam-report-{worker.job}-{i}"
report_path = save_metadata(server, report_name, metadata)
logger.info("saved csam report: %s", report_path)
elif check.is_nsfw and block_nsfw:
logger.warning("blocking nsfw image: %s, %s", i, prompt)
results.append(Image.new("RGB", image.size, color="black"))
2024-01-15 01:24:50 +00:00
else:
results.append(image)
if is_csam:
logger.warning("blocking csam result")
2024-01-09 04:14:32 +00:00
raise CancelledException(reason="csam")
else:
return StageResult.from_images(results, metadata=sources.metadata)
except ImportError:
logger.warning("horde safety not installed")
return StageResult.empty()