1
0
Fork 0

track started and finished jobs

This commit is contained in:
Sean Sube 2023-02-26 20:09:42 -06:00
parent eb82e73e59
commit 525ee24e91
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 116 additions and 59 deletions

View File

@ -3,9 +3,8 @@ from typing import List, Optional
from PIL import Image from PIL import Image
from onnx_web.image import valid_image from ..image import valid_image
from onnx_web.output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug

View File

@ -6,10 +6,8 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
from onnx_web.chain import blend_mask from ..chain import blend_mask, upscale_outpaint
from onnx_web.chain.base import ChainProgress from ..chain.base import ChainProgress
from ..chain import upscale_outpaint
from ..output import save_image, save_params from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext

View File

@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None:
) )
diffusion_models = list(set(diffusion_models)) diffusion_models = list(set(diffusion_models))
diffusion_models.sort() diffusion_models.sort()
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = [ correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
] ]
correction_models = list(set(correction_models)) correction_models = list(set(correction_models))
correction_models.sort() correction_models.sort()
logger.debug("loaded correction models from disk: %s", correction_models)
inversion_models = [ inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
] ]
inversion_models = list(set(inversion_models)) inversion_models = list(set(inversion_models))
inversion_models.sort() inversion_models.sort()
logger.debug("loaded inversion models from disk: %s", inversion_models)
upscaling_models = [ upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
] ]
upscaling_models = list(set(upscaling_models)) upscaling_models = list(set(upscaling_models))
upscaling_models.sort() upscaling_models.sort()
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
def load_params(context: ServerContext) -> None: def load_params(context: ServerContext) -> None:
global config_params global config_params
params_file = path.join(context.params_path, "params.json") params_file = path.join(context.params_path, "params.json")
logger.debug("loading server parameters from file: %s", params_file)
with open(params_file, "r") as f: with open(params_file, "r") as f:
config_params = yaml.safe_load(f) config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None: if "platform" in config_params and context.default_platform is not None:
logger.info( logger.info(
"Overriding default platform from environment: %s", "overriding default platform from environment: %s",
context.default_platform, context.default_platform,
) )
config_platform = config_params.get("platform", {}) config_platform = config_params.get("platform", {})
@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None:
global available_platforms global available_platforms
providers = list(get_available_providers()) providers = list(get_available_providers())
logger.debug("loading available platforms from providers: %s", providers)
for potential in platform_providers: for potential in platform_providers:
if ( if (

View File

@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Tuple
from flask import Flask from flask import Flask
from onnx_web.utils import base_join from ..utils import base_join
from onnx_web.worker.pool import DevicePoolExecutor from ..worker.pool import DevicePoolExecutor
from .context import ServerContext from .context import ServerContext
@ -28,7 +27,8 @@ def register_routes(
pool: DevicePoolExecutor, pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]], routes: List[Tuple[str, Dict, Callable]],
): ):
pass for route, kwargs, method in routes:
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
def wrap_route(func, *args, **kwargs): def wrap_route(func, *args, **kwargs):

View File

@ -22,18 +22,20 @@ class WorkerContext:
key: str, key: str,
device: DeviceParams, device: DeviceParams,
cancel: "Value[bool]" = None, cancel: "Value[bool]" = None,
finished: "Value[bool]" = None,
progress: "Value[int]" = None, progress: "Value[int]" = None,
finished: "Queue[str]" = None,
logs: "Queue[str]" = None, logs: "Queue[str]" = None,
pending: "Queue[Any]" = None, pending: "Queue[Any]" = None,
started: "Queue[Tuple[str, str]]" = None,
): ):
self.key = key self.key = key
self.cancel = cancel
self.device = device self.device = device
self.pending = pending self.cancel = cancel
self.progress = progress self.progress = progress
self.logs = logs
self.finished = finished self.finished = finished
self.logs = logs
self.pending = pending
self.started = started
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
@ -62,15 +64,16 @@ class WorkerContext:
with self.cancel.get_lock(): with self.cancel.get_lock():
self.cancel.value = cancel self.cancel.value = cancel
def set_finished(self, finished: bool = True) -> None:
with self.finished.get_lock():
self.finished.value = finished
def set_progress(self, progress: int) -> None: def set_progress(self, progress: int) -> None:
with self.progress.get_lock(): with self.progress.get_lock():
self.progress.value = progress self.progress.value = progress
def put_finished(self, job: str) -> None:
self.finished.put((job, self.device.device))
def put_started(self, job: str) -> None:
self.started.put((job, self.device.device))
def clear_flags(self) -> None: def clear_flags(self) -> None:
self.set_cancel(False) self.set_cancel(False)
self.set_finished(False)
self.set_progress(0) self.set_progress(0)

