diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index b4ff756e..4f65eb54 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -27,7 +27,7 @@ def buffer_external_data_tensors( for tensor in model.graph.initializer: name = tensor.name - logger.debug("externalizing tensor: %s", name) + logger.trace("externalizing tensor: %s", name) if tensor.HasField("raw_data"): npt = numpy_helper.to_array(tensor) orv = OrtValue.ortvalue_from_numpy(npt) @@ -74,7 +74,7 @@ def blend_loras( blended: Dict[str, np.ndarray] = {} for (lora_name, lora_weight), lora_model in zip(loras, lora_models): - logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) + logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight) if lora_model is None: logger.warning("unable to load tensor for LoRA") continue diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 97af0ae0..21c5d7a9 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -10,6 +10,7 @@ from ..params import DeviceParams from ..server import ServerContext from .command import JobCommand, ProgressCommand from .context import WorkerContext +from .utils import Interval from .worker import worker_main logger = getLogger(__name__) @@ -18,17 +19,24 @@ logger = getLogger(__name__) class DevicePoolExecutor: server: ServerContext devices: List[DeviceParams] + + join_timeout: float max_jobs_per_worker: int max_pending_per_worker: int - join_timeout: float + progress_interval: float + recycle_interval: float leaking: List[Tuple[str, Process]] context: Dict[str, WorkerContext] # Device -> Context current: Dict[str, "Value[int]"] # Device -> pid pending: Dict[str, "Queue[JobCommand]"] - threads: Dict[str, Thread] + progress: Dict[str, "Queue[ProgressCommand]"] workers: Dict[str, Process] + health_worker: Interval + logger_worker: Thread + progress_worker: Interval + cancelled_jobs: List[str] finished_jobs: List[ProgressCommand] pending_jobs: List[JobCommand] @@ -36,8 +44,7 @@ class DevicePoolExecutor: total_jobs: Dict[str, int] # Device -> job count logs: "Queue[str]" - progress: "Queue[ProgressCommand]" - recycle: Lock + rlock: Lock def __init__( self, @@ -45,18 +52,23 @@ class DevicePoolExecutor: devices: List[DeviceParams], max_pending_per_worker: int = 100, join_timeout: float = 1.0, + recycle_interval: float = 10, + progress_interval: float = 1.0, ): self.server = server self.devices = devices + + self.join_timeout = join_timeout self.max_jobs_per_worker = server.job_limit self.max_pending_per_worker = max_pending_per_worker - self.join_timeout = join_timeout + self.progress_interval = progress_interval + self.recycle_interval = recycle_interval self.leaking = [] self.context = {} self.current = {} self.pending = {} - self.threads = {} + self.progress = {} self.workers = {} self.cancelled_jobs = [] @@ -66,10 +78,10 @@ class DevicePoolExecutor: self.total_jobs = {} self.logs = Queue(self.max_pending_per_worker) - self.progress = Queue(self.max_pending_per_worker) - self.recycle = Lock() + self.rlock = Lock() # TODO: these should be part of a start method + self.create_health_worker() self.create_logger_worker() self.create_progress_worker() @@ -79,14 +91,9 @@ class DevicePoolExecutor: def create_device_worker(self, device: DeviceParams) -> None: name = device.device - # reuse the queue if possible, to keep queued jobs - if name in self.pending: - logger.debug("using existing pending job queue") - pending = self.pending[name] - else: - logger.debug("creating new pending job queue") - pending = Queue(self.max_pending_per_worker) - self.pending[name] = pending + # always recreate queues + self.progress[name] = Queue(self.max_pending_per_worker) + self.pending[name] = Queue(self.max_pending_per_worker) if name in self.current: logger.debug("using existing current worker value") @@ -100,9 +107,9 @@ class DevicePoolExecutor: name, device, cancel=Value("B", False), - progress=self.progress, + progress=self.progress[name], logs=self.logs, - pending=pending, + pending=self.pending[name], active_pid=current, ) self.context[name] = context @@ -117,8 +124,15 @@ class DevicePoolExecutor: self.workers[name] = worker current.value = worker.pid + def create_health_worker(self) -> None: + self.health_worker = Interval(self.recycle_interval, health_main, args=(self,)) + self.health_worker.daemon = True + self.health_worker.name = "onnx-web health" + logger.debug("starting health worker") + self.health_worker.start() + def create_logger_worker(self) -> None: - logger_thread = Thread( + self.logger_worker = Thread( name="onnx-web logger", target=logger_main, args=( @@ -127,25 +141,18 @@ class DevicePoolExecutor: ), daemon=True, ) - self.threads["logger"] = logger_thread logger.debug("starting logger worker") - logger_thread.start() + self.logger_worker.start() def create_progress_worker(self) -> None: - progress_thread = Thread( - name="onnx-web progress", - target=progress_main, - args=( - self, - self.progress, - ), - daemon=True, + self.progress_worker = Interval( + self.progress_interval, progress_main, args=(self, self.progress) ) - self.threads["progress"] = progress_thread - + self.progress_worker.daemon = True + self.progress_worker.name = "onnx-web progress" logger.debug("starting progress worker") - progress_thread.start() + self.progress_worker.start() def get_job_context(self, key: str) -> WorkerContext: device, _progress = self.running_jobs[key] @@ -225,10 +232,11 @@ class DevicePoolExecutor: def join(self): logger.info("stopping worker pool") - with self.recycle: + with self.rlock: logger.debug("closing queues") self.logs.close() - self.progress.close() + for queue in self.progress.values(): + queue.close() for queue in self.pending.values(): queue.close() @@ -250,9 +258,14 @@ class DevicePoolExecutor: else: logger.debug("worker for device %s has died", device) - for name, thread in self.threads.items(): - logger.debug("stopping worker %s for thread %s", thread.ident, name) - thread.join(self.join_timeout) + logger.debug("stopping progress worker") + self.progress_worker.join(self.join_timeout) + + logger.debug("stopping logger worker") + self.logger_worker.join(self.join_timeout) + + logger.debug("stopping health worker") + self.health_worker.join(self.join_timeout) logger.debug("worker pool stopped") @@ -276,7 +289,7 @@ class DevicePoolExecutor: def recycle(self): logger.debug("recycling worker pool") - with self.recycle: + with self.rlock: self.join_leaking() needs_restart = [] @@ -318,22 +331,14 @@ class DevicePoolExecutor: self.create_device_worker(device) self.total_jobs[device.device] = 0 - if self.threads["logger"].is_alive(): + if self.logger_worker.is_alive(): logger.debug("logger worker is running") - if self.logs.full(): - logger.warning("logger queue is full, restarting worker") - self.threads["logger"].join(self.join_timeout) - self.create_logger_worker() else: logger.warning("restarting logger worker") self.create_logger_worker() - if self.threads["progress"].is_alive(): + if self.progress_worker.is_alive(): logger.debug("progress worker is running") - if self.progress.full(): - logger.warning("progress queue is full, restarting worker") - self.threads["progress"].join(self.join_timeout) - self.create_progress_worker() else: logger.warning("restarting progress worker") self.create_progress_worker() @@ -366,7 +371,7 @@ class DevicePoolExecutor: # recycle before attempting to run logger.debug("job count for device %s: %s", device, self.total_jobs[device]) - self.recycle() + self.rlock() # build and queue job job = JobCommand(key, device, fn, args, kwargs) @@ -443,15 +448,28 @@ class DevicePoolExecutor: self.context[progress.device].set_cancel() -def logger_main(pool: DevicePoolExecutor, logs: Queue): +def health_main(pool: DevicePoolExecutor): + logger.trace("checking in from health worker thread") + pool.recycle() + + if pool.logs.full(): + logger.warning("logger queue is full, restarting worker") + pool.logger_worker.join(pool.join_timeout) + pool.create_logger_worker() + + if any([queue.full() for queue in pool.progress.values()]): + logger.warning("progress queue is full, restarting worker") + pool.progress_worker.join(pool.join_timeout) + pool.create_progress_worker() + + +def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"): logger.trace("checking in from logger worker thread") while True: try: - job = logs.get(timeout=(pool.join_timeout / 2)) - with open("worker.log", "w") as f: - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") + msg = logs.get(pool.join_timeout / 2) + logger.debug("received logs from worker: %s", msg) except Empty: pass except ValueError: @@ -460,15 +478,21 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue): logger.exception("error in log worker") -def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"): +def progress_main( + pool: DevicePoolExecutor, queues: Dict[str, "Queue[ProgressCommand]"] +): logger.trace("checking in from progress worker thread") - while True: + for device, queue in queues.items(): try: - progress = queue.get(timeout=(pool.join_timeout / 2)) - pool.update_job(progress) + progress = queue.get_nowait() + while progress is not None: + pool.update_job(progress) + progress = queue.get_nowait() except Empty: + logger.trace("empty queue in progress worker for device %s", device) pass - except ValueError: + except ValueError as e: + logger.debug("value error in progress worker for device %s: %s", device, e) break except Exception: - logger.exception("error in progress worker") + logger.exception("error in progress worker for device %s", device) diff --git a/api/onnx_web/worker/utils.py b/api/onnx_web/worker/utils.py new file mode 100644 index 00000000..4b1e5737 --- /dev/null +++ b/api/onnx_web/worker/utils.py @@ -0,0 +1,11 @@ +from threading import Timer + + +class Interval(Timer): + """ + From https://stackoverflow.com/a/48741004 + """ + + def run(self): + while not self.finished.wait(self.interval): + self.function(*self.args, **self.kwargs)