fix(api): load horde safety model to worker's torch device
This commit is contained in:
parent
f61808c990
commit
9d87e92a1c
|
@ -18,7 +18,7 @@ logger = getLogger(__name__)
|
||||||
class EditSafetyStage(BaseStage):
|
class EditSafetyStage(BaseStage):
|
||||||
max_tile = SizeChart.max
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def load(self, server: ServerContext) -> Any:
|
def load(self, server: ServerContext, device: str) -> Any:
|
||||||
# keep these within run to make this sort of like a plugin or peer dependency
|
# keep these within run to make this sort of like a plugin or peer dependency
|
||||||
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
from horde_safety.deep_danbooru_model import get_deep_danbooru_model
|
||||||
from horde_safety.interrogate import get_interrogator_no_blip
|
from horde_safety.interrogate import get_interrogator_no_blip
|
||||||
|
@ -31,8 +31,8 @@ class EditSafetyStage(BaseStage):
|
||||||
return cache_checker
|
return cache_checker
|
||||||
|
|
||||||
# set up
|
# set up
|
||||||
interrogator = get_interrogator_no_blip()
|
interrogator = get_interrogator_no_blip(device=device)
|
||||||
deep_danbooru_model = get_deep_danbooru_model()
|
deep_danbooru_model = get_deep_danbooru_model(device=device)
|
||||||
|
|
||||||
nsfw_checker = NSFWChecker(
|
nsfw_checker = NSFWChecker(
|
||||||
interrogator,
|
interrogator,
|
||||||
|
@ -58,7 +58,8 @@ class EditSafetyStage(BaseStage):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# set up
|
# set up
|
||||||
nsfw_checker = self.load(server)
|
torch_device = worker.device.torch_str()
|
||||||
|
nsfw_checker = self.load(server, torch_device)
|
||||||
block_nsfw = server.has_feature("horde-safety-nsfw")
|
block_nsfw = server.has_feature("horde-safety-nsfw")
|
||||||
is_csam = False
|
is_csam = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue