1
0
Fork 0

run logger in a thread, clean up status

This commit is contained in:
Sean Sube 2023-02-27 17:14:53 -06:00
parent 13395933dc
commit 66a20e60fe
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 121 additions and 96 deletions

View File

@ -402,7 +402,7 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
def txt2txt(context: ServerContext, pool: DevicePoolExecutor): def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(context)
output = make_output_name(context, "upscale", params, size) output = make_output_name(context, "txt2txt", params, size)
logger.info("upscale job queued for: %s", output) logger.info("upscale job queued for: %s", output)
pool.submit( pool.submit(

View File

@ -71,3 +71,17 @@ class WorkerContext:
def clear_flags(self) -> None: def clear_flags(self) -> None:
self.set_cancel(False) self.set_cancel(False)
self.set_progress(0) self.set_progress(0)
class JobStatus:
def __init__(
self,
name: str,
progress: int = 0,
cancelled: bool = False,
finished: bool = False,
) -> None:
self.name = name
self.progress = progress
self.cancelled = cancelled
self.finished = finished

View File

@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .context import WorkerContext from .context import WorkerContext
from .worker import logger_init, worker_init from .worker import worker_main
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,15 +18,15 @@ class DevicePoolExecutor:
devices: List[DeviceParams] = None devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None workers: Dict[str, Process] = None
active_jobs: Dict[str, Tuple[str, int]] = None active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus]
finished_jobs: List[Tuple[str, int, bool]] = None finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus]
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
devices: List[DeviceParams], devices: List[DeviceParams],
max_jobs_per_worker: int = 10, max_jobs_per_worker: int = 10,
join_timeout: float = 5.0, join_timeout: float = 1.0,
): ):
self.server = server self.server = server
self.devices = devices self.devices = devices
@ -35,28 +35,27 @@ class DevicePoolExecutor:
self.context = {} self.context = {}
self.pending = {} self.pending = {}
self.threads = {}
self.workers = {} self.workers = {}
self.active_jobs = {} self.active_jobs = {}
self.cancelled_jobs = []
self.finished_jobs = [] self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.logs = Queue()
self.progress = Queue() self.progress = Queue()
self.finished = Queue() self.finished = Queue()
self.create_logger_worker() self.create_logger_worker()
self.create_queue_workers() self.create_progress_worker()
self.create_finished_worker()
for device in devices: for device in devices:
self.create_device_worker(device) self.create_device_worker(device)
logger.debug("testing log worker") logger.debug("testing log worker")
self.log_queue.put("testing") self.logs.put("testing")
def create_logger_worker(self) -> None:
self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.log_queue,))
logger.debug("starting log worker")
self.logger.start()
def create_device_worker(self, device: DeviceParams) -> None: def create_device_worker(self, device: DeviceParams) -> None:
name = device.device name = device.device
@ -74,23 +73,54 @@ class DevicePoolExecutor:
cancel=Value("B", False), cancel=Value("B", False),
progress=self.progress, progress=self.progress,
finished=self.finished, finished=self.finished,
logs=self.log_queue, logs=self.logs,
pending=pending, pending=pending,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server)) self.workers[name] = Process(target=worker_main, args=(context, self.server))
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name].start() self.workers[name].start()
def create_queue_workers(self) -> None: def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):
logger.info("checking in from logger worker thread")
while True:
job = logs.get()
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
logger_thread = Thread(target=logger_worker, args=(self.logs,))
self.threads["logger"] = logger_thread
logger.debug("starting logger worker")
logger_thread.start()
def create_progress_worker(self) -> None:
def progress_worker(progress: Queue): def progress_worker(progress: Queue):
logger.info("checking in from progress worker thread") logger.info("checking in from progress worker thread")
while True: while True:
job, device, value = progress.get() try:
logger.info("progress update for job: %s, %s", job, value) job, device, value = progress.get()
self.active_jobs[job] = (device, value) logger.info("progress update for job: %s to %s", job, value)
self.active_jobs[job] = (device, value)
if job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", job, device
)
self.context[device].set_cancel()
except Exception as err:
logger.error("error during progress update", err)
progress_thread = Thread(target=progress_worker, args=(self.progress,))
self.threads["progress"] = progress_thread
logger.debug("starting progress worker")
progress_thread.start()
def create_finished_worker(self) -> None:
def finished_worker(finished: Queue): def finished_worker(finished: Queue):
logger.info("checking in from finished worker thread") logger.info("checking in from finished worker thread")
while True: while True:
@ -98,53 +128,19 @@ class DevicePoolExecutor:
logger.info("job has been finished: %s", job) logger.info("job has been finished: %s", job)
context = self.context[device] context = self.context[device]
_device, progress = self.active_jobs[job] _device, progress = self.active_jobs[job]
self.finished_jobs.append( self.finished_jobs.append((job, progress, context.cancel.value))
(job, progress, context.cancel.value)
)
del self.active_jobs[job] del self.active_jobs[job]
self.progress_thread = Thread(target=progress_worker, args=(self.progress,)) finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.progress_thread.start() self.thread["finished"] = finished_thread
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.finished_thread.start() logger.debug("started finished worker")
finished_thread.start()
def get_job_context(self, key: str) -> WorkerContext: def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.active_jobs[key] device, _progress = self.active_jobs[key]
return self.context[device] return self.context[device]
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
if key not in self.active_jobs:
logger.warn("attempting to cancel unknown job: %s", key)
return False
device, _progress = self.active_jobs[key]
context = self.context[device]
logger.info("cancelling job %s on device %s", key, device)
if context.cancel.get_lock():
context.cancel.value = True
# self.finished.append((key, context.progress.value, context.cancel.value)) maybe?
return True
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, p, c in self.finished_jobs:
if k == key:
return (True, p)
if key not in self.active_jobs:
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
# TODO: prune here, maybe?
_device, progress = self.active_jobs[key]
return (False, progress)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible # respect overrides if possible
if needs_device is not None: if needs_device is not None:
@ -164,6 +160,45 @@ class DevicePoolExecutor:
return lowest_devices[0] return lowest_devices[0]
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
self.cancelled_jobs.append(key)
if key not in self.active_jobs:
logger.debug("cancelled job has not been started yet: %s", key)
return True
device, _progress = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, device)
context = self.context[device]
context.set_cancel()
return True
def done(self, key: str) -> Tuple[Optional[bool], int]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
"""
for k, p, c in self.finished_jobs:
if k == key:
return (True, p)
if key not in self.active_jobs:
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
_device, progress = self.active_jobs[key]
return (False, progress)
def join(self): def join(self):
self.progress_thread.join(self.join_timeout) self.progress_thread.join(self.join_timeout)
self.finished_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout)
@ -216,35 +251,23 @@ class DevicePoolExecutor:
self.devices[device_idx], self.devices[device_idx],
) )
device = self.devices[device_idx] device = self.devices[device_idx].device
queue = self.pending[device.device] self.pending[device].put((fn, args, kwargs))
queue.put((fn, args, kwargs))
self.active_jobs[key] = (device.device, 0) def status(self) -> List[Tuple[str, int, bool, bool]]:
history = [
def status(self) -> List[Tuple[str, int, bool, int]]: (name, progress, False, name in self.cancelled_jobs)
pending = [ for name, _device, progress in self.active_jobs.items()
(
name,
self.workers[name].is_alive(),
self.context[device].pending.qsize(),
self.context[device].cancel.value,
False,
progress,
)
for name, device, progress in self.active_jobs.items()
] ]
pending.extend( history.extend(
[ [
( (
name, name,
False,
0,
cancel,
True,
progress, progress,
True,
cancel,
) )
for name, progress, cancel in self.finished_jobs for name, progress, cancel in self.finished_jobs
] ]
) )
return pending return history

View File

@ -11,19 +11,7 @@ from .context import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def logger_init(logs: Queue): def worker_main(context: WorkerContext, server: ServerContext):
setproctitle("onnx-web logger")
logger.info("checking in from logger")
while True:
job = logs.get()
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
def worker_init(context: WorkerContext, server: ServerContext):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (context.device.device))
@ -37,7 +25,7 @@ def worker_init(context: WorkerContext, server: ServerContext):
name = args[3][0] name = args[3][0]
try: try:
context.key = name # TODO: hax context.key = name # TODO: hax
context.clear_flags() context.clear_flags()
logger.info("starting job: %s", name) logger.info("starting job: %s", name)
fn(context, *args, **kwargs) fn(context, *args, **kwargs)