feat(api): add optional horde safety stage
This commit is contained in:
parent
c8d6afd64d
commit
a42d728006
|
@ -0,0 +1,72 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class EditSafetyStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("checking results using horde safety")
|
||||
|
||||
# keep these within run to make this sort of like a plugin or peer dependency
|
||||
try:
|
||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||
from horde_safety.interrogate import get_interrogator_no_blip
|
||||
from horde_safety.nsfw_checker_class import NSFWChecker
|
||||
|
||||
# set up
|
||||
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||
|
||||
interrogator = get_interrogator_no_blip()
|
||||
deep_danbooru_model = get_deep_danbooru_model()
|
||||
|
||||
nsfw_checker = NSFWChecker(
|
||||
interrogator,
|
||||
deep_danbooru_model,
|
||||
)
|
||||
|
||||
# individual flags from NSFWResult
|
||||
is_csam = False
|
||||
|
||||
images = sources.as_images()
|
||||
results = []
|
||||
for i, image in enumerate(images):
|
||||
prompt = sources.metadata[i].prompt
|
||||
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
||||
|
||||
if check.is_csam:
|
||||
logger.warning("flagging csam result: %s, %s", i, prompt)
|
||||
is_csam = True
|
||||
|
||||
if check.is_nsfw and block_nsfw:
|
||||
logger.warning("blocking nsfw image: %s, %s", i, prompt)
|
||||
results.append(Image.new("RGB", image.size, color="black"))
|
||||
|
||||
if is_csam:
|
||||
logger.warning("blocking csam result")
|
||||
raise RuntimeError("csam detected")
|
||||
else:
|
||||
return StageResult.from_images(results, metadata=sources.metadata)
|
||||
except ImportError:
|
||||
logger.warning("horde safety not installed")
|
||||
return StageResult.empty()
|
|
@ -1,5 +1,8 @@
|
|||
from logging import getLogger
|
||||
|
||||
from .edit_safety import EditSafetyStage
|
||||
from .edit_text import EditTextStage
|
||||
|
||||
from .base import BaseStage
|
||||
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
|
||||
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
|
||||
|
@ -40,6 +43,8 @@ CHAIN_STAGES = {
|
|||
"correct-codeformer": CorrectCodeformerStage,
|
||||
"correct-gfpgan": CorrectGFPGANStage,
|
||||
"edit-metadata": EditMetadataStage,
|
||||
"edit-safety": EditSafetyStage,
|
||||
"edit-text": EditTextStage,
|
||||
"persist-disk": PersistDiskStage,
|
||||
"persist-s3": PersistS3Stage,
|
||||
"reduce-crop": ReduceCropStage,
|
||||
|
|
|
@ -52,6 +52,16 @@ def get_highres_tile(
|
|||
return params.unet_tile
|
||||
|
||||
|
||||
def add_safety_stage(
|
||||
server: ServerContext,
|
||||
pipeline: ChainPipeline,
|
||||
) -> None:
|
||||
if server.has_feature("horde-safety"):
|
||||
from ..chain.edit_safety import EditSafetyStage
|
||||
|
||||
pipeline.stage(EditSafetyStage(), StageParams())
|
||||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
|
@ -110,6 +120,8 @@ def run_txt2img_pipeline(
|
|||
upscale=after_upscale,
|
||||
)
|
||||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
|
@ -209,6 +221,8 @@ def run_img2img_pipeline(
|
|||
chain=chain,
|
||||
)
|
||||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
|
@ -382,6 +396,8 @@ def run_inpaint_pipeline(
|
|||
chain=chain,
|
||||
)
|
||||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
|
@ -462,6 +478,8 @@ def run_upscale_pipeline(
|
|||
chain=chain,
|
||||
)
|
||||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
|
@ -515,6 +533,8 @@ def run_blend_pipeline(
|
|||
chain=chain,
|
||||
)
|
||||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
|
|
Loading…
Reference in New Issue