track started and finished jobs
This commit is contained in:
parent
eb82e73e59
commit
525ee24e91
|
@ -3,9 +3,8 @@ from typing import List, Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.image import valid_image
|
||||
from onnx_web.output import save_image
|
||||
|
||||
from ..image import valid_image
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
|
|
|
@ -6,10 +6,8 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain import blend_mask
|
||||
from onnx_web.chain.base import ChainProgress
|
||||
|
||||
from ..chain import upscale_outpaint
|
||||
from ..chain import blend_mask, upscale_outpaint
|
||||
from ..chain.base import ChainProgress
|
||||
from ..output import save_image, save_params
|
||||
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
|
|
|
@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None:
|
|||
)
|
||||
diffusion_models = list(set(diffusion_models))
|
||||
diffusion_models.sort()
|
||||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||
|
||||
correction_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
||||
]
|
||||
correction_models = list(set(correction_models))
|
||||
correction_models.sort()
|
||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||
|
||||
inversion_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
||||
]
|
||||
inversion_models = list(set(inversion_models))
|
||||
inversion_models.sort()
|
||||
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
||||
|
||||
upscaling_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||
]
|
||||
upscaling_models = list(set(upscaling_models))
|
||||
upscaling_models.sort()
|
||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||
|
||||
|
||||
def load_params(context: ServerContext) -> None:
|
||||
global config_params
|
||||
|
||||
params_file = path.join(context.params_path, "params.json")
|
||||
logger.debug("loading server parameters from file: %s", params_file)
|
||||
|
||||
with open(params_file, "r") as f:
|
||||
config_params = yaml.safe_load(f)
|
||||
|
||||
if "platform" in config_params and context.default_platform is not None:
|
||||
logger.info(
|
||||
"Overriding default platform from environment: %s",
|
||||
"overriding default platform from environment: %s",
|
||||
context.default_platform,
|
||||
)
|
||||
config_platform = config_params.get("platform", {})
|
||||
|
@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None:
|
|||
global available_platforms
|
||||
|
||||
providers = list(get_available_providers())
|
||||
logger.debug("loading available platforms from providers: %s", providers)
|
||||
|
||||
for potential in platform_providers:
|
||||
if (
|
||||
|
|
|
@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Tuple
|
|||
|
||||
from flask import Flask
|
||||
|
||||
from onnx_web.utils import base_join
|
||||
from onnx_web.worker.pool import DevicePoolExecutor
|
||||
|
||||
from ..utils import base_join
|
||||
from ..worker.pool import DevicePoolExecutor
|
||||
from .context import ServerContext
|
||||
|
||||
|
||||
|
@ -28,7 +27,8 @@ def register_routes(
|
|||
pool: DevicePoolExecutor,
|
||||
routes: List[Tuple[str, Dict, Callable]],
|
||||
):
|
||||
pass
|
||||
for route, kwargs, method in routes:
|
||||
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
|
||||
|
||||
|
||||
def wrap_route(func, *args, **kwargs):
|
||||
|
|
|
@ -22,18 +22,20 @@ class WorkerContext:
|
|||
key: str,
|
||||
device: DeviceParams,
|
||||
cancel: "Value[bool]" = None,
|
||||
finished: "Value[bool]" = None,
|
||||
progress: "Value[int]" = None,
|
||||
finished: "Queue[str]" = None,
|
||||
logs: "Queue[str]" = None,
|
||||
pending: "Queue[Any]" = None,
|
||||
started: "Queue[Tuple[str, str]]" = None,
|
||||
):
|
||||
self.key = key
|
||||
self.cancel = cancel
|
||||
self.device = device
|
||||
self.pending = pending
|
||||
self.cancel = cancel
|
||||
self.progress = progress
|
||||
self.logs = logs
|
||||
self.finished = finished
|
||||
self.logs = logs
|
||||
self.pending = pending
|
||||
self.started = started
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
@ -62,15 +64,16 @@ class WorkerContext:
|
|||
with self.cancel.get_lock():
|
||||
self.cancel.value = cancel
|
||||
|
||||
def set_finished(self, finished: bool = True) -> None:
|
||||
with self.finished.get_lock():
|
||||
self.finished.value = finished
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
with self.progress.get_lock():
|
||||
self.progress.value = progress
|
||||
|
||||
def put_finished(self, job: str) -> None:
|
||||
self.finished.put((job, self.device.device))
|
||||
|
||||
def put_started(self, job: str) -> None:
|
||||
self.started.put((job, self.device.device))
|
||||
|
||||
def clear_flags(self) -> None:
|
||||
self.set_cancel(False)
|
||||
self.set_finished(False)
|
||||
self.set_progress(0)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from multiprocessing import Queue
|
||||
from threading import Thread
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.multiprocessing import Process, Value
|
||||
from torch.multiprocessing import Process, Queue, Value
|
||||
|
||||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
|
@ -14,12 +14,12 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
context: Dict[str, WorkerContext] = None
|
||||
context: Dict[str, WorkerContext] = None # Device -> Context
|
||||
devices: List[DeviceParams] = None
|
||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||
workers: Dict[str, Process] = None
|
||||
active_job: Dict[str, str] = None
|
||||
finished: List[Tuple[str, int, bool]] = None
|
||||
active_jobs: Dict[str, str] = None
|
||||
finished_jobs: List[Tuple[str, int, bool]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -36,11 +36,15 @@ class DevicePoolExecutor:
|
|||
self.context = {}
|
||||
self.pending = {}
|
||||
self.workers = {}
|
||||
self.active_job = {}
|
||||
self.finished = []
|
||||
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||
self.active_jobs = {}
|
||||
self.finished_jobs = []
|
||||
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||
|
||||
self.started = Queue()
|
||||
self.finished = Queue()
|
||||
|
||||
self.create_logger_worker()
|
||||
self.create_queue_workers()
|
||||
for device in devices:
|
||||
self.create_device_worker(device)
|
||||
|
||||
|
@ -56,16 +60,23 @@ class DevicePoolExecutor:
|
|||
|
||||
def create_device_worker(self, device: DeviceParams) -> None:
|
||||
name = device.device
|
||||
pending = Queue()
|
||||
self.pending[name] = pending
|
||||
|
||||
# reuse the queue if possible, to keep queued jobs
|
||||
if name in self.pending:
|
||||
pending = self.pending[name]
|
||||
else:
|
||||
pending = Queue()
|
||||
self.pending[name] = pending
|
||||
|
||||
context = WorkerContext(
|
||||
name,
|
||||
device,
|
||||
cancel=Value("B", False),
|
||||
finished=Value("B", False),
|
||||
progress=Value("I", 0),
|
||||
pending=pending,
|
||||
finished=self.finished,
|
||||
logs=self.log_queue,
|
||||
pending=pending,
|
||||
started=self.started,
|
||||
)
|
||||
self.context[name] = context
|
||||
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
||||
|
@ -73,9 +84,32 @@ class DevicePoolExecutor:
|
|||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name].start()
|
||||
|
||||
def create_prune_worker(self) -> None:
|
||||
# TODO: create a background thread to prune completed jobs
|
||||
pass
|
||||
def create_queue_workers(self) -> None:
|
||||
def started_worker(pending: Queue):
|
||||
logger.info("checking in from started thread")
|
||||
while True:
|
||||
job, device = pending.get()
|
||||
logger.info("job has been started: %s", job)
|
||||
self.active_jobs[device] = job
|
||||
|
||||
def finished_worker(finished: Queue):
|
||||
logger.info("checking in from finished thread")
|
||||
while True:
|
||||
job, device = finished.get()
|
||||
logger.info("job has been finished: %s", job)
|
||||
context = self.get_job_context(job)
|
||||
self.finished_jobs.append(
|
||||
(job, context.progress.value, context.cancel.value)
|
||||
)
|
||||
|
||||
self.started_thread = Thread(target=started_worker, args=(self.started,))
|
||||
self.started_thread.start()
|
||||
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
|
||||
self.finished_thread.start()
|
||||
|
||||
def get_job_context(self, key: str) -> WorkerContext:
|
||||
device = self.active_jobs[key]
|
||||
return self.context[device]
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
"""
|
||||
|
@ -83,11 +117,11 @@ class DevicePoolExecutor:
|
|||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
if key not in self.active_job:
|
||||
if key not in self.active_jobs:
|
||||
logger.warn("attempting to cancel unknown job: %s", key)
|
||||
return False
|
||||
|
||||
device = self.active_job[key]
|
||||
device = self.active_jobs[key]
|
||||
context = self.context[device]
|
||||
logger.info("cancelling job %s on device %s", key, device)
|
||||
|
||||
|
@ -98,19 +132,17 @@ class DevicePoolExecutor:
|
|||
return True
|
||||
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
if key not in self.active_job:
|
||||
for k, p, c in self.finished_jobs:
|
||||
if k == key:
|
||||
return (c, p)
|
||||
|
||||
if key not in self.active_jobs:
|
||||
logger.warn("checking status for unknown job: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
# TODO: prune here, maybe?
|
||||
|
||||
device = self.active_job[key]
|
||||
context = self.context[device]
|
||||
|
||||
if context.finished.value is True:
|
||||
self.finished.append((key, context.progress.value, context.cancel.value))
|
||||
|
||||
return (context.finished.value, context.progress.value)
|
||||
context = self.get_job_context(key)
|
||||
return (False, context.progress.value)
|
||||
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
# respect overrides if possible
|
||||
|
@ -132,6 +164,9 @@ class DevicePoolExecutor:
|
|||
return lowest_devices[0]
|
||||
|
||||
def join(self):
|
||||
self.started_thread.join(self.join_timeout)
|
||||
self.finished_thread.join(self.join_timeout)
|
||||
|
||||
for device, worker in self.workers.items():
|
||||
if worker.is_alive():
|
||||
logger.info("stopping worker for device %s", device)
|
||||
|
@ -166,11 +201,11 @@ class DevicePoolExecutor:
|
|||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.finished_jobs += 1
|
||||
logger.debug("pool job count: %s", self.finished_jobs)
|
||||
if self.finished_jobs > self.max_jobs_per_worker:
|
||||
self.total_jobs += 1
|
||||
logger.debug("pool job count: %s", self.total_jobs)
|
||||
if self.total_jobs > self.max_jobs_per_worker:
|
||||
self.recycle()
|
||||
self.finished_jobs = 0
|
||||
self.total_jobs = 0
|
||||
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
logger.info(
|
||||
|
@ -184,7 +219,7 @@ class DevicePoolExecutor:
|
|||
queue = self.pending[device.device]
|
||||
queue.put((fn, args, kwargs))
|
||||
|
||||
self.active_job[key] = device.device
|
||||
self.active_jobs[key] = device.device
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||
pending = [
|
||||
|
@ -193,10 +228,22 @@ class DevicePoolExecutor:
|
|||
self.workers[name].is_alive(),
|
||||
context.pending.qsize(),
|
||||
context.cancel.value,
|
||||
context.finished.value,
|
||||
False,
|
||||
context.progress.value,
|
||||
)
|
||||
for name, context in self.context.items()
|
||||
]
|
||||
pending.extend(self.finished)
|
||||
pending.extend(
|
||||
[
|
||||
(
|
||||
name,
|
||||
False,
|
||||
0,
|
||||
cancel,
|
||||
True,
|
||||
progress,
|
||||
)
|
||||
for name, progress, cancel in self.finished_jobs
|
||||
]
|
||||
)
|
||||
return pending
|
||||
|
|
|
@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
|||
while True:
|
||||
job = context.pending.get()
|
||||
logger.info("got job: %s", job)
|
||||
try:
|
||||
fn, args, kwargs = job
|
||||
name = args[3][0]
|
||||
|
||||
logger.info("starting job: %s", name)
|
||||
fn, args, kwargs = job
|
||||
name = args[3][0]
|
||||
|
||||
try:
|
||||
context.clear_flags()
|
||||
logger.info("starting job: %s", name)
|
||||
context.put_started(name)
|
||||
fn(context, *args, **kwargs)
|
||||
logger.info("job succeeded: %s", name)
|
||||
except Exception as e:
|
||||
|
@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
|||
format_exception(type(e), e, e.__traceback__),
|
||||
)
|
||||
finally:
|
||||
context.set_finished()
|
||||
context.put_finished(name)
|
||||
logger.info("finished job: %s", name)
|
||||
|
|
Loading…
Reference in New Issue