View File

@ -1,9 +1,9 @@
from collections import Counter from collections import Counter
from logging import getLogger from logging import getLogger
from multiprocessing import Queue from threading import Thread
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Value from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
@ -14,12 +14,12 @@ logger = getLogger(__name__)
class DevicePoolExecutor: class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None context: Dict[str, WorkerContext] = None # Device -> Context
devices: List[DeviceParams] = None devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None workers: Dict[str, Process] = None
active_job: Dict[str, str] = None active_jobs: Dict[str, str] = None
finished: List[Tuple[str, int, bool]] = None finished_jobs: List[Tuple[str, int, bool]] = None
def __init__( def __init__(
self, self,
@ -36,11 +36,15 @@ class DevicePoolExecutor:
self.context = {} self.context = {}
self.pending = {} self.pending = {}
self.workers = {} self.workers = {}
self.active_job = {} self.active_jobs = {}
self.finished = [] self.finished_jobs = []
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.started = Queue()
self.finished = Queue()
self.create_logger_worker() self.create_logger_worker()
self.create_queue_workers()
for device in devices: for device in devices:
self.create_device_worker(device) self.create_device_worker(device)
@ -56,16 +60,23 @@ class DevicePoolExecutor:
def create_device_worker(self, device: DeviceParams) -> None: def create_device_worker(self, device: DeviceParams) -> None:
name = device.device name = device.device
pending = Queue()
self.pending[name] = pending # reuse the queue if possible, to keep queued jobs
if name in self.pending:
pending = self.pending[name]
else:
pending = Queue()
self.pending[name] = pending
context = WorkerContext( context = WorkerContext(
name, name,
device, device,
cancel=Value("B", False), cancel=Value("B", False),
finished=Value("B", False),
progress=Value("I", 0), progress=Value("I", 0),
pending=pending, finished=self.finished,
logs=self.log_queue, logs=self.log_queue,
pending=pending,
started=self.started,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server)) self.workers[name] = Process(target=worker_init, args=(context, self.server))
@ -73,9 +84,32 @@ class DevicePoolExecutor:
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name].start() self.workers[name].start()
def create_prune_worker(self) -> None: def create_queue_workers(self) -> None:
# TODO: create a background thread to prune completed jobs def started_worker(pending: Queue):
pass logger.info("checking in from started thread")
while True:
job, device = pending.get()
logger.info("job has been started: %s", job)
self.active_jobs[device] = job
def finished_worker(finished: Queue):
logger.info("checking in from finished thread")
while True:
job, device = finished.get()
logger.info("job has been finished: %s", job)
context = self.get_job_context(job)
self.finished_jobs.append(
(job, context.progress.value, context.cancel.value)
)
self.started_thread = Thread(target=started_worker, args=(self.started,))
self.started_thread.start()
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.finished_thread.start()
def get_job_context(self, key: str) -> WorkerContext:
device = self.active_jobs[key]
return self.context[device]
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
""" """
@ -83,11 +117,11 @@ class DevicePoolExecutor:
the future and never execute it. If the job has been started, it the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
""" """
if key not in self.active_job: if key not in self.active_jobs:
logger.warn("attempting to cancel unknown job: %s", key) logger.warn("attempting to cancel unknown job: %s", key)
return False return False
device = self.active_job[key] device = self.active_jobs[key]
context = self.context[device] context = self.context[device]
logger.info("cancelling job %s on device %s", key, device) logger.info("cancelling job %s on device %s", key, device)
@ -98,19 +132,17 @@ class DevicePoolExecutor:
return True return True
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Tuple[Optional[bool], int]:
if key not in self.active_job: for k, p, c in self.finished_jobs:
if k == key:
return (c, p)
if key not in self.active_jobs:
logger.warn("checking status for unknown job: %s", key) logger.warn("checking status for unknown job: %s", key)
return (None, 0) return (None, 0)
# TODO: prune here, maybe? # TODO: prune here, maybe?
context = self.get_job_context(key)
device = self.active_job[key] return (False, context.progress.value)
context = self.context[device]
if context.finished.value is True:
self.finished.append((key, context.progress.value, context.cancel.value))
return (context.finished.value, context.progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible # respect overrides if possible
@ -132,6 +164,9 @@ class DevicePoolExecutor:
return lowest_devices[0] return lowest_devices[0]
def join(self): def join(self):
self.started_thread.join(self.join_timeout)
self.finished_thread.join(self.join_timeout)
for device, worker in self.workers.items(): for device, worker in self.workers.items():
if worker.is_alive(): if worker.is_alive():
logger.info("stopping worker for device %s", device) logger.info("stopping worker for device %s", device)
@ -166,11 +201,11 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None, needs_device: Optional[DeviceParams] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.finished_jobs += 1 self.total_jobs += 1
logger.debug("pool job count: %s", self.finished_jobs) logger.debug("pool job count: %s", self.total_jobs)
if self.finished_jobs > self.max_jobs_per_worker: if self.total_jobs > self.max_jobs_per_worker:
self.recycle() self.recycle()
self.finished_jobs = 0 self.total_jobs = 0
device_idx = self.get_next_device(needs_device=needs_device) device_idx = self.get_next_device(needs_device=needs_device)
logger.info( logger.info(
@ -184,7 +219,7 @@ class DevicePoolExecutor:
queue = self.pending[device.device] queue = self.pending[device.device]
queue.put((fn, args, kwargs)) queue.put((fn, args, kwargs))
self.active_job[key] = device.device self.active_jobs[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [ pending = [
@ -193,10 +228,22 @@ class DevicePoolExecutor:
self.workers[name].is_alive(), self.workers[name].is_alive(),
context.pending.qsize(), context.pending.qsize(),
context.cancel.value, context.cancel.value,
context.finished.value, False,
context.progress.value, context.progress.value,
) )
for name, context in self.context.items() for name, context in self.context.items()
] ]
pending.extend(self.finished) pending.extend(
[
(
name,
False,
0,
cancel,
True,
progress,
)
for name, progress, cancel in self.finished_jobs
]
)
return pending return pending

View File

@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext):
while True: while True:
job = context.pending.get() job = context.pending.get()
logger.info("got job: %s", job) logger.info("got job: %s", job)
try:
fn, args, kwargs = job
name = args[3][0]
logger.info("starting job: %s", name) fn, args, kwargs = job
name = args[3][0]
try:
context.clear_flags() context.clear_flags()
logger.info("starting job: %s", name)
context.put_started(name)
fn(context, *args, **kwargs) fn(context, *args, **kwargs)
logger.info("job succeeded: %s", name) logger.info("job succeeded: %s", name)
except Exception as e: except Exception as e:
@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
format_exception(type(e), e, e.__traceback__), format_exception(type(e), e, e.__traceback__),
) )
finally: finally:
context.set_finished() context.put_finished(name)
logger.info("finished job: %s", name) logger.info("finished job: %s", name)