1
0
Fork 0

apply patches within workers

This commit is contained in:
Sean Sube 2023-02-26 12:32:48 -06:00
parent e1d0ad54b7
commit f115326da7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 16 additions and 5 deletions

View File

@ -39,7 +39,7 @@ def main():
CORS(app, origins=context.cors_origin) CORS(app, origins=context.cors_origin)
# any is a fake device, should not be in the pool # any is a fake device, should not be in the pool
pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"]) pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"])
# register routes # register routes
register_static_routes(app, context, pool) register_static_routes(app, context, pool)

View File

@ -5,6 +5,7 @@ from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext
from .context import WorkerContext from .context import WorkerContext
from .worker import logger_init, worker_init from .worker import logger_init, worker_init
@ -20,9 +21,11 @@ class DevicePoolExecutor:
def __init__( def __init__(
self, self,
server: ServerContext,
devices: List[DeviceParams], devices: List[DeviceParams],
finished_limit: int = 10, finished_limit: int = 10,
): ):
self.server = server
self.devices = devices self.devices = devices
self.finished = [] self.finished = []
self.finished_limit = finished_limit self.finished_limit = finished_limit
@ -32,9 +35,12 @@ class DevicePoolExecutor:
self.progress = {} self.progress = {}
self.workers = {} self.workers = {}
# TODO: make this a method
logger.debug("starting log worker") logger.debug("starting log worker")
self.log_queue = Queue() self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.lock, self.log_queue)) log_lock = Lock()
self.locks["logger"] = log_lock
self.logger = Process(target=logger_init, args=(log_lock, self.log_queue))
self.logger.start() self.logger.start()
logger.debug("testing log worker") logger.debug("testing log worker")
@ -43,10 +49,11 @@ class DevicePoolExecutor:
# create a pending queue and progress value for each device # create a pending queue and progress value for each device
for device in devices: for device in devices:
name = device.device name = device.device
# TODO: make this a method
lock = Lock() lock = Lock()
self.locks[name] = lock self.locks[name] = lock
cancel = Value("B", False, lock=lock) cancel = Value("B", False, lock=lock)
progress = Value("I", 0, lock=lock) progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why?
self.progress[name] = progress self.progress[name] = progress
pending = Queue() pending = Queue()
self.pending[name] = pending self.pending[name] = pending
@ -54,7 +61,7 @@ class DevicePoolExecutor:
self.context[name] = context self.context[name] = context
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(lock, context)) self.workers[name] = Process(target=worker_init, args=(lock, context, server))
self.workers[name].start() self.workers[name].start()
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:

View File

@ -5,9 +5,11 @@ from torch.multiprocessing import Lock, Queue
from traceback import format_exception from traceback import format_exception
from .context import WorkerContext from .context import WorkerContext
from ..server import ServerContext, apply_patches
logger = getLogger(__name__) logger = getLogger(__name__)
def logger_init(lock: Lock, logs: Queue): def logger_init(lock: Lock, logs: Queue):
with lock: with lock:
logger.info("checking in from logger, %s", lock) logger.info("checking in from logger, %s", lock)
@ -19,10 +21,12 @@ def logger_init(lock: Lock, logs: Queue):
f.write(str(job) + "\n\n") f.write(str(job) + "\n\n")
def worker_init(lock: Lock, context: WorkerContext): def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
with lock: with lock:
logger.info("checking in from worker, %s, %s", lock, get_available_providers()) logger.info("checking in from worker, %s, %s", lock, get_available_providers())
apply_patches(server)
while True: while True:
job = context.pending.get() job = context.pending.get()
logger.info("got job: %s", job) logger.info("got job: %s", job)