fix(api): load horde safety model to worker's torch device
This commit is contained in:
parent
f61808c990
commit
9d87e92a1c
|
@ -18,7 +18,7 @@ logger = getLogger(__name__)
|
|||
class EditSafetyStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def load(self, server: ServerContext) -> Any:
|
||||
def load(self, server: ServerContext, device: str) -> Any:
|
||||
# keep these within run to make this sort of like a plugin or peer dependency
|
||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||
from horde_safety.interrogate import get_interrogator_no_blip
|
||||
|
@ -31,8 +31,8 @@ class EditSafetyStage(BaseStage):
|
|||
return cache_checker
|
||||
|
||||
# set up
|
||||
interrogator = get_interrogator_no_blip()
|
||||
deep_danbooru_model = get_deep_danbooru_model()
|
||||
interrogator = get_interrogator_no_blip(device=device)
|
||||
deep_danbooru_model = get_deep_danbooru_model(device=device)
|
||||
|
||||
nsfw_checker = NSFWChecker(
|
||||
interrogator,
|
||||
|
@ -58,7 +58,8 @@ class EditSafetyStage(BaseStage):
|
|||
|
||||
try:
|
||||
# set up
|
||||
nsfw_checker = self.load(server)
|
||||
torch_device = worker.device.torch_str()
|
||||
nsfw_checker = self.load(server, torch_device)
|
||||
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||
is_csam = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue