1
0
Fork 0

track started and finished jobs

This commit is contained in:
Sean Sube 2023-02-26 20:09:42 -06:00
parent eb82e73e59
commit 525ee24e91
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 116 additions and 59 deletions

View File

@ -3,9 +3,8 @@ from typing import List, Optional
from PIL import Image
from onnx_web.image import valid_image
from onnx_web.output import save_image
from ..image import valid_image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug

View File

@ -6,10 +6,8 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image
from onnx_web.chain import blend_mask
from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint
from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext

View File

@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None:
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
logger.debug("loaded correction models from disk: %s", correction_models)
inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
]
inversion_models = list(set(inversion_models))
inversion_models.sort()
logger.debug("loaded inversion models from disk: %s", inversion_models)
upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
def load_params(context: ServerContext) -> None:
global config_params
params_file = path.join(context.params_path, "params.json")
logger.debug("loading server parameters from file: %s", params_file)
with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None:
logger.info(
"Overriding default platform from environment: %s",
"overriding default platform from environment: %s",
context.default_platform,
)
config_platform = config_params.get("platform", {})
@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None:
global available_platforms
providers = list(get_available_providers())
logger.debug("loading available platforms from providers: %s", providers)
for potential in platform_providers:
if (

View File

@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Tuple
from flask import Flask
from onnx_web.utils import base_join
from onnx_web.worker.pool import DevicePoolExecutor
from ..utils import base_join
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
@ -28,7 +27,8 @@ def register_routes(
pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]],
):
pass
for route, kwargs, method in routes:
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
def wrap_route(func, *args, **kwargs):

View File

@ -22,18 +22,20 @@ class WorkerContext:
key: str,
device: DeviceParams,
cancel: "Value[bool]" = None,
finished: "Value[bool]" = None,
progress: "Value[int]" = None,
finished: "Queue[str]" = None,
logs: "Queue[str]" = None,
pending: "Queue[Any]" = None,
started: "Queue[Tuple[str, str]]" = None,
):
self.key = key
self.cancel = cancel
self.device = device
self.pending = pending
self.cancel = cancel
self.progress = progress
self.logs = logs
self.finished = finished
self.logs = logs
self.pending = pending
self.started = started
def is_cancelled(self) -> bool:
return self.cancel.value
@ -62,15 +64,16 @@ class WorkerContext:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_finished(self, finished: bool = True) -> None:
with self.finished.get_lock():
self.finished.value = finished
def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress
def put_finished(self, job: str) -> None:
self.finished.put((job, self.device.device))
def put_started(self, job: str) -> None:
self.started.put((job, self.device.device))
def clear_flags(self) -> None:
self.set_cancel(False)
self.set_finished(False)
self.set_progress(0)

View File

