apply patches within workers
This commit is contained in:
parent
e1d0ad54b7
commit
f115326da7
|
@ -39,7 +39,7 @@ def main():
|
||||||
CORS(app, origins=context.cors_origin)
|
CORS(app, origins=context.cors_origin)
|
||||||
|
|
||||||
# any is a fake device, should not be in the pool
|
# any is a fake device, should not be in the pool
|
||||||
pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"])
|
pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"])
|
||||||
|
|
||||||
# register routes
|
# register routes
|
||||||
register_static_routes(app, context, pool)
|
register_static_routes(app, context, pool)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch.multiprocessing import Lock, Process, Value
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
|
from ..server import ServerContext
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
from .worker import logger_init, worker_init
|
from .worker import logger_init, worker_init
|
||||||
|
|
||||||
|
@ -20,9 +21,11 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
server: ServerContext,
|
||||||
devices: List[DeviceParams],
|
devices: List[DeviceParams],
|
||||||
finished_limit: int = 10,
|
finished_limit: int = 10,
|
||||||
):
|
):
|
||||||
|
self.server = server
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.finished = []
|
self.finished = []
|
||||||
self.finished_limit = finished_limit
|
self.finished_limit = finished_limit
|
||||||
|
@ -32,9 +35,12 @@ class DevicePoolExecutor:
|
||||||
self.progress = {}
|
self.progress = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
|
|
||||||
|
# TODO: make this a method
|
||||||
logger.debug("starting log worker")
|
logger.debug("starting log worker")
|
||||||
self.log_queue = Queue()
|
self.log_queue = Queue()
|
||||||
self.logger = Process(target=logger_init, args=(self.lock, self.log_queue))
|
log_lock = Lock()
|
||||||
|
self.locks["logger"] = log_lock
|
||||||
|
self.logger = Process(target=logger_init, args=(log_lock, self.log_queue))
|
||||||
self.logger.start()
|
self.logger.start()
|
||||||
|
|
||||||
logger.debug("testing log worker")
|
logger.debug("testing log worker")
|
||||||
|
@ -43,10 +49,11 @@ class DevicePoolExecutor:
|
||||||
# create a pending queue and progress value for each device
|
# create a pending queue and progress value for each device
|
||||||
for device in devices:
|
for device in devices:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
# TODO: make this a method
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
self.locks[name] = lock
|
self.locks[name] = lock
|
||||||
cancel = Value("B", False, lock=lock)
|
cancel = Value("B", False, lock=lock)
|
||||||
progress = Value("I", 0, lock=lock)
|
progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why?
|
||||||
self.progress[name] = progress
|
self.progress[name] = progress
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
self.pending[name] = pending
|
self.pending[name] = pending
|
||||||
|
@ -54,7 +61,7 @@ class DevicePoolExecutor:
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
self.workers[name] = Process(target=worker_init, args=(lock, context))
|
self.workers[name] = Process(target=worker_init, args=(lock, context, server))
|
||||||
self.workers[name].start()
|
self.workers[name].start()
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
|
|
|
@ -5,9 +5,11 @@ from torch.multiprocessing import Lock, Queue
|
||||||
from traceback import format_exception
|
from traceback import format_exception
|
||||||
|
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
|
from ..server import ServerContext, apply_patches
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def logger_init(lock: Lock, logs: Queue):
|
def logger_init(lock: Lock, logs: Queue):
|
||||||
with lock:
|
with lock:
|
||||||
logger.info("checking in from logger, %s", lock)
|
logger.info("checking in from logger, %s", lock)
|
||||||
|
@ -19,10 +21,12 @@ def logger_init(lock: Lock, logs: Queue):
|
||||||
f.write(str(job) + "\n\n")
|
f.write(str(job) + "\n\n")
|
||||||
|
|
||||||
|
|
||||||
def worker_init(lock: Lock, context: WorkerContext):
|
def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
|
||||||
with lock:
|
with lock:
|
||||||
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
|
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
|
||||||
|
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
job = context.pending.get()
|
job = context.pending.get()
|
||||||
logger.info("got job: %s", job)
|
logger.info("got job: %s", job)
|
||||||
|
|
Loading…
Reference in New Issue