1
0
Fork 0

background workers, logger

This commit is contained in:
Sean Sube 2023-02-25 23:49:39 -06:00
parent e46a1e5fd0
commit f898de8c54
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
28 changed files with 306 additions and 284 deletions

View File

@ -5,14 +5,14 @@ formatters:
handlers: handlers:
console: console:
class: logging.StreamHandler class: logging.StreamHandler
level: INFO level: DEBUG
formatter: simple formatter: simple
stream: ext://sys.stdout stream: ext://sys.stdout
loggers: loggers:
'': '':
level: INFO level: DEBUG
handlers: [console] handlers: [console]
propagate: True propagate: True
root: root:
level: INFO level: DEBUG
handlers: [console] handlers: [console]

View File

@ -25,6 +25,7 @@ from .image import (
from .onnx import OnnxNet, OnnxTensor from .onnx import OnnxNet, OnnxTensor
from .params import ( from .params import (
Border, Border,
DeviceParams,
ImageParams, ImageParams,
Param, Param,
Point, Point,
@ -33,8 +34,6 @@ from .params import (
UpscaleParams, UpscaleParams,
) )
from .server import ( from .server import (
DeviceParams,
DevicePoolExecutor,
ModelCache, ModelCache,
ServerContext, ServerContext,
apply_patch_basicsr, apply_patch_basicsr,
@ -51,3 +50,6 @@ from .utils import (
get_from_map, get_from_map,
get_not_empty, get_not_empty,
) )
from .worker import (
DevicePoolExecutor,
)

View File

@ -7,7 +7,8 @@ from PIL import Image
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from .utils import process_tile_order from .utils import process_tile_order
@ -17,7 +18,7 @@ logger = getLogger(__name__)
class StageCallback(Protocol): class StageCallback(Protocol):
def __call__( def __call__(
self, self,
job: JobContext, job: WorkerContext,
ctx: ServerContext, ctx: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
@ -77,7 +78,7 @@ class ChainPipeline:
def __call__( def __call__(
self, self,
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,

View File

@ -7,13 +7,14 @@ from PIL import Image
from ..diffusion.load import load_pipeline from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_img2img( def blend_img2img(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from .utils import process_tile_order from .utils import process_tile_order
@ -18,7 +19,7 @@ logger = getLogger(__name__)
def blend_inpaint( def blend_inpaint(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -7,14 +7,15 @@ from onnx_web.image import valid_image
from onnx_web.output import save_image from onnx_web.output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_mask( def blend_mask(
_job: JobContext, _job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -3,7 +3,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -11,7 +12,7 @@ device = "cpu"
def correct_codeformer( def correct_codeformer(
job: JobContext, job: WorkerContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -5,8 +5,10 @@ import numpy as np
from PIL import Image from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -46,7 +48,7 @@ def load_gfpgan(
def correct_gfpgan( def correct_gfpgan(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -4,13 +4,15 @@ from PIL import Image
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_disk( def persist_disk(
_job: JobContext, _job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -5,13 +5,15 @@ from boto3 import Session
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_s3( def persist_s3(
_job: JobContext, _job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -3,13 +3,15 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_crop( def reduce_crop(
_job: JobContext, _job: WorkerContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -3,13 +3,15 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_thumbnail( def reduce_thumbnail(
_job: JobContext, _job: WorkerContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -4,13 +4,15 @@ from typing import Callable
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_noise( def source_noise(
_job: JobContext, _job: WorkerContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -7,13 +7,15 @@ from PIL import Image
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_txt2img( def source_txt2img(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from .utils import process_tile_grid, process_tile_order from .utils import process_tile_grid, process_tile_order
@ -18,7 +19,7 @@ logger = getLogger(__name__)
def upscale_outpaint( def upscale_outpaint(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -6,7 +6,8 @@ from PIL import Image
from ..onnx import OnnxNet from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
@ -96,7 +97,7 @@ def load_resrgan(
def upscale_resrgan( def upscale_resrgan(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,

View File

@ -10,7 +10,8 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ProgressCallback, ServerContext from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
@ -62,7 +63,7 @@ def load_stable_diffusion(
def upscale_stable_diffusion( def upscale_stable_diffusion(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -12,7 +12,8 @@ from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint from ..chain import upscale_outpaint
from ..output import save_image, save_params from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..server import JobContext, ServerContext from ..worker import WorkerContext
from ..server import ServerContext
from ..upscale import run_upscale_correction from ..upscale import run_upscale_correction
from ..utils import run_gc from ..utils import run_gc
from .load import get_latents_from_seed, load_pipeline from .load import get_latents_from_seed, load_pipeline
@ -21,7 +22,7 @@ logger = getLogger(__name__)
def run_txt2img_pipeline( def run_txt2img_pipeline(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
@ -95,7 +96,7 @@ def run_txt2img_pipeline(
def run_img2img_pipeline( def run_img2img_pipeline(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
outputs: List[str], outputs: List[str],
@ -167,7 +168,7 @@ def run_img2img_pipeline(
def run_inpaint_pipeline( def run_inpaint_pipeline(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
@ -217,7 +218,7 @@ def run_inpaint_pipeline(
def run_upscale_pipeline( def run_upscale_pipeline(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
@ -243,7 +244,7 @@ def run_upscale_pipeline(
def run_blend_pipeline( def run_blend_pipeline(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,

View File

@ -63,7 +63,7 @@ from .params import (
TileOrder, TileOrder,
UpscaleParams, UpscaleParams,
) )
from .server import DevicePoolExecutor, ServerContext, apply_patches from .server import ServerContext, apply_patches
from .transformers import run_txt2txt_pipeline from .transformers import run_txt2txt_pipeline
from .utils import ( from .utils import (
base_join, base_join,
@ -75,6 +75,7 @@ from .utils import (
get_size, get_size,
is_debug, is_debug,
) )
from .worker import DevicePoolExecutor
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -1,10 +1,3 @@
from .device_pool import (
DeviceParams,
DevicePoolExecutor,
Job,
JobContext,
ProgressCallback,
)
from .hacks import ( from .hacks import (
apply_patch_basicsr, apply_patch_basicsr,
apply_patch_codeformer, apply_patch_codeformer,

View File

@ -1,230 +0,0 @@
from collections import Counter
from concurrent.futures import Future
from logging import getLogger
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
from traceback import format_exception
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from time import sleep
from ..params import DeviceParams
from ..utils import run_gc
logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
def worker_init(lock: Lock, job_queue: SimpleQueue):
logger.info("checking in from worker")
while True:
if job_queue.empty():
logger.info("no jobs, sleeping")
sleep(5)
else:
job = job_queue.get()
logger.info("got job: %s", job)
class JobContext:
cancel: Value = None
device_index: Value = None
devices: List[DeviceParams] = None
key: str = None
progress: Value = None
def __init__(
self,
key: str,
devices: List[DeviceParams],
cancel: bool = False,
device_index: int = -1,
progress: int = 0,
):
self.key = key
self.devices = list(devices)
self.cancel = Value("B", cancel)
self.device_index = Value("i", device_index)
self.progress = Value("I", progress)
def is_cancelled(self) -> bool:
return self.cancel.value
def get_device(self) -> DeviceParams:
"""
Get the device assigned to this job.
"""
with self.device_index.get_lock():
device_index = self.device_index.value
if device_index < 0:
raise ValueError("job has not been assigned to a device")
else:
device = self.devices[device_index]
logger.debug("job %s assigned to device %s", self.key, device)
return device
def get_progress(self) -> int:
return self.progress.value
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.key, step)
self.set_progress(step)
return on_progress
def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress
class Job:
"""
Link a future to its context.
"""
context: JobContext = None
future: Future = None
key: str = None
def __init__(
self,
key: str,
future: Future,
context: JobContext,
):
self.context = context
self.future = future
self.key = key
def get_progress(self) -> int:
return self.context.get_progress()
def set_cancel(self, cancel: bool = True):
return self.context.set_cancel(cancel)
def set_progress(self, progress: int):
return self.context.set_progress(progress)
class DevicePoolExecutor:
devices: List[DeviceParams] = None
finished: List[Tuple[str, int]] = None
pending: Dict[str, "Queue[Job]"] = None
progress: Dict[str, Value] = None
workers: Dict[str, Process] = None
def __init__(
self,
devices: List[DeviceParams],
finished_limit: int = 10,
):
self.devices = devices
self.finished = []
self.finished_limit = finished_limit
self.lock = Lock()
self.pending = {}
self.progress = {}
self.workers = {}
# create a pending queue and progress value for each device
for device in devices:
name = device.device
job_queue = Queue()
self.pending[name] = job_queue
self.progress[name] = Value("I", 0, lock=self.lock)
self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.finished:
if key == k:
return (True, progress)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
# use the first/default device if there are no jobs
if len(self.jobs) == 0:
return 0
job_devices = [
job.context.device_index.value for job in self.jobs if not job.future.done()
]
job_counts = Counter(range(len(self.devices)))
job_counts.update(job_devices)
queued = job_counts.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit:]
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
)
context = JobContext(key, self.devices, device_index=device_idx)
device = self.devices[device_idx]
queue = self.pending[device.device]
queue.put((fn, context, args, kwargs))
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
device.device,
self.pending[device.device].qsize(),
)
for device in self.devices
]
pending.extend(self.finished)
return pending

View File

@ -1,13 +1,14 @@
from logging import getLogger from logging import getLogger
from .params import ImageParams, Size from .params import ImageParams, Size
from .server import JobContext, ServerContext from .server import ServerContext
from .worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def run_txt2txt_pipeline( def run_txt2txt_pipeline(
job: JobContext, job: WorkerContext,
_server: ServerContext, _server: ServerContext,
params: ImageParams, params: ImageParams,
_size: Size, _size: Size,

View File

@ -10,13 +10,14 @@ from .chain import (
upscale_stable_diffusion, upscale_stable_diffusion,
) )
from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import JobContext, ProgressCallback, ServerContext from .server import ServerContext
from .worker import WorkerContext, ProgressCallback
logger = getLogger(__name__) logger = getLogger(__name__)
def run_upscale_correction( def run_upscale_correction(
job: JobContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,

View File

@ -0,0 +1,2 @@
from .context import WorkerContext, ProgressCallback
from .pool import DevicePoolExecutor

View File

@ -0,0 +1,60 @@
from logging import getLogger
from torch.multiprocessing import Queue, Value
from typing import Any, Callable
from ..params import DeviceParams
logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext:
cancel: "Value[bool]" = None
key: str = None
progress: "Value[int]" = None
def __init__(
self,
key: str,
cancel: "Value[bool]",
device: DeviceParams,
pending: "Queue[Any]",
progress: "Value[int]",
):
self.key = key
self.cancel = cancel
self.device = device
self.pending = pending
self.progress = progress
def is_cancelled(self) -> bool:
return self.cancel.value
def get_device(self) -> DeviceParams:
"""
Get the device assigned to this job.
"""
return self.device
def get_progress(self) -> int:
return self.progress.value
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.key, step)
self.set_progress(step)
return on_progress
def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress

View File

@ -0,0 +1 @@
# TODO: queue-based logger

136
api/onnx_web/worker/pool.py Normal file
View File

@ -0,0 +1,136 @@
from collections import Counter
from logging import getLogger
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple
from ..params import DeviceParams
from .context import WorkerContext
from .worker import logger_init, worker_init
logger = getLogger(__name__)
class DevicePoolExecutor:
devices: List[DeviceParams] = None
finished: List[Tuple[str, int]] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
progress: Dict[str, Value] = None
workers: Dict[str, Process] = None
def __init__(
self,
devices: List[DeviceParams],
finished_limit: int = 10,
):
self.devices = devices
self.finished = []
self.finished_limit = finished_limit
self.lock = Lock()
self.pending = {}
self.progress = {}
self.workers = {}
log_queue = Queue()
logger_context = WorkerContext("logger", None, None, log_queue, None)
logger.debug("starting log worker")
self.logger = Process(target=logger_init, args=(self.lock, logger_context))
self.logger.start()
# create a pending queue and progress value for each device
for device in devices:
name = device.device
cancel = Value("B", False, lock=self.lock)
progress = Value("I", 0, lock=self.lock)
pending = Queue()
context = WorkerContext(name, cancel, device, pending, progress)
self.pending[name] = pending
self.progress[name] = pending
logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
self.workers[name].start()
logger.debug("testing log worker")
log_queue.put("testing")
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.finished:
if key == k:
return (True, progress)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
pending = [
self.pending[d.device].qsize() for d in self.devices
]
jobs = Counter(range(len(self.devices)))
jobs.update(pending)
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit:]
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
)
device = self.devices[device_idx]
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
device.device,
self.pending[device.device].qsize(),
self.workers[device.device].is_alive(),
)
for device in self.devices
]
pending.extend(self.finished)
return pending

View File

@ -0,0 +1,32 @@
from logging import getLogger
from torch.multiprocessing import Lock
from time import sleep
from .context import WorkerContext
logger = getLogger(__name__)
def logger_init(lock: Lock, context: WorkerContext):
logger.info("checking in from logger")
with open("worker.log", "w") as f:
while True:
if context.pending.empty():
logger.info("no logs, sleeping")
sleep(5)
else:
job = context.pending.get()
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
def worker_init(lock: Lock, context: WorkerContext):
logger.info("checking in from worker")
while True:
if context.pending.empty():
logger.info("no jobs, sleeping")
sleep(5)
else:
job = context.pending.get()
logger.info("got job: %s", job)