cache horde safety models
This commit is contained in:
parent
8f31d16fa3
commit
5e40ba949e
|
@ -1,11 +1,12 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..errors import CancelledException
|
from ..errors import CancelledException
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..server.model_cache import ModelTypes
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
@ -16,6 +17,29 @@ logger = getLogger(__name__)
|
||||||
class EditSafetyStage(BaseStage):
|
class EditSafetyStage(BaseStage):
|
||||||
max_tile = SizeChart.max
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def load(self, server: ServerContext) -> Any:
|
||||||
|
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||||
|
from horde_safety.interrogate import get_interrogator_no_blip
|
||||||
|
from horde_safety.nsfw_checker_class import NSFWChecker
|
||||||
|
|
||||||
|
cache_key = ("horde-safety",)
|
||||||
|
cache_checker = server.cache.get(ModelTypes.safety, cache_key)
|
||||||
|
if cache_checker is not None:
|
||||||
|
return cache_checker
|
||||||
|
|
||||||
|
# set up
|
||||||
|
interrogator = get_interrogator_no_blip()
|
||||||
|
deep_danbooru_model = get_deep_danbooru_model()
|
||||||
|
|
||||||
|
nsfw_checker = NSFWChecker(
|
||||||
|
interrogator,
|
||||||
|
deep_danbooru_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
server.cache.set(ModelTypes.safety, cache_key)
|
||||||
|
|
||||||
|
return nsfw_checker
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_worker: WorkerContext,
|
_worker: WorkerContext,
|
||||||
|
@ -31,24 +55,12 @@ class EditSafetyStage(BaseStage):
|
||||||
|
|
||||||
# keep these within run to make this sort of like a plugin or peer dependency
|
# keep these within run to make this sort of like a plugin or peer dependency
|
||||||
try:
|
try:
|
||||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
|
||||||
from horde_safety.interrogate import get_interrogator_no_blip
|
|
||||||
from horde_safety.nsfw_checker_class import NSFWChecker
|
|
||||||
|
|
||||||
# set up
|
# set up
|
||||||
|
nsfw_checker = self.load(server)
|
||||||
block_nsfw = server.has_feature("horde-safety-nsfw")
|
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||||
|
|
||||||
interrogator = get_interrogator_no_blip()
|
|
||||||
deep_danbooru_model = get_deep_danbooru_model()
|
|
||||||
|
|
||||||
nsfw_checker = NSFWChecker(
|
|
||||||
interrogator,
|
|
||||||
deep_danbooru_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# individual flags from NSFWResult
|
|
||||||
is_csam = False
|
is_csam = False
|
||||||
|
|
||||||
|
# check each output
|
||||||
images = sources.as_images()
|
images = sources.as_images()
|
||||||
results = []
|
results = []
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
|
@ -68,6 +80,7 @@ class EditSafetyStage(BaseStage):
|
||||||
results.append(image)
|
results.append(image)
|
||||||
|
|
||||||
if is_csam:
|
if is_csam:
|
||||||
|
# TODO: save metadata to a report file
|
||||||
logger.warning("blocking csam result")
|
logger.warning("blocking csam result")
|
||||||
raise CancelledException(reason="csam")
|
raise CancelledException(reason="csam")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,6 +12,7 @@ class ModelTypes(str, Enum):
|
||||||
diffusion = "diffusion"
|
diffusion = "diffusion"
|
||||||
scheduler = "scheduler"
|
scheduler = "scheduler"
|
||||||
upscaling = "upscaling"
|
upscaling = "upscaling"
|
||||||
|
safety = "safety"
|
||||||
|
|
||||||
|
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
|
|
Loading…
Reference in New Issue