From 525ee24e916e404449929ee72752568ac66a3487 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:09:42 -0600 Subject: [PATCH] track started and finished jobs --- api/onnx_web/chain/blend_mask.py | 5 +- api/onnx_web/diffusion/run.py | 6 +- api/onnx_web/server/config.py | 10 ++- api/onnx_web/server/utils.py | 8 +-- api/onnx_web/worker/context.py | 21 +++--- api/onnx_web/worker/pool.py | 113 ++++++++++++++++++++++--------- api/onnx_web/worker/worker.py | 12 ++-- 7 files changed, 116 insertions(+), 59 deletions(-) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index f7b68e6f..5c53bd12 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -3,9 +3,8 @@ from typing import List, Optional from PIL import Image -from onnx_web.image import valid_image -from onnx_web.output import save_image - +from ..image import valid_image +from ..output import save_image from ..params import ImageParams, StageParams from ..server import ServerContext from ..utils import is_debug diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 2e92294c..765d7de8 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -6,10 +6,8 @@ import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image -from onnx_web.chain import blend_mask -from onnx_web.chain.base import ChainProgress - -from ..chain import upscale_outpaint +from ..chain import blend_mask, upscale_outpaint +from ..chain.base import ChainProgress from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams from ..server import ServerContext diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index a5dc1a31..c0392104 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None: ) diffusion_models = list(set(diffusion_models)) diffusion_models.sort() + logger.debug("loaded diffusion models from disk: %s", diffusion_models) correction_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) ] correction_models = list(set(correction_models)) correction_models.sort() + logger.debug("loaded correction models from disk: %s", correction_models) inversion_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) ] inversion_models = list(set(inversion_models)) inversion_models.sort() + logger.debug("loaded inversion models from disk: %s", inversion_models) upscaling_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) ] upscaling_models = list(set(upscaling_models)) upscaling_models.sort() + logger.debug("loaded upscaling models from disk: %s", upscaling_models) def load_params(context: ServerContext) -> None: global config_params + params_file = path.join(context.params_path, "params.json") + logger.debug("loading server parameters from file: %s", params_file) + with open(params_file, "r") as f: config_params = yaml.safe_load(f) if "platform" in config_params and context.default_platform is not None: logger.info( - "Overriding default platform from environment: %s", + "overriding default platform from environment: %s", context.default_platform, ) config_platform = config_params.get("platform", {}) @@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None: global available_platforms providers = list(get_available_providers()) + logger.debug("loading available platforms from providers: %s", providers) for potential in platform_providers: if ( diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py index 582b1c6a..56cc33e0 100644 --- a/api/onnx_web/server/utils.py +++ b/api/onnx_web/server/utils.py @@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Tuple from flask import Flask -from onnx_web.utils import base_join -from onnx_web.worker.pool import DevicePoolExecutor - +from ..utils import base_join +from ..worker.pool import DevicePoolExecutor from .context import ServerContext @@ -28,7 +27,8 @@ def register_routes( pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]], ): - pass + for route, kwargs, method in routes: + app.route(route, **kwargs)(wrap_route(method, context, pool=pool)) def wrap_route(func, *args, **kwargs): diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 056a2517..67bc4eb6 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,18 +22,20 @@ class WorkerContext: key: str, device: DeviceParams, cancel: "Value[bool]" = None, - finished: "Value[bool]" = None, progress: "Value[int]" = None, + finished: "Queue[str]" = None, logs: "Queue[str]" = None, pending: "Queue[Any]" = None, + started: "Queue[Tuple[str, str]]" = None, ): self.key = key - self.cancel = cancel self.device = device - self.pending = pending + self.cancel = cancel self.progress = progress - self.logs = logs self.finished = finished + self.logs = logs + self.pending = pending + self.started = started def is_cancelled(self) -> bool: return self.cancel.value @@ -62,15 +64,16 @@ class WorkerContext: with self.cancel.get_lock(): self.cancel.value = cancel - def set_finished(self, finished: bool = True) -> None: - with self.finished.get_lock(): - self.finished.value = finished - def set_progress(self, progress: int) -> None: with self.progress.get_lock(): self.progress.value = progress + def put_finished(self, job: str) -> None: + self.finished.put((job, self.device.device)) + + def put_started(self, job: str) -> None: + self.started.put((job, self.device.device)) + def clear_flags(self) -> None: self.set_cancel(False) - self.set_finished(False) self.set_progress(0) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ec5ac480..6b4bfd96 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,9 +1,9 @@ from collections import Counter from logging import getLogger -from multiprocessing import Queue +from threading import Thread from typing import Callable, Dict, List, Optional, Tuple -from torch.multiprocessing import Process, Value +from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext @@ -14,12 +14,12 @@ logger = getLogger(__name__) class DevicePoolExecutor: - context: Dict[str, WorkerContext] = None + context: Dict[str, WorkerContext] = None # Device -> Context devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_job: Dict[str, str] = None - finished: List[Tuple[str, int, bool]] = None + active_jobs: Dict[str, str] = None + finished_jobs: List[Tuple[str, int, bool]] = None def __init__( self, @@ -36,11 +36,15 @@ class DevicePoolExecutor: self.context = {} self.pending = {} self.workers = {} - self.active_job = {} - self.finished = [] - self.finished_jobs = 0 # TODO: turn this into a Dict per-worker + self.active_jobs = {} + self.finished_jobs = [] + self.total_jobs = 0 # TODO: turn this into a Dict per-worker + + self.started = Queue() + self.finished = Queue() self.create_logger_worker() + self.create_queue_workers() for device in devices: self.create_device_worker(device) @@ -56,16 +60,23 @@ class DevicePoolExecutor: def create_device_worker(self, device: DeviceParams) -> None: name = device.device - pending = Queue() - self.pending[name] = pending + + # reuse the queue if possible, to keep queued jobs + if name in self.pending: + pending = self.pending[name] + else: + pending = Queue() + self.pending[name] = pending + context = WorkerContext( name, device, cancel=Value("B", False), - finished=Value("B", False), progress=Value("I", 0), - pending=pending, + finished=self.finished, logs=self.log_queue, + pending=pending, + started=self.started, ) self.context[name] = context self.workers[name] = Process(target=worker_init, args=(context, self.server)) @@ -73,9 +84,32 @@ class DevicePoolExecutor: logger.debug("starting worker for device %s", device) self.workers[name].start() - def create_prune_worker(self) -> None: - # TODO: create a background thread to prune completed jobs - pass + def create_queue_workers(self) -> None: + def started_worker(pending: Queue): + logger.info("checking in from started thread") + while True: + job, device = pending.get() + logger.info("job has been started: %s", job) + self.active_jobs[device] = job + + def finished_worker(finished: Queue): + logger.info("checking in from finished thread") + while True: + job, device = finished.get() + logger.info("job has been finished: %s", job) + context = self.get_job_context(job) + self.finished_jobs.append( + (job, context.progress.value, context.cancel.value) + ) + + self.started_thread = Thread(target=started_worker, args=(self.started,)) + self.started_thread.start() + self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) + self.finished_thread.start() + + def get_job_context(self, key: str) -> WorkerContext: + device = self.active_jobs[key] + return self.context[device] def cancel(self, key: str) -> bool: """ @@ -83,11 +117,11 @@ class DevicePoolExecutor: the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ - if key not in self.active_job: + if key not in self.active_jobs: logger.warn("attempting to cancel unknown job: %s", key) return False - device = self.active_job[key] + device = self.active_jobs[key] context = self.context[device] logger.info("cancelling job %s on device %s", key, device) @@ -98,19 +132,17 @@ class DevicePoolExecutor: return True def done(self, key: str) -> Tuple[Optional[bool], int]: - if key not in self.active_job: + for k, p, c in self.finished_jobs: + if k == key: + return (c, p) + + if key not in self.active_jobs: logger.warn("checking status for unknown job: %s", key) return (None, 0) # TODO: prune here, maybe? - - device = self.active_job[key] - context = self.context[device] - - if context.finished.value is True: - self.finished.append((key, context.progress.value, context.cancel.value)) - - return (context.finished.value, context.progress.value) + context = self.get_job_context(key) + return (False, context.progress.value) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -132,6 +164,9 @@ class DevicePoolExecutor: return lowest_devices[0] def join(self): + self.started_thread.join(self.join_timeout) + self.finished_thread.join(self.join_timeout) + for device, worker in self.workers.items(): if worker.is_alive(): logger.info("stopping worker for device %s", device) @@ -166,11 +201,11 @@ class DevicePoolExecutor: needs_device: Optional[DeviceParams] = None, **kwargs, ) -> None: - self.finished_jobs += 1 - logger.debug("pool job count: %s", self.finished_jobs) - if self.finished_jobs > self.max_jobs_per_worker: + self.total_jobs += 1 + logger.debug("pool job count: %s", self.total_jobs) + if self.total_jobs > self.max_jobs_per_worker: self.recycle() - self.finished_jobs = 0 + self.total_jobs = 0 device_idx = self.get_next_device(needs_device=needs_device) logger.info( @@ -184,7 +219,7 @@ class DevicePoolExecutor: queue = self.pending[device.device] queue.put((fn, args, kwargs)) - self.active_job[key] = device.device + self.active_jobs[key] = device.device def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ @@ -193,10 +228,22 @@ class DevicePoolExecutor: self.workers[name].is_alive(), context.pending.qsize(), context.cancel.value, - context.finished.value, + False, context.progress.value, ) for name, context in self.context.items() ] - pending.extend(self.finished) + pending.extend( + [ + ( + name, + False, + 0, + cancel, + True, + progress, + ) + for name, progress, cancel in self.finished_jobs + ] + ) return pending diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 24a1c4f2..db23540f 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext): while True: job = context.pending.get() logger.info("got job: %s", job) - try: - fn, args, kwargs = job - name = args[3][0] - logger.info("starting job: %s", name) + fn, args, kwargs = job + name = args[3][0] + + try: context.clear_flags() + logger.info("starting job: %s", name) + context.put_started(name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) except Exception as e: @@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext): format_exception(type(e), e, e.__traceback__), ) finally: - context.set_finished() + context.put_finished(name) logger.info("finished job: %s", name)