background workers, logger
This commit is contained in:
parent
e46a1e5fd0
commit
f898de8c54
|
@ -5,14 +5,14 @@ formatters:
|
||||||
handlers:
|
handlers:
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
level: INFO
|
level: DEBUG
|
||||||
formatter: simple
|
formatter: simple
|
||||||
stream: ext://sys.stdout
|
stream: ext://sys.stdout
|
||||||
loggers:
|
loggers:
|
||||||
'':
|
'':
|
||||||
level: INFO
|
level: DEBUG
|
||||||
handlers: [console]
|
handlers: [console]
|
||||||
propagate: True
|
propagate: True
|
||||||
root:
|
root:
|
||||||
level: INFO
|
level: DEBUG
|
||||||
handlers: [console]
|
handlers: [console]
|
||||||
|
|
|
@ -25,6 +25,7 @@ from .image import (
|
||||||
from .onnx import OnnxNet, OnnxTensor
|
from .onnx import OnnxNet, OnnxTensor
|
||||||
from .params import (
|
from .params import (
|
||||||
Border,
|
Border,
|
||||||
|
DeviceParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Param,
|
Param,
|
||||||
Point,
|
Point,
|
||||||
|
@ -33,8 +34,6 @@ from .params import (
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from .server import (
|
from .server import (
|
||||||
DeviceParams,
|
|
||||||
DevicePoolExecutor,
|
|
||||||
ModelCache,
|
ModelCache,
|
||||||
ServerContext,
|
ServerContext,
|
||||||
apply_patch_basicsr,
|
apply_patch_basicsr,
|
||||||
|
@ -51,3 +50,6 @@ from .utils import (
|
||||||
get_from_map,
|
get_from_map,
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
)
|
)
|
||||||
|
from .worker import (
|
||||||
|
DevicePoolExecutor,
|
||||||
|
)
|
|
@ -7,7 +7,8 @@ from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ logger = getLogger(__name__)
|
||||||
class StageCallback(Protocol):
|
class StageCallback(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -77,7 +78,7 @@ class ChainPipeline:
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
|
|
@ -7,13 +7,14 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusion.load import load_pipeline
|
from ..diffusion.load import load_pipeline
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_img2img(
|
def blend_img2img(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_inpaint(
|
def blend_inpaint(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -7,14 +7,15 @@ from onnx_web.image import valid_image
|
||||||
from onnx_web.output import save_image
|
from onnx_web.output import save_image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_mask(
|
def blend_mask(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -3,7 +3,8 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
def correct_codeformer(
|
def correct_codeformer(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -5,8 +5,10 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import JobContext, ServerContext
|
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -46,7 +48,7 @@ def load_gfpgan(
|
||||||
|
|
||||||
|
|
||||||
def correct_gfpgan(
|
def correct_gfpgan(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -4,13 +4,15 @@ from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_disk(
|
def persist_disk(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -5,13 +5,15 @@ from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_s3(
|
def persist_s3(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -3,13 +3,15 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_crop(
|
def reduce_crop(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -3,13 +3,15 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_thumbnail(
|
def reduce_thumbnail(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -4,13 +4,15 @@ from typing import Callable
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_noise(
|
def source_noise(
|
||||||
_job: JobContext,
|
_job: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -7,13 +7,15 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_txt2img(
|
def source_txt2img(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from .utils import process_tile_grid, process_tile_order
|
from .utils import process_tile_grid, process_tile_order
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upscale_outpaint(
|
def upscale_outpaint(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -6,7 +6,8 @@ from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxNet
|
from ..onnx import OnnxNet
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -96,7 +97,7 @@ def load_resrgan(
|
||||||
|
|
||||||
|
|
||||||
def upscale_resrgan(
|
def upscale_resrgan(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -10,7 +10,8 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import JobContext, ProgressCallback, ServerContext
|
from ..worker import WorkerContext, ProgressCallback
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -62,7 +63,7 @@ def load_stable_diffusion(
|
||||||
|
|
||||||
|
|
||||||
def upscale_stable_diffusion(
|
def upscale_stable_diffusion(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -12,7 +12,8 @@ from onnx_web.chain.base import ChainProgress
|
||||||
from ..chain import upscale_outpaint
|
from ..chain import upscale_outpaint
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..server import JobContext, ServerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..server import ServerContext
|
||||||
from ..upscale import run_upscale_correction
|
from ..upscale import run_upscale_correction
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
@ -21,7 +22,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -95,7 +96,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
|
@ -167,7 +168,7 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -217,7 +218,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -243,7 +244,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
|
|
||||||
def run_blend_pipeline(
|
def run_blend_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
|
|
@ -63,7 +63,7 @@ from .params import (
|
||||||
TileOrder,
|
TileOrder,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from .server import DevicePoolExecutor, ServerContext, apply_patches
|
from .server import ServerContext, apply_patches
|
||||||
from .transformers import run_txt2txt_pipeline
|
from .transformers import run_txt2txt_pipeline
|
||||||
from .utils import (
|
from .utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -75,6 +75,7 @@ from .utils import (
|
||||||
get_size,
|
get_size,
|
||||||
is_debug,
|
is_debug,
|
||||||
)
|
)
|
||||||
|
from .worker import DevicePoolExecutor
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,3 @@
|
||||||
from .device_pool import (
|
|
||||||
DeviceParams,
|
|
||||||
DevicePoolExecutor,
|
|
||||||
Job,
|
|
||||||
JobContext,
|
|
||||||
ProgressCallback,
|
|
||||||
)
|
|
||||||
from .hacks import (
|
from .hacks import (
|
||||||
apply_patch_basicsr,
|
apply_patch_basicsr,
|
||||||
apply_patch_codeformer,
|
apply_patch_codeformer,
|
||||||
|
|
|
@ -1,230 +0,0 @@
|
||||||
from collections import Counter
|
|
||||||
from concurrent.futures import Future
|
|
||||||
from logging import getLogger
|
|
||||||
from multiprocessing import Queue
|
|
||||||
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
|
|
||||||
from traceback import format_exception
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
from ..params import DeviceParams
|
|
||||||
from ..utils import run_gc
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
ProgressCallback = Callable[[int, int, Any], None]
|
|
||||||
|
|
||||||
|
|
||||||
def worker_init(lock: Lock, job_queue: SimpleQueue):
|
|
||||||
logger.info("checking in from worker")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if job_queue.empty():
|
|
||||||
logger.info("no jobs, sleeping")
|
|
||||||
sleep(5)
|
|
||||||
else:
|
|
||||||
job = job_queue.get()
|
|
||||||
logger.info("got job: %s", job)
|
|
||||||
|
|
||||||
|
|
||||||
class JobContext:
|
|
||||||
cancel: Value = None
|
|
||||||
device_index: Value = None
|
|
||||||
devices: List[DeviceParams] = None
|
|
||||||
key: str = None
|
|
||||||
progress: Value = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
devices: List[DeviceParams],
|
|
||||||
cancel: bool = False,
|
|
||||||
device_index: int = -1,
|
|
||||||
progress: int = 0,
|
|
||||||
):
|
|
||||||
self.key = key
|
|
||||||
self.devices = list(devices)
|
|
||||||
self.cancel = Value("B", cancel)
|
|
||||||
self.device_index = Value("i", device_index)
|
|
||||||
self.progress = Value("I", progress)
|
|
||||||
|
|
||||||
def is_cancelled(self) -> bool:
|
|
||||||
return self.cancel.value
|
|
||||||
|
|
||||||
def get_device(self) -> DeviceParams:
|
|
||||||
"""
|
|
||||||
Get the device assigned to this job.
|
|
||||||
"""
|
|
||||||
with self.device_index.get_lock():
|
|
||||||
device_index = self.device_index.value
|
|
||||||
if device_index < 0:
|
|
||||||
raise ValueError("job has not been assigned to a device")
|
|
||||||
else:
|
|
||||||
device = self.devices[device_index]
|
|
||||||
logger.debug("job %s assigned to device %s", self.key, device)
|
|
||||||
return device
|
|
||||||
|
|
||||||
def get_progress(self) -> int:
|
|
||||||
return self.progress.value
|
|
||||||
|
|
||||||
def get_progress_callback(self) -> ProgressCallback:
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
|
||||||
on_progress.step = step
|
|
||||||
if self.is_cancelled():
|
|
||||||
raise RuntimeError("job has been cancelled")
|
|
||||||
else:
|
|
||||||
logger.debug("setting progress for job %s to %s", self.key, step)
|
|
||||||
self.set_progress(step)
|
|
||||||
|
|
||||||
return on_progress
|
|
||||||
|
|
||||||
def set_cancel(self, cancel: bool = True) -> None:
|
|
||||||
with self.cancel.get_lock():
|
|
||||||
self.cancel.value = cancel
|
|
||||||
|
|
||||||
def set_progress(self, progress: int) -> None:
|
|
||||||
with self.progress.get_lock():
|
|
||||||
self.progress.value = progress
|
|
||||||
|
|
||||||
|
|
||||||
class Job:
|
|
||||||
"""
|
|
||||||
Link a future to its context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
context: JobContext = None
|
|
||||||
future: Future = None
|
|
||||||
key: str = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
future: Future,
|
|
||||||
context: JobContext,
|
|
||||||
):
|
|
||||||
self.context = context
|
|
||||||
self.future = future
|
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def get_progress(self) -> int:
|
|
||||||
return self.context.get_progress()
|
|
||||||
|
|
||||||
def set_cancel(self, cancel: bool = True):
|
|
||||||
return self.context.set_cancel(cancel)
|
|
||||||
|
|
||||||
def set_progress(self, progress: int):
|
|
||||||
return self.context.set_progress(progress)
|
|
||||||
|
|
||||||
|
|
||||||
class DevicePoolExecutor:
|
|
||||||
devices: List[DeviceParams] = None
|
|
||||||
finished: List[Tuple[str, int]] = None
|
|
||||||
pending: Dict[str, "Queue[Job]"] = None
|
|
||||||
progress: Dict[str, Value] = None
|
|
||||||
workers: Dict[str, Process] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
devices: List[DeviceParams],
|
|
||||||
finished_limit: int = 10,
|
|
||||||
):
|
|
||||||
self.devices = devices
|
|
||||||
self.finished = []
|
|
||||||
self.finished_limit = finished_limit
|
|
||||||
self.lock = Lock()
|
|
||||||
self.pending = {}
|
|
||||||
self.progress = {}
|
|
||||||
self.workers = {}
|
|
||||||
|
|
||||||
# create a pending queue and progress value for each device
|
|
||||||
for device in devices:
|
|
||||||
name = device.device
|
|
||||||
job_queue = Queue()
|
|
||||||
self.pending[name] = job_queue
|
|
||||||
self.progress[name] = Value("I", 0, lock=self.lock)
|
|
||||||
self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
|
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Cancel a job. If the job has not been started, this will cancel
|
|
||||||
the future and never execute it. If the job has been started, it
|
|
||||||
should be cancelled on the next progress callback.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
|
||||||
for k, progress in self.finished:
|
|
||||||
if key == k:
|
|
||||||
return (True, progress)
|
|
||||||
|
|
||||||
logger.warn("checking status for unknown key: %s", key)
|
|
||||||
return (None, 0)
|
|
||||||
|
|
||||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
|
||||||
# respect overrides if possible
|
|
||||||
if needs_device is not None:
|
|
||||||
for i in range(len(self.devices)):
|
|
||||||
if self.devices[i].device == needs_device.device:
|
|
||||||
return i
|
|
||||||
|
|
||||||
# use the first/default device if there are no jobs
|
|
||||||
if len(self.jobs) == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
job_devices = [
|
|
||||||
job.context.device_index.value for job in self.jobs if not job.future.done()
|
|
||||||
]
|
|
||||||
job_counts = Counter(range(len(self.devices)))
|
|
||||||
job_counts.update(job_devices)
|
|
||||||
|
|
||||||
queued = job_counts.most_common()
|
|
||||||
logger.debug("jobs queued by device: %s", queued)
|
|
||||||
|
|
||||||
lowest_count = queued[-1][1]
|
|
||||||
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
|
||||||
lowest_devices.sort()
|
|
||||||
|
|
||||||
return lowest_devices[0]
|
|
||||||
|
|
||||||
def prune(self):
|
|
||||||
finished_count = len(self.finished)
|
|
||||||
if finished_count > self.finished_limit:
|
|
||||||
logger.debug(
|
|
||||||
"pruning %s of %s finished jobs",
|
|
||||||
finished_count - self.finished_limit,
|
|
||||||
finished_count,
|
|
||||||
)
|
|
||||||
self.finished[:] = self.finished[-self.finished_limit:]
|
|
||||||
|
|
||||||
def submit(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
fn: Callable[..., None],
|
|
||||||
/,
|
|
||||||
*args,
|
|
||||||
needs_device: Optional[DeviceParams] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
self.prune()
|
|
||||||
device_idx = self.get_next_device(needs_device=needs_device)
|
|
||||||
logger.info(
|
|
||||||
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
|
||||||
)
|
|
||||||
|
|
||||||
context = JobContext(key, self.devices, device_index=device_idx)
|
|
||||||
device = self.devices[device_idx]
|
|
||||||
|
|
||||||
queue = self.pending[device.device]
|
|
||||||
queue.put((fn, context, args, kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
|
||||||
pending = [
|
|
||||||
(
|
|
||||||
device.device,
|
|
||||||
self.pending[device.device].qsize(),
|
|
||||||
)
|
|
||||||
for device in self.devices
|
|
||||||
]
|
|
||||||
pending.extend(self.finished)
|
|
||||||
return pending
|
|
|
@ -1,13 +1,14 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from .params import ImageParams, Size
|
from .params import ImageParams, Size
|
||||||
from .server import JobContext, ServerContext
|
from .server import ServerContext
|
||||||
|
from .worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_txt2txt_pipeline(
|
def run_txt2txt_pipeline(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_size: Size,
|
_size: Size,
|
||||||
|
|
|
@ -10,13 +10,14 @@ from .chain import (
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from .server import JobContext, ProgressCallback, ServerContext
|
from .server import ServerContext
|
||||||
|
from .worker import WorkerContext, ProgressCallback
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_correction(
|
def run_upscale_correction(
|
||||||
job: JobContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .context import WorkerContext, ProgressCallback
|
||||||
|
from .pool import DevicePoolExecutor
|
|
@ -0,0 +1,60 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from torch.multiprocessing import Queue, Value
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from ..params import DeviceParams
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ProgressCallback = Callable[[int, int, Any], None]
|
||||||
|
|
||||||
|
class WorkerContext:
|
||||||
|
cancel: "Value[bool]" = None
|
||||||
|
key: str = None
|
||||||
|
progress: "Value[int]" = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
cancel: "Value[bool]",
|
||||||
|
device: DeviceParams,
|
||||||
|
pending: "Queue[Any]",
|
||||||
|
progress: "Value[int]",
|
||||||
|
):
|
||||||
|
self.key = key
|
||||||
|
self.cancel = cancel
|
||||||
|
self.device = device
|
||||||
|
self.pending = pending
|
||||||
|
self.progress = progress
|
||||||
|
|
||||||
|
def is_cancelled(self) -> bool:
|
||||||
|
return self.cancel.value
|
||||||
|
|
||||||
|
def get_device(self) -> DeviceParams:
|
||||||
|
"""
|
||||||
|
Get the device assigned to this job.
|
||||||
|
"""
|
||||||
|
return self.device
|
||||||
|
|
||||||
|
def get_progress(self) -> int:
|
||||||
|
return self.progress.value
|
||||||
|
|
||||||
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
|
on_progress.step = step
|
||||||
|
if self.is_cancelled():
|
||||||
|
raise RuntimeError("job has been cancelled")
|
||||||
|
else:
|
||||||
|
logger.debug("setting progress for job %s to %s", self.key, step)
|
||||||
|
self.set_progress(step)
|
||||||
|
|
||||||
|
return on_progress
|
||||||
|
|
||||||
|
def set_cancel(self, cancel: bool = True) -> None:
|
||||||
|
with self.cancel.get_lock():
|
||||||
|
self.cancel.value = cancel
|
||||||
|
|
||||||
|
def set_progress(self, progress: int) -> None:
|
||||||
|
with self.progress.get_lock():
|
||||||
|
self.progress.value = progress
|
|
@ -0,0 +1 @@
|
||||||
|
# TODO: queue-based logger
|
|
@ -0,0 +1,136 @@
|
||||||
|
from collections import Counter
|
||||||
|
from logging import getLogger
|
||||||
|
from multiprocessing import Queue
|
||||||
|
from torch.multiprocessing import Lock, Process, Value
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..params import DeviceParams
|
||||||
|
from .context import WorkerContext
|
||||||
|
from .worker import logger_init, worker_init
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DevicePoolExecutor:
|
||||||
|
devices: List[DeviceParams] = None
|
||||||
|
finished: List[Tuple[str, int]] = None
|
||||||
|
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||||
|
progress: Dict[str, Value] = None
|
||||||
|
workers: Dict[str, Process] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
devices: List[DeviceParams],
|
||||||
|
finished_limit: int = 10,
|
||||||
|
):
|
||||||
|
self.devices = devices
|
||||||
|
self.finished = []
|
||||||
|
self.finished_limit = finished_limit
|
||||||
|
self.lock = Lock()
|
||||||
|
self.pending = {}
|
||||||
|
self.progress = {}
|
||||||
|
self.workers = {}
|
||||||
|
|
||||||
|
log_queue = Queue()
|
||||||
|
logger_context = WorkerContext("logger", None, None, log_queue, None)
|
||||||
|
|
||||||
|
logger.debug("starting log worker")
|
||||||
|
self.logger = Process(target=logger_init, args=(self.lock, logger_context))
|
||||||
|
self.logger.start()
|
||||||
|
|
||||||
|
# create a pending queue and progress value for each device
|
||||||
|
for device in devices:
|
||||||
|
name = device.device
|
||||||
|
cancel = Value("B", False, lock=self.lock)
|
||||||
|
progress = Value("I", 0, lock=self.lock)
|
||||||
|
pending = Queue()
|
||||||
|
context = WorkerContext(name, cancel, device, pending, progress)
|
||||||
|
self.pending[name] = pending
|
||||||
|
self.progress[name] = pending
|
||||||
|
|
||||||
|
logger.debug("starting worker for device %s", device)
|
||||||
|
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
|
||||||
|
self.workers[name].start()
|
||||||
|
|
||||||
|
logger.debug("testing log worker")
|
||||||
|
log_queue.put("testing")
|
||||||
|
|
||||||
|
def cancel(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Cancel a job. If the job has not been started, this will cancel
|
||||||
|
the future and never execute it. If the job has been started, it
|
||||||
|
should be cancelled on the next progress callback.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||||
|
for k, progress in self.finished:
|
||||||
|
if key == k:
|
||||||
|
return (True, progress)
|
||||||
|
|
||||||
|
logger.warn("checking status for unknown key: %s", key)
|
||||||
|
return (None, 0)
|
||||||
|
|
||||||
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||||
|
# respect overrides if possible
|
||||||
|
if needs_device is not None:
|
||||||
|
for i in range(len(self.devices)):
|
||||||
|
if self.devices[i].device == needs_device.device:
|
||||||
|
return i
|
||||||
|
|
||||||
|
pending = [
|
||||||
|
self.pending[d.device].qsize() for d in self.devices
|
||||||
|
]
|
||||||
|
jobs = Counter(range(len(self.devices)))
|
||||||
|
jobs.update(pending)
|
||||||
|
|
||||||
|
queued = jobs.most_common()
|
||||||
|
logger.debug("jobs queued by device: %s", queued)
|
||||||
|
|
||||||
|
lowest_count = queued[-1][1]
|
||||||
|
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
||||||
|
lowest_devices.sort()
|
||||||
|
|
||||||
|
return lowest_devices[0]
|
||||||
|
|
||||||
|
def prune(self):
|
||||||
|
finished_count = len(self.finished)
|
||||||
|
if finished_count > self.finished_limit:
|
||||||
|
logger.debug(
|
||||||
|
"pruning %s of %s finished jobs",
|
||||||
|
finished_count - self.finished_limit,
|
||||||
|
finished_count,
|
||||||
|
)
|
||||||
|
self.finished[:] = self.finished[-self.finished_limit:]
|
||||||
|
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
fn: Callable[..., None],
|
||||||
|
/,
|
||||||
|
*args,
|
||||||
|
needs_device: Optional[DeviceParams] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.prune()
|
||||||
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
|
logger.info(
|
||||||
|
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
device = self.devices[device_idx]
|
||||||
|
queue = self.pending[device.device]
|
||||||
|
queue.put((fn, args, kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||||
|
pending = [
|
||||||
|
(
|
||||||
|
device.device,
|
||||||
|
self.pending[device.device].qsize(),
|
||||||
|
self.workers[device.device].is_alive(),
|
||||||
|
)
|
||||||
|
for device in self.devices
|
||||||
|
]
|
||||||
|
pending.extend(self.finished)
|
||||||
|
return pending
|
|
@ -0,0 +1,32 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from torch.multiprocessing import Lock
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from .context import WorkerContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
def logger_init(lock: Lock, context: WorkerContext):
|
||||||
|
logger.info("checking in from logger")
|
||||||
|
|
||||||
|
with open("worker.log", "w") as f:
|
||||||
|
while True:
|
||||||
|
if context.pending.empty():
|
||||||
|
logger.info("no logs, sleeping")
|
||||||
|
sleep(5)
|
||||||
|
else:
|
||||||
|
job = context.pending.get()
|
||||||
|
logger.info("got log: %s", job)
|
||||||
|
f.write(str(job) + "\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init(lock: Lock, context: WorkerContext):
|
||||||
|
logger.info("checking in from worker")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if context.pending.empty():
|
||||||
|
logger.info("no jobs, sleeping")
|
||||||
|
sleep(5)
|
||||||
|
else:
|
||||||
|
job = context.pending.get()
|
||||||
|
logger.info("got job: %s", job)
|
Loading…
Reference in New Issue