diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py index 0a1d281f..ee5eeb78 100644 --- a/api/onnx_web/chain/edit_safety.py +++ b/api/onnx_web/chain/edit_safety.py @@ -1,11 +1,12 @@ from logging import getLogger -from typing import Optional +from typing import Any, Optional from PIL import Image from ..errors import CancelledException from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext +from ..server.model_cache import ModelTypes from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -16,6 +17,29 @@ logger = getLogger(__name__) class EditSafetyStage(BaseStage): max_tile = SizeChart.max + def load(self, server: ServerContext) -> Any: + from horde_safety.deep_danbooru_model import get_deep_danbooru_model + from horde_safety.interrogate import get_interrogator_no_blip + from horde_safety.nsfw_checker_class import NSFWChecker + + cache_key = ("horde-safety",) + cache_checker = server.cache.get(ModelTypes.safety, cache_key) + if cache_checker is not None: + return cache_checker + + # set up + interrogator = get_interrogator_no_blip() + deep_danbooru_model = get_deep_danbooru_model() + + nsfw_checker = NSFWChecker( + interrogator, + deep_danbooru_model, + ) + + server.cache.set(ModelTypes.safety, cache_key) + + return nsfw_checker + def run( self, _worker: WorkerContext, @@ -31,24 +55,12 @@ class EditSafetyStage(BaseStage): # keep these within run to make this sort of like a plugin or peer dependency try: - from horde_safety.deep_danbooru_model import get_deep_danbooru_model - from horde_safety.interrogate import get_interrogator_no_blip - from horde_safety.nsfw_checker_class import NSFWChecker - # set up + nsfw_checker = self.load(server) block_nsfw = server.has_feature("horde-safety-nsfw") - - interrogator = get_interrogator_no_blip() - deep_danbooru_model = get_deep_danbooru_model() - - nsfw_checker = NSFWChecker( - interrogator, - deep_danbooru_model, - ) - - # individual flags from NSFWResult is_csam = False + # check each output images = sources.as_images() results = [] for i, image in enumerate(images): @@ -68,6 +80,7 @@ class EditSafetyStage(BaseStage): results.append(image) if is_csam: + # TODO: save metadata to a report file logger.warning("blocking csam result") raise CancelledException(reason="csam") else: diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 6525d4ae..8c157858 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -12,6 +12,7 @@ class ModelTypes(str, Enum): diffusion = "diffusion" scheduler = "scheduler" upscaling = "upscaling" + safety = "safety" class ModelCache: