1
0
Fork 0

fix(api): load horde safety model to worker's torch device

This commit is contained in:
Sean Sube 2024-02-25 07:44:46 -06:00
parent f61808c990
commit 9d87e92a1c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 5 additions and 4 deletions

View File

@ -18,7 +18,7 @@ logger = getLogger(__name__)
class EditSafetyStage(BaseStage):
max_tile = SizeChart.max
def load(self, server: ServerContext) -> Any:
def load(self, server: ServerContext, device: str) -> Any:
# keep these within run to make this sort of like a plugin or peer dependency
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
from horde_safety.interrogate import get_interrogator_no_blip
@ -31,8 +31,8 @@ class EditSafetyStage(BaseStage):
return cache_checker
# set up
interrogator = get_interrogator_no_blip()
deep_danbooru_model = get_deep_danbooru_model()
interrogator = get_interrogator_no_blip(device=device)
deep_danbooru_model = get_deep_danbooru_model(device=device)
nsfw_checker = NSFWChecker(
interrogator,
@ -58,7 +58,8 @@ class EditSafetyStage(BaseStage):
try:
# set up
nsfw_checker = self.load(server)
torch_device = worker.device.torch_str()
nsfw_checker = self.load(server, torch_device)
block_nsfw = server.has_feature("horde-safety-nsfw")
is_csam = False