feat(api): add optional horde safety stage
This commit is contained in:
parent
c8d6afd64d
commit
a42d728006
|
@ -0,0 +1,72 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EditSafetyStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
*,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
logger.info("checking results using horde safety")
|
||||||
|
|
||||||
|
# keep these within run to make this sort of like a plugin or peer dependency
|
||||||
|
try:
|
||||||
|
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||||
|
from horde_safety.interrogate import get_interrogator_no_blip
|
||||||
|
from horde_safety.nsfw_checker_class import NSFWChecker
|
||||||
|
|
||||||
|
# set up
|
||||||
|
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||||
|
|
||||||
|
interrogator = get_interrogator_no_blip()
|
||||||
|
deep_danbooru_model = get_deep_danbooru_model()
|
||||||
|
|
||||||
|
nsfw_checker = NSFWChecker(
|
||||||
|
interrogator,
|
||||||
|
deep_danbooru_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# individual flags from NSFWResult
|
||||||
|
is_csam = False
|
||||||
|
|
||||||
|
images = sources.as_images()
|
||||||
|
results = []
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
prompt = sources.metadata[i].prompt
|
||||||
|
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
||||||
|
|
||||||
|
if check.is_csam:
|
||||||
|
logger.warning("flagging csam result: %s, %s", i, prompt)
|
||||||
|
is_csam = True
|
||||||
|
|
||||||
|
if check.is_nsfw and block_nsfw:
|
||||||
|
logger.warning("blocking nsfw image: %s, %s", i, prompt)
|
||||||
|
results.append(Image.new("RGB", image.size, color="black"))
|
||||||
|
|
||||||
|
if is_csam:
|
||||||
|
logger.warning("blocking csam result")
|
||||||
|
raise RuntimeError("csam detected")
|
||||||
|
else:
|
||||||
|
return StageResult.from_images(results, metadata=sources.metadata)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("horde safety not installed")
|
||||||
|
return StageResult.empty()
|
|
@ -1,5 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
|
from .edit_safety import EditSafetyStage
|
||||||
|
from .edit_text import EditTextStage
|
||||||
|
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
|
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
|
||||||
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
|
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
|
||||||
|
@ -40,6 +43,8 @@ CHAIN_STAGES = {
|
||||||
"correct-codeformer": CorrectCodeformerStage,
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
"correct-gfpgan": CorrectGFPGANStage,
|
"correct-gfpgan": CorrectGFPGANStage,
|
||||||
"edit-metadata": EditMetadataStage,
|
"edit-metadata": EditMetadataStage,
|
||||||
|
"edit-safety": EditSafetyStage,
|
||||||
|
"edit-text": EditTextStage,
|
||||||
"persist-disk": PersistDiskStage,
|
"persist-disk": PersistDiskStage,
|
||||||
"persist-s3": PersistS3Stage,
|
"persist-s3": PersistS3Stage,
|
||||||
"reduce-crop": ReduceCropStage,
|
"reduce-crop": ReduceCropStage,
|
||||||
|
|
|
@ -52,6 +52,16 @@ def get_highres_tile(
|
||||||
return params.unet_tile
|
return params.unet_tile
|
||||||
|
|
||||||
|
|
||||||
|
def add_safety_stage(
|
||||||
|
server: ServerContext,
|
||||||
|
pipeline: ChainPipeline,
|
||||||
|
) -> None:
|
||||||
|
if server.has_feature("horde-safety"):
|
||||||
|
from ..chain.edit_safety import EditSafetyStage
|
||||||
|
|
||||||
|
pipeline.stage(EditSafetyStage(), StageParams())
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -110,6 +120,8 @@ def run_txt2img_pipeline(
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
|
@ -209,6 +221,8 @@ def run_img2img_pipeline(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
|
@ -382,6 +396,8 @@ def run_inpaint_pipeline(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
|
@ -462,6 +478,8 @@ def run_upscale_pipeline(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
|
@ -515,6 +533,8 @@ def run_blend_pipeline(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
|
|
Loading…
Reference in New Issue