1
0
Fork 0
onnx-web/api/onnx_web/chain/edit_safety.py

78 lines
2.5 KiB
Python

from logging import getLogger
from typing import Optional
from PIL import Image
from ..errors import CancelledException
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class EditSafetyStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("checking results using horde safety")
# keep these within run to make this sort of like a plugin or peer dependency
try:
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
from horde_safety.nsfw_checker_class import NSFWChecker
# set up
block_nsfw = server.has_feature("horde-safety-nsfw")
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
nsfw_checker = NSFWChecker(
interrogator,
deep_danbooru_model,
)
# individual flags from NSFWResult
is_csam = False
images = sources.as_images()
results = []
for i, image in enumerate(images):
prompt = sources.metadata[i].params.prompt
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
if check.is_csam:
logger.warning("flagging csam result: %s, %s", i, prompt)
is_csam = True
continue
if check.is_nsfw and block_nsfw:
logger.warning("blocking nsfw image: %s, %s", i, prompt)
results.append(Image.new("RGB", image.size, color="black"))
continue
results.append(image)
if is_csam:
logger.warning("blocking csam result")
raise CancelledException(reason="csam")
else:
return StageResult.from_images(results, metadata=sources.metadata)
except ImportError:
logger.warning("horde safety not installed")
return StageResult.empty()