wire up worker jobs
This commit is contained in:
parent
f898de8c54
commit
943281feb5
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from torch.multiprocessing import Queue, Value
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
from ..params import DeviceParams
|
||||
|
||||
|
@ -12,6 +12,7 @@ ProgressCallback = Callable[[int, int, Any], None]
|
|||
class WorkerContext:
|
||||
cancel: "Value[bool]" = None
|
||||
key: str = None
|
||||
pending: "Queue[Tuple[Callable, Any, Any]]" = None
|
||||
progress: "Value[int]" = None
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from logging import getLogger
|
||||
from torch.multiprocessing import Lock
|
||||
from time import sleep
|
||||
from traceback import print_exception
|
||||
|
||||
from .context import WorkerContext
|
||||
|
||||
|
@ -30,3 +31,10 @@ def worker_init(lock: Lock, context: WorkerContext):
|
|||
else:
|
||||
job = context.pending.get()
|
||||
logger.info("got job: %s", job)
|
||||
try:
|
||||
fn, args, kwargs = job
|
||||
fn(context, *args, **kwargs)
|
||||
logger.info("finished job")
|
||||
except Exception as e:
|
||||
print_exception(type(e), e, e.__traceback__)
|
||||
|
||||
|
|
Loading…
Reference in New Issue