From f898de8c5490673f60bc91129f8fedb07e7495f6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 25 Feb 2023 23:49:39 -0600 Subject: [PATCH] background workers, logger --- api/logging.yaml | 6 +- api/onnx_web/__init__.py | 6 +- api/onnx_web/chain/base.py | 7 +- api/onnx_web/chain/blend_img2img.py | 5 +- api/onnx_web/chain/blend_inpaint.py | 5 +- api/onnx_web/chain/blend_mask.py | 5 +- api/onnx_web/chain/correct_codeformer.py | 5 +- api/onnx_web/chain/correct_gfpgan.py | 6 +- api/onnx_web/chain/persist_disk.py | 6 +- api/onnx_web/chain/persist_s3.py | 6 +- api/onnx_web/chain/reduce_crop.py | 6 +- api/onnx_web/chain/reduce_thumbnail.py | 6 +- api/onnx_web/chain/source_noise.py | 6 +- api/onnx_web/chain/source_txt2img.py | 6 +- api/onnx_web/chain/upscale_outpaint.py | 5 +- api/onnx_web/chain/upscale_resrgan.py | 5 +- .../chain/upscale_stable_diffusion.py | 5 +- api/onnx_web/diffusion/run.py | 13 +- api/onnx_web/serve.py | 3 +- api/onnx_web/server/__init__.py | 7 - api/onnx_web/server/device_pool.py | 230 ------------------ api/onnx_web/transformers.py | 5 +- api/onnx_web/upscale.py | 5 +- api/onnx_web/worker/__init__.py | 2 + api/onnx_web/worker/context.py | 60 +++++ api/onnx_web/worker/logging.py | 1 + api/onnx_web/worker/pool.py | 136 +++++++++++ api/onnx_web/worker/worker.py | 32 +++ 28 files changed, 306 insertions(+), 284 deletions(-) delete mode 100644 api/onnx_web/server/device_pool.py create mode 100644 api/onnx_web/worker/__init__.py create mode 100644 api/onnx_web/worker/context.py create mode 100644 api/onnx_web/worker/logging.py create mode 100644 api/onnx_web/worker/pool.py create mode 100644 api/onnx_web/worker/worker.py diff --git a/api/logging.yaml b/api/logging.yaml index 0bf54310..24bd3c29 100644 --- a/api/logging.yaml +++ b/api/logging.yaml @@ -5,14 +5,14 @@ formatters: handlers: console: class: logging.StreamHandler - level: INFO + level: DEBUG formatter: simple stream: ext://sys.stdout loggers: '': - level: INFO + level: DEBUG handlers: [console] propagate: True root: - level: INFO + level: DEBUG handlers: [console] diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 3afedb07..7316bb87 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -25,6 +25,7 @@ from .image import ( from .onnx import OnnxNet, OnnxTensor from .params import ( Border, + DeviceParams, ImageParams, Param, Point, @@ -33,8 +34,6 @@ from .params import ( UpscaleParams, ) from .server import ( - DeviceParams, - DevicePoolExecutor, ModelCache, ServerContext, apply_patch_basicsr, @@ -51,3 +50,6 @@ from .utils import ( get_from_map, get_not_empty, ) +from .worker import ( + DevicePoolExecutor, +) \ No newline at end of file diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 7e77bab7..dc807322 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,7 +7,8 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_order @@ -17,7 +18,7 @@ logger = getLogger(__name__) class StageCallback(Protocol): def __call__( self, - job: JobContext, + job: WorkerContext, ctx: ServerContext, stage: StageParams, params: ImageParams, @@ -77,7 +78,7 @@ class ChainPipeline: def __call__( self, - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, source: Image.Image, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0ef9ef96..67531103 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -7,13 +7,14 @@ from PIL import Image from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext logger = getLogger(__name__) def blend_img2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 16422cce..51b6d983 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_order @@ -18,7 +19,7 @@ logger = getLogger(__name__) def blend_inpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index be5d1159..bfe11aab 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -7,14 +7,15 @@ from onnx_web.image import valid_image from onnx_web.output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug logger = getLogger(__name__) def blend_mask( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 6b4e235d..01d61db7 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -3,7 +3,8 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext logger = getLogger(__name__) @@ -11,7 +12,7 @@ device = "cpu" def correct_codeformer( - job: JobContext, + job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 6796c1ba..afcae86b 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,8 +5,10 @@ import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext from ..utils import run_gc +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) @@ -46,7 +48,7 @@ def load_gfpgan( def correct_gfpgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 9a5f0cd0..58020b57 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,13 +4,15 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def persist_disk( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index bf3682bf..926f1598 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,13 +5,15 @@ from boto3 import Session from PIL import Image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def persist_s3( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 226f6cf2..4cd715b1 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -3,13 +3,15 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def reduce_crop( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 0037084c..4950a973 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,13 +3,15 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def reduce_thumbnail( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index f6267b26..9ab302b1 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,13 +4,15 @@ from typing import Callable from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def source_noise( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index a5cbb07f..b933ecc9 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -7,13 +7,15 @@ from PIL import Image from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext + logger = getLogger(__name__) def source_txt2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 7919e325..23393491 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_grid, process_tile_order @@ -18,7 +19,7 @@ logger = getLogger(__name__) def upscale_outpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 178360c3..ccbb3644 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -6,7 +6,8 @@ from PIL import Image from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext from ..utils import run_gc logger = getLogger(__name__) @@ -96,7 +97,7 @@ def load_resrgan( def upscale_resrgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 5747f13f..00c1b9d4 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,7 +10,8 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import run_gc logger = getLogger(__name__) @@ -62,7 +63,7 @@ def load_stable_diffusion( def upscale_stable_diffusion( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index ec988d66..0e44b6cc 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -12,7 +12,8 @@ from onnx_web.chain.base import ChainProgress from ..chain import upscale_outpaint from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc from .load import get_latents_from_seed, load_pipeline @@ -21,7 +22,7 @@ logger = getLogger(__name__) def run_txt2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -95,7 +96,7 @@ def run_txt2img_pipeline( def run_img2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, outputs: List[str], @@ -167,7 +168,7 @@ def run_img2img_pipeline( def run_inpaint_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -217,7 +218,7 @@ def run_inpaint_pipeline( def run_upscale_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -243,7 +244,7 @@ def run_upscale_pipeline( def run_blend_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index fb95d8ef..4d20f296 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -63,7 +63,7 @@ from .params import ( TileOrder, UpscaleParams, ) -from .server import DevicePoolExecutor, ServerContext, apply_patches +from .server import ServerContext, apply_patches from .transformers import run_txt2txt_pipeline from .utils import ( base_join, @@ -75,6 +75,7 @@ from .utils import ( get_size, is_debug, ) +from .worker import DevicePoolExecutor logger = getLogger(__name__) diff --git a/api/onnx_web/server/__init__.py b/api/onnx_web/server/__init__.py index 0403746c..f02fa35a 100644 --- a/api/onnx_web/server/__init__.py +++ b/api/onnx_web/server/__init__.py @@ -1,10 +1,3 @@ -from .device_pool import ( - DeviceParams, - DevicePoolExecutor, - Job, - JobContext, - ProgressCallback, -) from .hacks import ( apply_patch_basicsr, apply_patch_codeformer, diff --git a/api/onnx_web/server/device_pool.py b/api/onnx_web/server/device_pool.py deleted file mode 100644 index 152e1d74..00000000 --- a/api/onnx_web/server/device_pool.py +++ /dev/null @@ -1,230 +0,0 @@ -from collections import Counter -from concurrent.futures import Future -from logging import getLogger -from multiprocessing import Queue -from torch.multiprocessing import Lock, Process, SimpleQueue, Value -from traceback import format_exception -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from time import sleep - -from ..params import DeviceParams -from ..utils import run_gc - -logger = getLogger(__name__) - -ProgressCallback = Callable[[int, int, Any], None] - - -def worker_init(lock: Lock, job_queue: SimpleQueue): - logger.info("checking in from worker") - - while True: - if job_queue.empty(): - logger.info("no jobs, sleeping") - sleep(5) - else: - job = job_queue.get() - logger.info("got job: %s", job) - - -class JobContext: - cancel: Value = None - device_index: Value = None - devices: List[DeviceParams] = None - key: str = None - progress: Value = None - - def __init__( - self, - key: str, - devices: List[DeviceParams], - cancel: bool = False, - device_index: int = -1, - progress: int = 0, - ): - self.key = key - self.devices = list(devices) - self.cancel = Value("B", cancel) - self.device_index = Value("i", device_index) - self.progress = Value("I", progress) - - def is_cancelled(self) -> bool: - return self.cancel.value - - def get_device(self) -> DeviceParams: - """ - Get the device assigned to this job. - """ - with self.device_index.get_lock(): - device_index = self.device_index.value - if device_index < 0: - raise ValueError("job has not been assigned to a device") - else: - device = self.devices[device_index] - logger.debug("job %s assigned to device %s", self.key, device) - return device - - def get_progress(self) -> int: - return self.progress.value - - def get_progress_callback(self) -> ProgressCallback: - def on_progress(step: int, timestep: int, latents: Any): - on_progress.step = step - if self.is_cancelled(): - raise RuntimeError("job has been cancelled") - else: - logger.debug("setting progress for job %s to %s", self.key, step) - self.set_progress(step) - - return on_progress - - def set_cancel(self, cancel: bool = True) -> None: - with self.cancel.get_lock(): - self.cancel.value = cancel - - def set_progress(self, progress: int) -> None: - with self.progress.get_lock(): - self.progress.value = progress - - -class Job: - """ - Link a future to its context. - """ - - context: JobContext = None - future: Future = None - key: str = None - - def __init__( - self, - key: str, - future: Future, - context: JobContext, - ): - self.context = context - self.future = future - self.key = key - - def get_progress(self) -> int: - return self.context.get_progress() - - def set_cancel(self, cancel: bool = True): - return self.context.set_cancel(cancel) - - def set_progress(self, progress: int): - return self.context.set_progress(progress) - - -class DevicePoolExecutor: - devices: List[DeviceParams] = None - finished: List[Tuple[str, int]] = None - pending: Dict[str, "Queue[Job]"] = None - progress: Dict[str, Value] = None - workers: Dict[str, Process] = None - - def __init__( - self, - devices: List[DeviceParams], - finished_limit: int = 10, - ): - self.devices = devices - self.finished = [] - self.finished_limit = finished_limit - self.lock = Lock() - self.pending = {} - self.progress = {} - self.workers = {} - - # create a pending queue and progress value for each device - for device in devices: - name = device.device - job_queue = Queue() - self.pending[name] = job_queue - self.progress[name] = Value("I", 0, lock=self.lock) - self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue)) - - def cancel(self, key: str) -> bool: - """ - Cancel a job. If the job has not been started, this will cancel - the future and never execute it. If the job has been started, it - should be cancelled on the next progress callback. - """ - raise NotImplementedError() - - def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.finished: - if key == k: - return (True, progress) - - logger.warn("checking status for unknown key: %s", key) - return (None, 0) - - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: - # respect overrides if possible - if needs_device is not None: - for i in range(len(self.devices)): - if self.devices[i].device == needs_device.device: - return i - - # use the first/default device if there are no jobs - if len(self.jobs) == 0: - return 0 - - job_devices = [ - job.context.device_index.value for job in self.jobs if not job.future.done() - ] - job_counts = Counter(range(len(self.devices))) - job_counts.update(job_devices) - - queued = job_counts.most_common() - logger.debug("jobs queued by device: %s", queued) - - lowest_count = queued[-1][1] - lowest_devices = [d[0] for d in queued if d[1] == lowest_count] - lowest_devices.sort() - - return lowest_devices[0] - - def prune(self): - finished_count = len(self.finished) - if finished_count > self.finished_limit: - logger.debug( - "pruning %s of %s finished jobs", - finished_count - self.finished_limit, - finished_count, - ) - self.finished[:] = self.finished[-self.finished_limit:] - - def submit( - self, - key: str, - fn: Callable[..., None], - /, - *args, - needs_device: Optional[DeviceParams] = None, - **kwargs, - ) -> None: - self.prune() - device_idx = self.get_next_device(needs_device=needs_device) - logger.info( - "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] - ) - - context = JobContext(key, self.devices, device_index=device_idx) - device = self.devices[device_idx] - - queue = self.pending[device.device] - queue.put((fn, context, args, kwargs)) - - - def status(self) -> List[Tuple[str, int, bool, int]]: - pending = [ - ( - device.device, - self.pending[device.device].qsize(), - ) - for device in self.devices - ] - pending.extend(self.finished) - return pending diff --git a/api/onnx_web/transformers.py b/api/onnx_web/transformers.py index f7a70693..18d90f0a 100644 --- a/api/onnx_web/transformers.py +++ b/api/onnx_web/transformers.py @@ -1,13 +1,14 @@ from logging import getLogger from .params import ImageParams, Size -from .server import JobContext, ServerContext +from .server import ServerContext +from .worker import WorkerContext logger = getLogger(__name__) def run_txt2txt_pipeline( - job: JobContext, + job: WorkerContext, _server: ServerContext, params: ImageParams, _size: Size, diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index c04d6efb..8636f8c1 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -10,13 +10,14 @@ from .chain import ( upscale_stable_diffusion, ) from .params import ImageParams, SizeChart, StageParams, UpscaleParams -from .server import JobContext, ProgressCallback, ServerContext +from .server import ServerContext +from .worker import WorkerContext, ProgressCallback logger = getLogger(__name__) def run_upscale_correction( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/worker/__init__.py b/api/onnx_web/worker/__init__.py new file mode 100644 index 00000000..0ca5eefc --- /dev/null +++ b/api/onnx_web/worker/__init__.py @@ -0,0 +1,2 @@ +from .context import WorkerContext, ProgressCallback +from .pool import DevicePoolExecutor \ No newline at end of file diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py new file mode 100644 index 00000000..bf927f03 --- /dev/null +++ b/api/onnx_web/worker/context.py @@ -0,0 +1,60 @@ +from logging import getLogger +from torch.multiprocessing import Queue, Value +from typing import Any, Callable + +from ..params import DeviceParams + +logger = getLogger(__name__) + + +ProgressCallback = Callable[[int, int, Any], None] + +class WorkerContext: + cancel: "Value[bool]" = None + key: str = None + progress: "Value[int]" = None + + def __init__( + self, + key: str, + cancel: "Value[bool]", + device: DeviceParams, + pending: "Queue[Any]", + progress: "Value[int]", + ): + self.key = key + self.cancel = cancel + self.device = device + self.pending = pending + self.progress = progress + + def is_cancelled(self) -> bool: + return self.cancel.value + + def get_device(self) -> DeviceParams: + """ + Get the device assigned to this job. + """ + return self.device + + def get_progress(self) -> int: + return self.progress.value + + def get_progress_callback(self) -> ProgressCallback: + def on_progress(step: int, timestep: int, latents: Any): + on_progress.step = step + if self.is_cancelled(): + raise RuntimeError("job has been cancelled") + else: + logger.debug("setting progress for job %s to %s", self.key, step) + self.set_progress(step) + + return on_progress + + def set_cancel(self, cancel: bool = True) -> None: + with self.cancel.get_lock(): + self.cancel.value = cancel + + def set_progress(self, progress: int) -> None: + with self.progress.get_lock(): + self.progress.value = progress diff --git a/api/onnx_web/worker/logging.py b/api/onnx_web/worker/logging.py new file mode 100644 index 00000000..39808a64 --- /dev/null +++ b/api/onnx_web/worker/logging.py @@ -0,0 +1 @@ +# TODO: queue-based logger \ No newline at end of file diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py new file mode 100644 index 00000000..7d83427c --- /dev/null +++ b/api/onnx_web/worker/pool.py @@ -0,0 +1,136 @@ +from collections import Counter +from logging import getLogger +from multiprocessing import Queue +from torch.multiprocessing import Lock, Process, Value +from typing import Callable, Dict, List, Optional, Tuple + +from ..params import DeviceParams +from .context import WorkerContext +from .worker import logger_init, worker_init + +logger = getLogger(__name__) + + +class DevicePoolExecutor: + devices: List[DeviceParams] = None + finished: List[Tuple[str, int]] = None + pending: Dict[str, "Queue[WorkerContext]"] = None + progress: Dict[str, Value] = None + workers: Dict[str, Process] = None + + def __init__( + self, + devices: List[DeviceParams], + finished_limit: int = 10, + ): + self.devices = devices + self.finished = [] + self.finished_limit = finished_limit + self.lock = Lock() + self.pending = {} + self.progress = {} + self.workers = {} + + log_queue = Queue() + logger_context = WorkerContext("logger", None, None, log_queue, None) + + logger.debug("starting log worker") + self.logger = Process(target=logger_init, args=(self.lock, logger_context)) + self.logger.start() + + # create a pending queue and progress value for each device + for device in devices: + name = device.device + cancel = Value("B", False, lock=self.lock) + progress = Value("I", 0, lock=self.lock) + pending = Queue() + context = WorkerContext(name, cancel, device, pending, progress) + self.pending[name] = pending + self.progress[name] = pending + + logger.debug("starting worker for device %s", device) + self.workers[name] = Process(target=worker_init, args=(self.lock, context)) + self.workers[name].start() + + logger.debug("testing log worker") + log_queue.put("testing") + + def cancel(self, key: str) -> bool: + """ + Cancel a job. If the job has not been started, this will cancel + the future and never execute it. If the job has been started, it + should be cancelled on the next progress callback. + """ + raise NotImplementedError() + + def done(self, key: str) -> Tuple[Optional[bool], int]: + for k, progress in self.finished: + if key == k: + return (True, progress) + + logger.warn("checking status for unknown key: %s", key) + return (None, 0) + + def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: + # respect overrides if possible + if needs_device is not None: + for i in range(len(self.devices)): + if self.devices[i].device == needs_device.device: + return i + + pending = [ + self.pending[d.device].qsize() for d in self.devices + ] + jobs = Counter(range(len(self.devices))) + jobs.update(pending) + + queued = jobs.most_common() + logger.debug("jobs queued by device: %s", queued) + + lowest_count = queued[-1][1] + lowest_devices = [d[0] for d in queued if d[1] == lowest_count] + lowest_devices.sort() + + return lowest_devices[0] + + def prune(self): + finished_count = len(self.finished) + if finished_count > self.finished_limit: + logger.debug( + "pruning %s of %s finished jobs", + finished_count - self.finished_limit, + finished_count, + ) + self.finished[:] = self.finished[-self.finished_limit:] + + def submit( + self, + key: str, + fn: Callable[..., None], + /, + *args, + needs_device: Optional[DeviceParams] = None, + **kwargs, + ) -> None: + self.prune() + device_idx = self.get_next_device(needs_device=needs_device) + logger.info( + "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] + ) + + device = self.devices[device_idx] + queue = self.pending[device.device] + queue.put((fn, args, kwargs)) + + + def status(self) -> List[Tuple[str, int, bool, int]]: + pending = [ + ( + device.device, + self.pending[device.device].qsize(), + self.workers[device.device].is_alive(), + ) + for device in self.devices + ] + pending.extend(self.finished) + return pending diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py new file mode 100644 index 00000000..8f6adccb --- /dev/null +++ b/api/onnx_web/worker/worker.py @@ -0,0 +1,32 @@ +from logging import getLogger +from torch.multiprocessing import Lock +from time import sleep + +from .context import WorkerContext + +logger = getLogger(__name__) + +def logger_init(lock: Lock, context: WorkerContext): + logger.info("checking in from logger") + + with open("worker.log", "w") as f: + while True: + if context.pending.empty(): + logger.info("no logs, sleeping") + sleep(5) + else: + job = context.pending.get() + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + + +def worker_init(lock: Lock, context: WorkerContext): + logger.info("checking in from worker") + + while True: + if context.pending.empty(): + logger.info("no jobs, sleeping") + sleep(5) + else: + job = context.pending.get() + logger.info("got job: %s", job)