track started and finished jobs
This commit is contained in:
parent
eb82e73e59
commit
525ee24e91
|
@ -3,9 +3,8 @@ from typing import List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.image import valid_image
|
from ..image import valid_image
|
||||||
from onnx_web.output import save_image
|
from ..output import save_image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
|
|
@ -6,10 +6,8 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain import blend_mask
|
from ..chain import blend_mask, upscale_outpaint
|
||||||
from onnx_web.chain.base import ChainProgress
|
from ..chain.base import ChainProgress
|
||||||
|
|
||||||
from ..chain import upscale_outpaint
|
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
|
|
@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
diffusion_models = list(set(diffusion_models))
|
diffusion_models = list(set(diffusion_models))
|
||||||
diffusion_models.sort()
|
diffusion_models.sort()
|
||||||
|
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||||
|
|
||||||
correction_models = [
|
correction_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
||||||
]
|
]
|
||||||
correction_models = list(set(correction_models))
|
correction_models = list(set(correction_models))
|
||||||
correction_models.sort()
|
correction_models.sort()
|
||||||
|
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||||
|
|
||||||
inversion_models = [
|
inversion_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
||||||
]
|
]
|
||||||
inversion_models = list(set(inversion_models))
|
inversion_models = list(set(inversion_models))
|
||||||
inversion_models.sort()
|
inversion_models.sort()
|
||||||
|
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
||||||
|
|
||||||
upscaling_models = [
|
upscaling_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||||
]
|
]
|
||||||
upscaling_models = list(set(upscaling_models))
|
upscaling_models = list(set(upscaling_models))
|
||||||
upscaling_models.sort()
|
upscaling_models.sort()
|
||||||
|
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext) -> None:
|
def load_params(context: ServerContext) -> None:
|
||||||
global config_params
|
global config_params
|
||||||
|
|
||||||
params_file = path.join(context.params_path, "params.json")
|
params_file = path.join(context.params_path, "params.json")
|
||||||
|
logger.debug("loading server parameters from file: %s", params_file)
|
||||||
|
|
||||||
with open(params_file, "r") as f:
|
with open(params_file, "r") as f:
|
||||||
config_params = yaml.safe_load(f)
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
if "platform" in config_params and context.default_platform is not None:
|
if "platform" in config_params and context.default_platform is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overriding default platform from environment: %s",
|
"overriding default platform from environment: %s",
|
||||||
context.default_platform,
|
context.default_platform,
|
||||||
)
|
)
|
||||||
config_platform = config_params.get("platform", {})
|
config_platform = config_params.get("platform", {})
|
||||||
|
@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
||||||
providers = list(get_available_providers())
|
providers = list(get_available_providers())
|
||||||
|
logger.debug("loading available platforms from providers: %s", providers)
|
||||||
|
|
||||||
for potential in platform_providers:
|
for potential in platform_providers:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -4,9 +4,8 @@ from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
from onnx_web.utils import base_join
|
from ..utils import base_join
|
||||||
from onnx_web.worker.pool import DevicePoolExecutor
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +27,8 @@ def register_routes(
|
||||||
pool: DevicePoolExecutor,
|
pool: DevicePoolExecutor,
|
||||||
routes: List[Tuple[str, Dict, Callable]],
|
routes: List[Tuple[str, Dict, Callable]],
|
||||||
):
|
):
|
||||||
pass
|
for route, kwargs, method in routes:
|
||||||
|
app.route(route, **kwargs)(wrap_route(method, context, pool=pool))
|
||||||
|
|
||||||
|
|
||||||
def wrap_route(func, *args, **kwargs):
|
def wrap_route(func, *args, **kwargs):
|
||||||
|
|
|
@ -22,18 +22,20 @@ class WorkerContext:
|
||||||
key: str,
|
key: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
cancel: "Value[bool]" = None,
|
cancel: "Value[bool]" = None,
|
||||||
finished: "Value[bool]" = None,
|
|
||||||
progress: "Value[int]" = None,
|
progress: "Value[int]" = None,
|
||||||
|
finished: "Queue[str]" = None,
|
||||||
logs: "Queue[str]" = None,
|
logs: "Queue[str]" = None,
|
||||||
pending: "Queue[Any]" = None,
|
pending: "Queue[Any]" = None,
|
||||||
|
started: "Queue[Tuple[str, str]]" = None,
|
||||||
):
|
):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.cancel = cancel
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.pending = pending
|
self.cancel = cancel
|
||||||
self.progress = progress
|
self.progress = progress
|
||||||
self.logs = logs
|
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
|
self.logs = logs
|
||||||
|
self.pending = pending
|
||||||
|
self.started = started
|
||||||
|
|
||||||
def is_cancelled(self) -> bool:
|
def is_cancelled(self) -> bool:
|
||||||
return self.cancel.value
|
return self.cancel.value
|
||||||
|
@ -62,15 +64,16 @@ class WorkerContext:
|
||||||
with self.cancel.get_lock():
|
with self.cancel.get_lock():
|
||||||
self.cancel.value = cancel
|
self.cancel.value = cancel
|
||||||
|
|
||||||
def set_finished(self, finished: bool = True) -> None:
|
|
||||||
with self.finished.get_lock():
|
|
||||||
self.finished.value = finished
|
|
||||||
|
|
||||||
def set_progress(self, progress: int) -> None:
|
def set_progress(self, progress: int) -> None:
|
||||||
with self.progress.get_lock():
|
with self.progress.get_lock():
|
||||||
self.progress.value = progress
|
self.progress.value = progress
|
||||||
|
|
||||||
|
def put_finished(self, job: str) -> None:
|
||||||
|
self.finished.put((job, self.device.device))
|
||||||
|
|
||||||
|
def put_started(self, job: str) -> None:
|
||||||
|
self.started.put((job, self.device.device))
|
||||||
|
|
||||||
def clear_flags(self) -> None:
|
def clear_flags(self) -> None:
|
||||||
self.set_cancel(False)
|
self.set_cancel(False)
|
||||||
self.set_finished(False)
|
|
||||||
self.set_progress(0)
|
self.set_progress(0)
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from multiprocessing import Queue
|
from threading import Thread
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from torch.multiprocessing import Process, Value
|
from torch.multiprocessing import Process, Queue, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
@ -14,12 +14,12 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DevicePoolExecutor:
|
class DevicePoolExecutor:
|
||||||
context: Dict[str, WorkerContext] = None
|
context: Dict[str, WorkerContext] = None # Device -> Context
|
||||||
devices: List[DeviceParams] = None
|
devices: List[DeviceParams] = None
|
||||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||||
workers: Dict[str, Process] = None
|
workers: Dict[str, Process] = None
|
||||||
active_job: Dict[str, str] = None
|
active_jobs: Dict[str, str] = None
|
||||||
finished: List[Tuple[str, int, bool]] = None
|
finished_jobs: List[Tuple[str, int, bool]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -36,11 +36,15 @@ class DevicePoolExecutor:
|
||||||
self.context = {}
|
self.context = {}
|
||||||
self.pending = {}
|
self.pending = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
self.active_job = {}
|
self.active_jobs = {}
|
||||||
self.finished = []
|
self.finished_jobs = []
|
||||||
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
|
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||||
|
|
||||||
|
self.started = Queue()
|
||||||
|
self.finished = Queue()
|
||||||
|
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
|
self.create_queue_workers()
|
||||||
for device in devices:
|
for device in devices:
|
||||||
self.create_device_worker(device)
|
self.create_device_worker(device)
|
||||||
|
|
||||||
|
@ -56,16 +60,23 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
def create_device_worker(self, device: DeviceParams) -> None:
|
def create_device_worker(self, device: DeviceParams) -> None:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
|
||||||
|
# reuse the queue if possible, to keep queued jobs
|
||||||
|
if name in self.pending:
|
||||||
|
pending = self.pending[name]
|
||||||
|
else:
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
self.pending[name] = pending
|
self.pending[name] = pending
|
||||||
|
|
||||||
context = WorkerContext(
|
context = WorkerContext(
|
||||||
name,
|
name,
|
||||||
device,
|
device,
|
||||||
cancel=Value("B", False),
|
cancel=Value("B", False),
|
||||||
finished=Value("B", False),
|
|
||||||
progress=Value("I", 0),
|
progress=Value("I", 0),
|
||||||
pending=pending,
|
finished=self.finished,
|
||||||
logs=self.log_queue,
|
logs=self.log_queue,
|
||||||
|
pending=pending,
|
||||||
|
started=self.started,
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
||||||
|
@ -73,9 +84,32 @@ class DevicePoolExecutor:
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
self.workers[name].start()
|
self.workers[name].start()
|
||||||
|
|
||||||
def create_prune_worker(self) -> None:
|
def create_queue_workers(self) -> None:
|
||||||
# TODO: create a background thread to prune completed jobs
|
def started_worker(pending: Queue):
|
||||||
pass
|
logger.info("checking in from started thread")
|
||||||
|
while True:
|
||||||
|
job, device = pending.get()
|
||||||
|
logger.info("job has been started: %s", job)
|
||||||
|
self.active_jobs[device] = job
|
||||||
|
|
||||||
|
def finished_worker(finished: Queue):
|
||||||
|
logger.info("checking in from finished thread")
|
||||||
|
while True:
|
||||||
|
job, device = finished.get()
|
||||||
|
logger.info("job has been finished: %s", job)
|
||||||
|
context = self.get_job_context(job)
|
||||||
|
self.finished_jobs.append(
|
||||||
|
(job, context.progress.value, context.cancel.value)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.started_thread = Thread(target=started_worker, args=(self.started,))
|
||||||
|
self.started_thread.start()
|
||||||
|
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
|
||||||
|
self.finished_thread.start()
|
||||||
|
|
||||||
|
def get_job_context(self, key: str) -> WorkerContext:
|
||||||
|
device = self.active_jobs[key]
|
||||||
|
return self.context[device]
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -83,11 +117,11 @@ class DevicePoolExecutor:
|
||||||
the future and never execute it. If the job has been started, it
|
the future and never execute it. If the job has been started, it
|
||||||
should be cancelled on the next progress callback.
|
should be cancelled on the next progress callback.
|
||||||
"""
|
"""
|
||||||
if key not in self.active_job:
|
if key not in self.active_jobs:
|
||||||
logger.warn("attempting to cancel unknown job: %s", key)
|
logger.warn("attempting to cancel unknown job: %s", key)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
device = self.active_job[key]
|
device = self.active_jobs[key]
|
||||||
context = self.context[device]
|
context = self.context[device]
|
||||||
logger.info("cancelling job %s on device %s", key, device)
|
logger.info("cancelling job %s on device %s", key, device)
|
||||||
|
|
||||||
|
@ -98,19 +132,17 @@ class DevicePoolExecutor:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||||
if key not in self.active_job:
|
for k, p, c in self.finished_jobs:
|
||||||
|
if k == key:
|
||||||
|
return (c, p)
|
||||||
|
|
||||||
|
if key not in self.active_jobs:
|
||||||
logger.warn("checking status for unknown job: %s", key)
|
logger.warn("checking status for unknown job: %s", key)
|
||||||
return (None, 0)
|
return (None, 0)
|
||||||
|
|
||||||
# TODO: prune here, maybe?
|
# TODO: prune here, maybe?
|
||||||
|
context = self.get_job_context(key)
|
||||||
device = self.active_job[key]
|
return (False, context.progress.value)
|
||||||
context = self.context[device]
|
|
||||||
|
|
||||||
if context.finished.value is True:
|
|
||||||
self.finished.append((key, context.progress.value, context.cancel.value))
|
|
||||||
|
|
||||||
return (context.finished.value, context.progress.value)
|
|
||||||
|
|
||||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||||
# respect overrides if possible
|
# respect overrides if possible
|
||||||
|
@ -132,6 +164,9 @@ class DevicePoolExecutor:
|
||||||
return lowest_devices[0]
|
return lowest_devices[0]
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
|
self.started_thread.join(self.join_timeout)
|
||||||
|
self.finished_thread.join(self.join_timeout)
|
||||||
|
|
||||||
for device, worker in self.workers.items():
|
for device, worker in self.workers.items():
|
||||||
if worker.is_alive():
|
if worker.is_alive():
|
||||||
logger.info("stopping worker for device %s", device)
|
logger.info("stopping worker for device %s", device)
|
||||||
|
@ -166,11 +201,11 @@ class DevicePoolExecutor:
|
||||||
needs_device: Optional[DeviceParams] = None,
|
needs_device: Optional[DeviceParams] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.finished_jobs += 1
|
self.total_jobs += 1
|
||||||
logger.debug("pool job count: %s", self.finished_jobs)
|
logger.debug("pool job count: %s", self.total_jobs)
|
||||||
if self.finished_jobs > self.max_jobs_per_worker:
|
if self.total_jobs > self.max_jobs_per_worker:
|
||||||
self.recycle()
|
self.recycle()
|
||||||
self.finished_jobs = 0
|
self.total_jobs = 0
|
||||||
|
|
||||||
device_idx = self.get_next_device(needs_device=needs_device)
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -184,7 +219,7 @@ class DevicePoolExecutor:
|
||||||
queue = self.pending[device.device]
|
queue = self.pending[device.device]
|
||||||
queue.put((fn, args, kwargs))
|
queue.put((fn, args, kwargs))
|
||||||
|
|
||||||
self.active_job[key] = device.device
|
self.active_jobs[key] = device.device
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||||
pending = [
|
pending = [
|
||||||
|
@ -193,10 +228,22 @@ class DevicePoolExecutor:
|
||||||
self.workers[name].is_alive(),
|
self.workers[name].is_alive(),
|
||||||
context.pending.qsize(),
|
context.pending.qsize(),
|
||||||
context.cancel.value,
|
context.cancel.value,
|
||||||
context.finished.value,
|
False,
|
||||||
context.progress.value,
|
context.progress.value,
|
||||||
)
|
)
|
||||||
for name, context in self.context.items()
|
for name, context in self.context.items()
|
||||||
]
|
]
|
||||||
pending.extend(self.finished)
|
pending.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
cancel,
|
||||||
|
True,
|
||||||
|
progress,
|
||||||
|
)
|
||||||
|
for name, progress, cancel in self.finished_jobs
|
||||||
|
]
|
||||||
|
)
|
||||||
return pending
|
return pending
|
||||||
|
|
|
@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
||||||
while True:
|
while True:
|
||||||
job = context.pending.get()
|
job = context.pending.get()
|
||||||
logger.info("got job: %s", job)
|
logger.info("got job: %s", job)
|
||||||
try:
|
|
||||||
fn, args, kwargs = job
|
fn, args, kwargs = job
|
||||||
name = args[3][0]
|
name = args[3][0]
|
||||||
|
|
||||||
logger.info("starting job: %s", name)
|
try:
|
||||||
context.clear_flags()
|
context.clear_flags()
|
||||||
|
logger.info("starting job: %s", name)
|
||||||
|
context.put_started(name)
|
||||||
fn(context, *args, **kwargs)
|
fn(context, *args, **kwargs)
|
||||||
logger.info("job succeeded: %s", name)
|
logger.info("job succeeded: %s", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
||||||
format_exception(type(e), e, e.__traceback__),
|
format_exception(type(e), e, e.__traceback__),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
context.set_finished()
|
context.put_finished(name)
|
||||||
logger.info("finished job: %s", name)
|
logger.info("finished job: %s", name)
|
||||||
|
|
Loading…
Reference in New Issue