@ -1,9 +1,9 @@
from collections import Counter
from logging import getLogger
from multiprocessing import Queue
from threading import Thread
from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Value
from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams
from ..server import ServerContext
@ -14,12 +14,12 @@ logger = getLogger(__name__)
class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None
context: Dict[str, WorkerContext] = None # Device -> Context
devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None
active_job: Dict[str, str] = None
finished: List[Tuple[str, int, bool]] = None
active_jobs: Dict[str, str] = None
finished_jobs: List[Tuple[str, int, bool]] = None
def __init__(
self,
@ -36,11 +36,15 @@ class DevicePoolExecutor:
self.context = {}
self.pending = {}
self.workers = {}
self.active_job = {}
self.finished = []
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
self.active_jobs = {}
self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.started = Queue()
self.finished = Queue()
self.create_logger_worker()
self.create_queue_workers()
for device in devices:
self.create_device_worker(device)
@ -56,16 +60,23 @@ class DevicePoolExecutor:
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
pending = Queue()
self.pending[name] = pending
# reuse the queue if possible, to keep queued jobs
if name in self.pending:
pending = self.pending[name]
else:
pending = Queue()
self.pending[name] = pending
context = WorkerContext(
name,
device,
cancel=Value("B", False),
finished=Value("B", False),
progress=Value("I", 0),
pending=pending,
finished=self.finished,
logs=self.log_queue,
pending=pending,
started=self.started,
)
self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server))
@ -73,9 +84,32 @@ class DevicePoolExecutor:
logger.debug("starting worker for device %s", device)
self.workers[name].start()
def create_prune_worker(self) -> None:
# TODO: create a background thread to prune completed jobs
pass
def create_queue_workers(self) -> None:
def started_worker(pending: Queue):
logger.info("checking in from started thread")
while True:
job, device = pending.get()
logger.info("job has been started: %s", job)
self.active_jobs[device] = job
def finished_worker(finished: Queue):
logger.info("checking in from finished thread")
while True:
job, device = finished.get()
logger.info("job has been finished: %s", job)
context = self.get_job_context(job)
self.finished_jobs.append(
(job, context.progress.value, context.cancel.value)
)
self.started_thread = Thread(target=started_worker, args=(self.started,))
self.started_thread.start()
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.finished_thread.start()
def get_job_context(self, key: str) -> WorkerContext:
device = self.active_jobs[key]
return self.context[device]
def cancel(self, key: str) -> bool:
"""
@ -83,11 +117,11 @@ class DevicePoolExecutor:
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
if key not in self.active_job:
if key not in self.active_jobs:
logger.warn("attempting to cancel unknown job: %s", key)
return False
device = self.active_job[key]
device = self.active_jobs[key]
context = self.context[device]
logger.info("cancelling job %s on device %s", key, device)
@ -98,19 +132,17 @@ class DevicePoolExecutor:
return True
def done(self, key: str) -> Tuple[Optional[bool], int]:
if key not in self.active_job:
for k, p, c in self.finished_jobs:
if k == key:
return (c, p)
if key not in self.active_jobs:
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
# TODO: prune here, maybe?
device = self.active_job[key]
context = self.context[device]
if context.finished.value is True:
self.finished.append((key, context.progress.value, context.cancel.value))
return (context.finished.value, context.progress.value)
context = self.get_job_context(key)
return (False, context.progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
@ -132,6 +164,9 @@ class DevicePoolExecutor:
return lowest_devices[0]
def join(self):
self.started_thread.join(self.join_timeout)
self.finished_thread.join(self.join_timeout)
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
@ -166,11 +201,11 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.finished_jobs += 1
logger.debug("pool job count: %s", self.finished_jobs)
if self.finished_jobs > self.max_jobs_per_worker:
self.total_jobs += 1
logger.debug("pool job count: %s", self.total_jobs)
if self.total_jobs > self.max_jobs_per_worker:
self.recycle()
self.finished_jobs = 0
self.total_jobs = 0
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
@ -184,7 +219,7 @@ class DevicePoolExecutor:
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
self.active_job[key] = device.device
self.active_jobs[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
@ -193,10 +228,22 @@ class DevicePoolExecutor:
self.workers[name].is_alive(),
context.pending.qsize(),
context.cancel.value,
context.finished.value,
False,
context.progress.value,
)
for name, context in self.context.items()
]
pending.extend(self.finished)
pending.extend(
[
(
name,
False,
0,
cancel,
True,
progress,
)
for name, progress, cancel in self.finished_jobs
]
)
return pending

View File

@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext):
while True:
job = context.pending.get()
logger.info("got job: %s", job)
try:
fn, args, kwargs = job
name = args[3][0]
logger.info("starting job: %s", name)
fn, args, kwargs = job
name = args[3][0]
try:
context.clear_flags()
logger.info("starting job: %s", name)
context.put_started(name)
fn(context, *args, **kwargs)
logger.info("job succeeded: %s", name)
except Exception as e:
@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
format_exception(type(e), e, e.__traceback__),
)
finally:
context.set_finished()
context.put_finished(name)
logger.info("finished job: %s", name)