cache horde safety models
This commit is contained in:
parent
8f31d16fa3
commit
5e40ba949e
|
@ -1,11 +1,12 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..errors import CancelledException
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..server.model_cache import ModelTypes
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
@ -16,6 +17,29 @@ logger = getLogger(__name__)
|
|||
class EditSafetyStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def load(self, server: ServerContext) -> Any:
|
||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||
from horde_safety.interrogate import get_interrogator_no_blip
|
||||
from horde_safety.nsfw_checker_class import NSFWChecker
|
||||
|
||||
cache_key = ("horde-safety",)
|
||||
cache_checker = server.cache.get(ModelTypes.safety, cache_key)
|
||||
if cache_checker is not None:
|
||||
return cache_checker
|
||||
|
||||
# set up
|
||||
interrogator = get_interrogator_no_blip()
|
||||
deep_danbooru_model = get_deep_danbooru_model()
|
||||
|
||||
nsfw_checker = NSFWChecker(
|
||||
interrogator,
|
||||
deep_danbooru_model,
|
||||
)
|
||||
|
||||
server.cache.set(ModelTypes.safety, cache_key)
|
||||
|
||||
return nsfw_checker
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
|
@ -31,24 +55,12 @@ class EditSafetyStage(BaseStage):
|
|||
|
||||
# keep these within run to make this sort of like a plugin or peer dependency
|
||||
try:
|
||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||
from horde_safety.interrogate import get_interrogator_no_blip
|
||||
from horde_safety.nsfw_checker_class import NSFWChecker
|
||||
|
||||
# set up
|
||||
nsfw_checker = self.load(server)
|
||||
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||
|
||||
interrogator = get_interrogator_no_blip()
|
||||
deep_danbooru_model = get_deep_danbooru_model()
|
||||
|
||||
nsfw_checker = NSFWChecker(
|
||||
interrogator,
|
||||
deep_danbooru_model,
|
||||
)
|
||||
|
||||
# individual flags from NSFWResult
|
||||
is_csam = False
|
||||
|
||||
# check each output
|
||||
images = sources.as_images()
|
||||
results = []
|
||||
for i, image in enumerate(images):
|
||||
|
@ -68,6 +80,7 @@ class EditSafetyStage(BaseStage):
|
|||
results.append(image)
|
||||
|
||||
if is_csam:
|
||||
# TODO: save metadata to a report file
|
||||
logger.warning("blocking csam result")
|
||||
raise CancelledException(reason="csam")
|
||||
else:
|
||||
|
|
|
@ -12,6 +12,7 @@ class ModelTypes(str, Enum):
|
|||
diffusion = "diffusion"
|
||||
scheduler = "scheduler"
|
||||
upscaling = "upscaling"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class ModelCache:
|
||||
|
|
Loading…
Reference in New Issue