background workers, logger
This commit is contained in:
parent
e46a1e5fd0
commit
f898de8c54
|
@ -5,14 +5,14 @@ formatters:
|
|||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
formatter: simple
|
||||
stream: ext://sys.stdout
|
||||
loggers:
|
||||
'':
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: True
|
||||
root:
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
|
|
|
@ -25,6 +25,7 @@ from .image import (
|
|||
from .onnx import OnnxNet, OnnxTensor
|
||||
from .params import (
|
||||
Border,
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
Param,
|
||||
Point,
|
||||
|
@ -33,8 +34,6 @@ from .params import (
|
|||
UpscaleParams,
|
||||
)
|
||||
from .server import (
|
||||
DeviceParams,
|
||||
DevicePoolExecutor,
|
||||
ModelCache,
|
||||
ServerContext,
|
||||
apply_patch_basicsr,
|
||||
|
@ -51,3 +50,6 @@ from .utils import (
|
|||
get_from_map,
|
||||
get_not_empty,
|
||||
)
|
||||
from .worker import (
|
||||
DevicePoolExecutor,
|
||||
)
|
|
@ -7,7 +7,8 @@ from PIL import Image
|
|||
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from .utils import process_tile_order
|
||||
|
||||
|
@ -17,7 +18,7 @@ logger = getLogger(__name__)
|
|||
class StageCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
@ -77,7 +78,7 @@ class ChainPipeline:
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
|
|
|
@ -7,13 +7,14 @@ from PIL import Image
|
|||
|
||||
from ..diffusion.load import load_pipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_img2img(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
|
|||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from .utils import process_tile_order
|
||||
|
||||
|
@ -18,7 +19,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def blend_inpaint(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -7,14 +7,15 @@ from onnx_web.image import valid_image
|
|||
from onnx_web.output import save_image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_mask(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -3,7 +3,8 @@ from logging import getLogger
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -11,7 +12,7 @@ device = "cpu"
|
|||
|
||||
|
||||
def correct_codeformer(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -5,8 +5,10 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -46,7 +48,7 @@ def load_gfpgan(
|
|||
|
||||
|
||||
def correct_gfpgan(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -4,13 +4,15 @@ from PIL import Image
|
|||
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_disk(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -5,13 +5,15 @@ from boto3 import Session
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_s3(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -3,13 +3,15 @@ from logging import getLogger
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_crop(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -3,13 +3,15 @@ from logging import getLogger
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_thumbnail(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -4,13 +4,15 @@ from typing import Callable
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_noise(
|
||||
_job: JobContext,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -7,13 +7,15 @@ from PIL import Image
|
|||
|
||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_txt2img(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
|
|||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from .utils import process_tile_grid, process_tile_order
|
||||
|
||||
|
@ -18,7 +19,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def upscale_outpaint(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -6,7 +6,8 @@ from PIL import Image
|
|||
|
||||
from ..onnx import OnnxNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -96,7 +97,7 @@ def load_resrgan(
|
|||
|
||||
|
||||
def upscale_resrgan(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
|||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import JobContext, ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext, ProgressCallback
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -62,7 +63,7 @@ def load_stable_diffusion(
|
|||
|
||||
|
||||
def upscale_stable_diffusion(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -12,7 +12,8 @@ from onnx_web.chain.base import ChainProgress
|
|||
from ..chain import upscale_outpaint
|
||||
from ..output import save_image, save_params
|
||||
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..server import JobContext, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..server import ServerContext
|
||||
from ..upscale import run_upscale_correction
|
||||
from ..utils import run_gc
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
|
@ -21,7 +22,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -95,7 +96,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
|
||||
def run_img2img_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
outputs: List[str],
|
||||
|
@ -167,7 +168,7 @@ def run_img2img_pipeline(
|
|||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -217,7 +218,7 @@ def run_inpaint_pipeline(
|
|||
|
||||
|
||||
def run_upscale_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
@ -243,7 +244,7 @@ def run_upscale_pipeline(
|
|||
|
||||
|
||||
def run_blend_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
|
|
|
@ -63,7 +63,7 @@ from .params import (
|
|||
TileOrder,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .server import DevicePoolExecutor, ServerContext, apply_patches
|
||||
from .server import ServerContext, apply_patches
|
||||
from .transformers import run_txt2txt_pipeline
|
||||
from .utils import (
|
||||
base_join,
|
||||
|
@ -75,6 +75,7 @@ from .utils import (
|
|||
get_size,
|
||||
is_debug,
|
||||
)
|
||||
from .worker import DevicePoolExecutor
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,10 +1,3 @@
|
|||
from .device_pool import (
|
||||
DeviceParams,
|
||||
DevicePoolExecutor,
|
||||
Job,
|
||||
JobContext,
|
||||
ProgressCallback,
|
||||
)
|
||||
from .hacks import (
|
||||
apply_patch_basicsr,
|
||||
apply_patch_codeformer,
|
||||
|
|
|
@ -1,230 +0,0 @@
|
|||
from collections import Counter
|
||||
from concurrent.futures import Future
|
||||
from logging import getLogger
|
||||
from multiprocessing import Queue
|
||||
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
|
||||
from traceback import format_exception
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from time import sleep
|
||||
|
||||
from ..params import DeviceParams
|
||||
from ..utils import run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
ProgressCallback = Callable[[int, int, Any], None]
|
||||
|
||||
|
||||
def worker_init(lock: Lock, job_queue: SimpleQueue):
|
||||
logger.info("checking in from worker")
|
||||
|
||||
while True:
|
||||
if job_queue.empty():
|
||||
logger.info("no jobs, sleeping")
|
||||
sleep(5)
|
||||
else:
|
||||
job = job_queue.get()
|
||||
logger.info("got job: %s", job)
|
||||
|
||||
|
||||
class JobContext:
|
||||
cancel: Value = None
|
||||
device_index: Value = None
|
||||
devices: List[DeviceParams] = None
|
||||
key: str = None
|
||||
progress: Value = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
devices: List[DeviceParams],
|
||||
cancel: bool = False,
|
||||
device_index: int = -1,
|
||||
progress: int = 0,
|
||||
):
|
||||
self.key = key
|
||||
self.devices = list(devices)
|
||||
self.cancel = Value("B", cancel)
|
||||
self.device_index = Value("i", device_index)
|
||||
self.progress = Value("I", progress)
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def get_device(self) -> DeviceParams:
|
||||
"""
|
||||
Get the device assigned to this job.
|
||||
"""
|
||||
with self.device_index.get_lock():
|
||||
device_index = self.device_index.value
|
||||
if device_index < 0:
|
||||
raise ValueError("job has not been assigned to a device")
|
||||
else:
|
||||
device = self.devices[device_index]
|
||||
logger.debug("job %s assigned to device %s", self.key, device)
|
||||
return device
|
||||
|
||||
def get_progress(self) -> int:
|
||||
return self.progress.value
|
||||
|
||||
def get_progress_callback(self) -> ProgressCallback:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
on_progress.step = step
|
||||
if self.is_cancelled():
|
||||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.key, step)
|
||||
self.set_progress(step)
|
||||
|
||||
return on_progress
|
||||
|
||||
def set_cancel(self, cancel: bool = True) -> None:
|
||||
with self.cancel.get_lock():
|
||||
self.cancel.value = cancel
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
with self.progress.get_lock():
|
||||
self.progress.value = progress
|
||||
|
||||
|
||||
class Job:
|
||||
"""
|
||||
Link a future to its context.
|
||||
"""
|
||||
|
||||
context: JobContext = None
|
||||
future: Future = None
|
||||
key: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
future: Future,
|
||||
context: JobContext,
|
||||
):
|
||||
self.context = context
|
||||
self.future = future
|
||||
self.key = key
|
||||
|
||||
def get_progress(self) -> int:
|
||||
return self.context.get_progress()
|
||||
|
||||
def set_cancel(self, cancel: bool = True):
|
||||
return self.context.set_cancel(cancel)
|
||||
|
||||
def set_progress(self, progress: int):
|
||||
return self.context.set_progress(progress)
|
||||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
devices: List[DeviceParams] = None
|
||||
finished: List[Tuple[str, int]] = None
|
||||
pending: Dict[str, "Queue[Job]"] = None
|
||||
progress: Dict[str, Value] = None
|
||||
workers: Dict[str, Process] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
devices: List[DeviceParams],
|
||||
finished_limit: int = 10,
|
||||
):
|
||||
self.devices = devices
|
||||
self.finished = []
|
||||
self.finished_limit = finished_limit
|
||||
self.lock = Lock()
|
||||
self.pending = {}
|
||||
self.progress = {}
|
||||
self.workers = {}
|
||||
|
||||
# create a pending queue and progress value for each device
|
||||
for device in devices:
|
||||
name = device.device
|
||||
job_queue = Queue()
|
||||
self.pending[name] = job_queue
|
||||
self.progress[name] = Value("I", 0, lock=self.lock)
|
||||
self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
"""
|
||||
Cancel a job. If the job has not been started, this will cancel
|
||||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
for k, progress in self.finished:
|
||||
if key == k:
|
||||
return (True, progress)
|
||||
|
||||
logger.warn("checking status for unknown key: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
# respect overrides if possible
|
||||
if needs_device is not None:
|
||||
for i in range(len(self.devices)):
|
||||
if self.devices[i].device == needs_device.device:
|
||||
return i
|
||||
|
||||
# use the first/default device if there are no jobs
|
||||
if len(self.jobs) == 0:
|
||||
return 0
|
||||
|
||||
job_devices = [
|
||||
job.context.device_index.value for job in self.jobs if not job.future.done()
|
||||
]
|
||||
job_counts = Counter(range(len(self.devices)))
|
||||
job_counts.update(job_devices)
|
||||
|
||||
queued = job_counts.most_common()
|
||||
logger.debug("jobs queued by device: %s", queued)
|
||||
|
||||
lowest_count = queued[-1][1]
|
||||
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
||||
lowest_devices.sort()
|
||||
|
||||
return lowest_devices[0]
|
||||
|
||||
def prune(self):
|
||||
finished_count = len(self.finished)
|
||||
if finished_count > self.finished_limit:
|
||||
logger.debug(
|
||||
"pruning %s of %s finished jobs",
|
||||
finished_count - self.finished_limit,
|
||||
finished_count,
|
||||
)
|
||||
self.finished[:] = self.finished[-self.finished_limit:]
|
||||
|
||||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., None],
|
||||
/,
|
||||
*args,
|
||||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.prune()
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
logger.info(
|
||||
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
||||
)
|
||||
|
||||
context = JobContext(key, self.devices, device_index=device_idx)
|
||||
device = self.devices[device_idx]
|
||||
|
||||
queue = self.pending[device.device]
|
||||
queue.put((fn, context, args, kwargs))
|
||||
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||
pending = [
|
||||
(
|
||||
device.device,
|
||||
self.pending[device.device].qsize(),
|
||||
)
|
||||
for device in self.devices
|
||||
]
|
||||
pending.extend(self.finished)
|
||||
return pending
|
|
@ -1,13 +1,14 @@
|
|||
from logging import getLogger
|
||||
|
||||
from .params import ImageParams, Size
|
||||
from .server import JobContext, ServerContext
|
||||
from .server import ServerContext
|
||||
from .worker import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run_txt2txt_pipeline(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
params: ImageParams,
|
||||
_size: Size,
|
||||
|
|
|
@ -10,13 +10,14 @@ from .chain import (
|
|||
upscale_stable_diffusion,
|
||||
)
|
||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .server import JobContext, ProgressCallback, ServerContext
|
||||
from .server import ServerContext
|
||||
from .worker import WorkerContext, ProgressCallback
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run_upscale_correction(
|
||||
job: JobContext,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .context import WorkerContext, ProgressCallback
|
||||
from .pool import DevicePoolExecutor
|
|
@ -0,0 +1,60 @@
|
|||
from logging import getLogger
|
||||
from torch.multiprocessing import Queue, Value
|
||||
from typing import Any, Callable
|
||||
|
||||
from ..params import DeviceParams
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[int, int, Any], None]
|
||||
|
||||
class WorkerContext:
|
||||
cancel: "Value[bool]" = None
|
||||
key: str = None
|
||||
progress: "Value[int]" = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
cancel: "Value[bool]",
|
||||
device: DeviceParams,
|
||||
pending: "Queue[Any]",
|
||||
progress: "Value[int]",
|
||||
):
|
||||
self.key = key
|
||||
self.cancel = cancel
|
||||
self.device = device
|
||||
self.pending = pending
|
||||
self.progress = progress
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def get_device(self) -> DeviceParams:
|
||||
"""
|
||||
Get the device assigned to this job.
|
||||
"""
|
||||
return self.device
|
||||
|
||||
def get_progress(self) -> int:
|
||||
return self.progress.value
|
||||
|
||||
def get_progress_callback(self) -> ProgressCallback:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
on_progress.step = step
|
||||
if self.is_cancelled():
|
||||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.key, step)
|
||||
self.set_progress(step)
|
||||
|
||||
return on_progress
|
||||
|
||||
def set_cancel(self, cancel: bool = True) -> None:
|
||||
with self.cancel.get_lock():
|
||||
self.cancel.value = cancel
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
with self.progress.get_lock():
|
||||
self.progress.value = progress
|
|
@ -0,0 +1 @@
|
|||
# TODO: queue-based logger
|
|
@ -0,0 +1,136 @@
|
|||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from multiprocessing import Queue
|
||||
from torch.multiprocessing import Lock, Process, Value
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ..params import DeviceParams
|
||||
from .context import WorkerContext
|
||||
from .worker import logger_init, worker_init
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
devices: List[DeviceParams] = None
|
||||
finished: List[Tuple[str, int]] = None
|
||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||
progress: Dict[str, Value] = None
|
||||
workers: Dict[str, Process] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
devices: List[DeviceParams],
|
||||
finished_limit: int = 10,
|
||||
):
|
||||
self.devices = devices
|
||||
self.finished = []
|
||||
self.finished_limit = finished_limit
|
||||
self.lock = Lock()
|
||||
self.pending = {}
|
||||
self.progress = {}
|
||||
self.workers = {}
|
||||
|
||||
log_queue = Queue()
|
||||
logger_context = WorkerContext("logger", None, None, log_queue, None)
|
||||
|
||||
logger.debug("starting log worker")
|
||||
self.logger = Process(target=logger_init, args=(self.lock, logger_context))
|
||||
self.logger.start()
|
||||
|
||||
# create a pending queue and progress value for each device
|
||||
for device in devices:
|
||||
name = device.device
|
||||
cancel = Value("B", False, lock=self.lock)
|
||||
progress = Value("I", 0, lock=self.lock)
|
||||
pending = Queue()
|
||||
context = WorkerContext(name, cancel, device, pending, progress)
|
||||
self.pending[name] = pending
|
||||
self.progress[name] = pending
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
|
||||
self.workers[name].start()
|
||||
|
||||
logger.debug("testing log worker")
|
||||
log_queue.put("testing")
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
"""
|
||||
Cancel a job. If the job has not been started, this will cancel
|
||||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
for k, progress in self.finished:
|
||||
if key == k:
|
||||
return (True, progress)
|
||||
|
||||
logger.warn("checking status for unknown key: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
# respect overrides if possible
|
||||
if needs_device is not None:
|
||||
for i in range(len(self.devices)):
|
||||
if self.devices[i].device == needs_device.device:
|
||||
return i
|
||||
|
||||
pending = [
|
||||
self.pending[d.device].qsize() for d in self.devices
|
||||
]
|
||||
jobs = Counter(range(len(self.devices)))
|
||||
jobs.update(pending)
|
||||
|
||||
queued = jobs.most_common()
|
||||
logger.debug("jobs queued by device: %s", queued)
|
||||
|
||||
lowest_count = queued[-1][1]
|
||||
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
||||
lowest_devices.sort()
|
||||
|
||||
return lowest_devices[0]
|
||||
|
||||
def prune(self):
|
||||
finished_count = len(self.finished)
|
||||
if finished_count > self.finished_limit:
|
||||
logger.debug(
|
||||
"pruning %s of %s finished jobs",
|
||||
finished_count - self.finished_limit,
|
||||
finished_count,
|
||||
)
|
||||
self.finished[:] = self.finished[-self.finished_limit:]
|
||||
|
||||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
fn: Callable[..., None],
|
||||
/,
|
||||
*args,
|
||||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.prune()
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
logger.info(
|
||||
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
||||
)
|
||||
|
||||
device = self.devices[device_idx]
|
||||
queue = self.pending[device.device]
|
||||
queue.put((fn, args, kwargs))
|
||||
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||
pending = [
|
||||
(
|
||||
device.device,
|
||||
self.pending[device.device].qsize(),
|
||||
self.workers[device.device].is_alive(),
|
||||
)
|
||||
for device in self.devices
|
||||
]
|
||||
pending.extend(self.finished)
|
||||
return pending
|
|
@ -0,0 +1,32 @@
|
|||
from logging import getLogger
|
||||
from torch.multiprocessing import Lock
|
||||
from time import sleep
|
||||
|
||||
from .context import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def logger_init(lock: Lock, context: WorkerContext):
|
||||
logger.info("checking in from logger")
|
||||
|
||||
with open("worker.log", "w") as f:
|
||||
while True:
|
||||
if context.pending.empty():
|
||||
logger.info("no logs, sleeping")
|
||||
sleep(5)
|
||||
else:
|
||||
job = context.pending.get()
|
||||
logger.info("got log: %s", job)
|
||||
f.write(str(job) + "\n\n")
|
||||
|
||||
|
||||
def worker_init(lock: Lock, context: WorkerContext):
|
||||
logger.info("checking in from worker")
|
||||
|
||||
while True:
|
||||
if context.pending.empty():
|
||||
logger.info("no jobs, sleeping")
|
||||
sleep(5)
|
||||
else:
|
||||
job = context.pending.get()
|
||||
logger.info("got job: %s", job)
|
Loading…
Reference in New Issue