1
0
Fork 0

cache horde safety models

This commit is contained in:
Sean Sube 2024-01-14 09:59:19 -06:00
parent 8f31d16fa3
commit 5e40ba949e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 29 additions and 15 deletions

View File

@ -1,11 +1,12 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import Any, Optional
from PIL import Image from PIL import Image
from ..errors import CancelledException from ..errors import CancelledException
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..server.model_cache import ModelTypes
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -16,6 +17,29 @@ logger = getLogger(__name__)
class EditSafetyStage(BaseStage): class EditSafetyStage(BaseStage):
max_tile = SizeChart.max max_tile = SizeChart.max
def load(self, server: ServerContext) -> Any:
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
from horde_safety.nsfw_checker_class import NSFWChecker
cache_key = ("horde-safety",)
cache_checker = server.cache.get(ModelTypes.safety, cache_key)
if cache_checker is not None:
return cache_checker
# set up
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
nsfw_checker = NSFWChecker(
interrogator,
deep_danbooru_model,
)
server.cache.set(ModelTypes.safety, cache_key)
return nsfw_checker
def run( def run(
self, self,
_worker: WorkerContext, _worker: WorkerContext,
@ -31,24 +55,12 @@ class EditSafetyStage(BaseStage):
# keep these within run to make this sort of like a plugin or peer dependency # keep these within run to make this sort of like a plugin or peer dependency
try: try:
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
from horde_safety.nsfw_checker_class import NSFWChecker
# set up # set up
nsfw_checker = self.load(server)
block_nsfw = server.has_feature("horde-safety-nsfw") block_nsfw = server.has_feature("horde-safety-nsfw")
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
nsfw_checker = NSFWChecker(
interrogator,
deep_danbooru_model,
)
# individual flags from NSFWResult
is_csam = False is_csam = False
# check each output
images = sources.as_images() images = sources.as_images()
results = [] results = []
for i, image in enumerate(images): for i, image in enumerate(images):
@ -68,6 +80,7 @@ class EditSafetyStage(BaseStage):
results.append(image) results.append(image)
if is_csam: if is_csam:
# TODO: save metadata to a report file
logger.warning("blocking csam result") logger.warning("blocking csam result")
raise CancelledException(reason="csam") raise CancelledException(reason="csam")
else: else:

View File

@ -12,6 +12,7 @@ class ModelTypes(str, Enum):
diffusion = "diffusion" diffusion = "diffusion"
scheduler = "scheduler" scheduler = "scheduler"
upscaling = "upscaling" upscaling = "upscaling"
safety = "safety"
class ModelCache: class ModelCache: