1
0
Fork 0

feat(api): add optional horde safety stage

This commit is contained in:
Sean Sube 2024-01-08 21:37:44 -06:00
parent c8d6afd64d
commit a42d728006
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 97 additions and 0 deletions

View File

@ -0,0 +1,72 @@
from logging import getLogger
from typing import Optional
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class EditSafetyStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("checking results using horde safety")
# keep these within run to make this sort of like a plugin or peer dependency
try:
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
from horde_safety.nsfw_checker_class import NSFWChecker
# set up
block_nsfw = server.has_feature("horde-safety-nsfw")
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
nsfw_checker = NSFWChecker(
interrogator,
deep_danbooru_model,
)
# individual flags from NSFWResult
is_csam = False
images = sources.as_images()
results = []
for i, image in enumerate(images):
prompt = sources.metadata[i].prompt
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
if check.is_csam:
logger.warning("flagging csam result: %s, %s", i, prompt)
is_csam = True
if check.is_nsfw and block_nsfw:
logger.warning("blocking nsfw image: %s, %s", i, prompt)
results.append(Image.new("RGB", image.size, color="black"))
if is_csam:
logger.warning("blocking csam result")
raise RuntimeError("csam detected")
else:
return StageResult.from_images(results, metadata=sources.metadata)
except ImportError:
logger.warning("horde safety not installed")
return StageResult.empty()

View File

@ -1,5 +1,8 @@
from logging import getLogger
from .edit_safety import EditSafetyStage
from .edit_text import EditTextStage
from .base import BaseStage
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
@ -40,6 +43,8 @@ CHAIN_STAGES = {
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"edit-metadata": EditMetadataStage,
"edit-safety": EditSafetyStage,
"edit-text": EditTextStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,

View File

@ -52,6 +52,16 @@ def get_highres_tile(
return params.unet_tile
def add_safety_stage(
server: ServerContext,
pipeline: ChainPipeline,
) -> None:
if server.has_feature("horde-safety"):
from ..chain.edit_safety import EditSafetyStage
pipeline.stage(EditSafetyStage(), StageParams())
def run_txt2img_pipeline(
worker: WorkerContext,
server: ServerContext,
@ -110,6 +120,8 @@ def run_txt2img_pipeline(
upscale=after_upscale,
)
add_safety_stage(server, chain)
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback(reset=True)
@ -209,6 +221,8 @@ def run_img2img_pipeline(
chain=chain,
)
add_safety_stage(server, chain)
# run and append the filtered source
progress = worker.get_progress_callback(reset=True)
images = chain(
@ -382,6 +396,8 @@ def run_inpaint_pipeline(
chain=chain,
)
add_safety_stage(server, chain)
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback(reset=True)
@ -462,6 +478,8 @@ def run_upscale_pipeline(
chain=chain,
)
add_safety_stage(server, chain)
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(
@ -515,6 +533,8 @@ def run_blend_pipeline(
chain=chain,
)
add_safety_stage(server, chain)
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(