diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py index 8778c963..afcb7fb6 100644 --- a/api/onnx_web/chain/edit_safety.py +++ b/api/onnx_web/chain/edit_safety.py @@ -18,7 +18,7 @@ logger = getLogger(__name__) class EditSafetyStage(BaseStage): max_tile = SizeChart.max - def load(self, server: ServerContext) -> Any: + def load(self, server: ServerContext, device: str) -> Any: # keep these within run to make this sort of like a plugin or peer dependency from horde_safety.deep_danbooru_model import get_deep_danbooru_model from horde_safety.interrogate import get_interrogator_no_blip @@ -31,8 +31,8 @@ class EditSafetyStage(BaseStage): return cache_checker # set up - interrogator = get_interrogator_no_blip() - deep_danbooru_model = get_deep_danbooru_model() + interrogator = get_interrogator_no_blip(device=device) + deep_danbooru_model = get_deep_danbooru_model(device=device) nsfw_checker = NSFWChecker( interrogator, @@ -58,7 +58,8 @@ class EditSafetyStage(BaseStage): try: # set up - nsfw_checker = self.load(server) + torch_device = worker.device.torch_str() + nsfw_checker = self.load(server, torch_device) block_nsfw = server.has_feature("horde-safety-nsfw") is_csam = False