1
0
Fork 0

lint all the new stuff

This commit is contained in:
Sean Sube 2023-02-26 14:15:30 -06:00
parent b880b7a121
commit 584dddb5d6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
34 changed files with 183 additions and 174 deletions

View File

@ -57,4 +57,4 @@ from .utils import (
) )
from .worker import ( from .worker import (
DevicePoolExecutor, DevicePoolExecutor,
) )

View File

@ -29,4 +29,4 @@ CHAIN_STAGES = {
"upscale-outpaint": upscale_outpaint, "upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan, "upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion, "upscale-stable-diffusion": upscale_stable_diffusion,
} }

View File

@ -7,9 +7,9 @@ from PIL import Image
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,8 +7,8 @@ from PIL import Image
from ..diffusion.load import load_pipeline from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,9 +7,9 @@ from onnx_web.image import valid_image
from onnx_web.output import save_image from onnx_web.output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -3,8 +3,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -5,10 +5,9 @@ import numpy as np
from PIL import Image from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -4,9 +4,8 @@ from PIL import Image
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -5,9 +5,8 @@ from boto3 import Session
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -3,9 +3,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -3,9 +3,8 @@ from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -4,9 +4,8 @@ from typing import Callable
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,9 +7,8 @@ from PIL import Image
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_grid, process_tile_order from .utils import process_tile_grid, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -6,9 +6,9 @@ from PIL import Image
from ..onnx import OnnxNet from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -4,11 +4,9 @@ from typing import Any, Optional, Tuple
import numpy as np import numpy as np
from diffusers import ( from diffusers import (
DiffusionPipeline,
OnnxRuntimeModel,
StableDiffusionPipeline,
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
@ -19,7 +17,9 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
OnnxRuntimeModel,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline,
) )
try: try:

View File

@ -12,10 +12,10 @@ from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint from ..chain import upscale_outpaint
from ..output import save_image, save_params from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext from ..server import ServerContext
from ..upscale import run_upscale_correction from ..upscale import run_upscale_correction
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline from .load import get_latents_from_seed, load_pipeline
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,49 +7,53 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
from torch.multiprocessing import set_start_method from torch.multiprocessing import set_start_method
from .server.api import register_api_routes from .server.api import register_api_routes
from .server.static import register_static_routes from .server.config import (
from .server.config import get_available_platforms, load_models, load_params, load_platforms get_available_platforms,
from .server.utils import check_paths load_models,
load_params,
load_platforms,
)
from .server.context import ServerContext from .server.context import ServerContext
from .server.hacks import apply_patches from .server.hacks import apply_patches
from .utils import ( from .server.static import register_static_routes
is_debug, from .server.utils import check_paths
) from .utils import is_debug
from .worker import DevicePoolExecutor from .worker import DevicePoolExecutor
def main(): def main():
set_start_method("spawn", force=True) set_start_method("spawn", force=True)
context = ServerContext.from_environ() context = ServerContext.from_environ()
apply_patches(context) apply_patches(context)
check_paths(context) check_paths(context)
load_models(context) load_models(context)
load_params(context) load_params(context)
load_platforms(context) load_platforms(context)
if is_debug(): if is_debug():
gc.set_debug(gc.DEBUG_STATS) gc.set_debug(gc.DEBUG_STATS)
if not context.show_progress: if not context.show_progress:
disable_progress_bar() disable_progress_bar()
disable_progress_bars() disable_progress_bars()
app = Flask(__name__) app = Flask(__name__)
CORS(app, origins=context.cors_origin) CORS(app, origins=context.cors_origin)
# any is a fake device, should not be in the pool # any is a fake device, should not be in the pool
pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"]) pool = DevicePoolExecutor(
context, [p for p in get_available_platforms() if p.device != "any"]
)
# register routes # register routes
register_static_routes(app, context, pool) register_static_routes(app, context, pool)
register_api_routes(app, context, pool) register_api_routes(app, context, pool)
return app, pool return app, pool
if __name__ == "__main__": if __name__ == "__main__":
app, pool = main() app, pool = main()
app.run("0.0.0.0", 5000, debug=is_debug()) app.run("0.0.0.0", 5000, debug=is_debug())
pool.join() pool.join()

View File

@ -3,9 +3,9 @@ from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from onnxruntime import InferenceSession, SessionOptions
from ..server import ServerContext from ..server import ServerContext
from .torch_before_ort import InferenceSession, SessionOptions
class OnnxTensor: class OnnxTensor:

View File

@ -0,0 +1,5 @@
# this file exists to make sure torch is always imported before onnxruntime
# to work around https://github.com/microsoft/onnxruntime/issues/11092
import torch # NOQA
from onnxruntime import * # NOQA

View File

@ -2,8 +2,7 @@ from enum import IntEnum
from logging import getLogger from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
from onnxruntime import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -7,27 +7,7 @@ from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate from jsonschema import validate
from PIL import Image from PIL import Image
from .context import ServerContext from ..chain import CHAIN_STAGES, ChainPipeline
from .utils import wrap_route
from ..worker.pool import DevicePoolExecutor
from .config import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_inversion_models,
get_mask_filters,
get_noise_sources,
get_upscaling_models,
)
from .params import border_from_request, pipeline_from_request, upscale_from_request
from ..chain import (
CHAIN_STAGES,
ChainPipeline,
)
from ..diffusion.load import get_pipeline_schedulers from ..diffusion.load import get_pipeline_schedulers
from ..diffusion.run import ( from ..diffusion.run import (
run_blend_pipeline, run_blend_pipeline,
@ -36,16 +16,9 @@ from ..diffusion.run import (
run_txt2img_pipeline, run_txt2img_pipeline,
run_upscale_pipeline, run_upscale_pipeline,
) )
from ..image import ( # mask filters; noise sources from ..image import valid_image # mask filters; noise sources
valid_image,
)
from ..output import json_params, make_output_name from ..output import json_params, make_output_name
from ..params import ( from ..params import Border, StageParams, TileOrder, UpscaleParams
Border,
StageParams,
TileOrder,
UpscaleParams,
)
from ..transformers import run_txt2txt_pipeline from ..transformers import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
base_join, base_join,
@ -56,6 +29,21 @@ from ..utils import (
get_not_empty, get_not_empty,
get_size, get_size,
) )
from ..worker.pool import DevicePoolExecutor
from .config import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_inversion_models,
get_mask_filters,
get_noise_sources,
get_upscaling_models,
)
from .context import ServerContext
from .params import border_from_request, pipeline_from_request, upscale_from_request
from .utils import wrap_route
logger = getLogger(__name__) logger = getLogger(__name__)
@ -456,22 +444,38 @@ def status(context: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
return [ return [
app.route("/api")(wrap_route(introspect, context, app=app)), app.route("/api")(wrap_route(introspect, context, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
app.route("/api/settings/models")(wrap_route(list_models, context)), app.route("/api/settings/models")(wrap_route(list_models, context)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
app.route("/api/settings/params")(wrap_route(list_params, context)), app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)), app.route("/api/img2img", methods=["POST"])(
app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)), wrap_route(img2img, context, pool=pool)
app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)), ),
app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)), app.route("/api/txt2img", methods=["POST"])(
app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)), wrap_route(txt2img, context, pool=pool)
app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)), ),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)), app.route("/api/txt2txt", methods=["POST"])(
app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)), wrap_route(txt2txt, context, pool=pool)
app.route("/api/ready")(wrap_route(ready, context, pool=pool)), ),
app.route("/api/status")(wrap_route(status, context, pool=pool)), app.route("/api/inpaint", methods=["POST"])(
] wrap_route(inpaint, context, pool=pool)
),
app.route("/api/upscale", methods=["POST"])(
wrap_route(upscale, context, pool=pool)
),
app.route("/api/chain", methods=["POST"])(
wrap_route(chain, context, pool=pool)
),
app.route("/api/blend", methods=["POST"])(
wrap_route(blend, context, pool=pool)
),
app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, context, pool=pool)
),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)),
]

View File

@ -2,13 +2,11 @@ from functools import cmp_to_key
from glob import glob from glob import glob
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Dict, List, Optional, Union from typing import Dict, List, Union
import torch import torch
import yaml import yaml
from onnxruntime import get_available_providers
from .context import ServerContext
from ..image import ( # mask filters; noise sources from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
mask_filter_gaussian_screen, mask_filter_gaussian_screen,
@ -20,9 +18,9 @@ from ..image import ( # mask filters; noise sources
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from ..params import ( from ..onnx.torch_before_ort import get_available_providers
DeviceParams, from ..params import DeviceParams
) from .context import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -221,4 +219,3 @@ def load_platforms(context: ServerContext) -> None:
"available acceleration platforms: %s", "available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]), ", ".join([str(p) for p in available_platforms]),
) )

View File

@ -4,30 +4,24 @@ from typing import Tuple
import numpy as np import numpy as np
from flask import request from flask import request
from .context import ServerContext
from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models
from .utils import get_model_path
from ..diffusion.load import pipeline_schedulers from ..diffusion.load import pipeline_schedulers
from ..params import ( from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
Border, from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
DeviceParams, from .config import (
ImageParams, get_available_platforms,
Size, get_config_value,
UpscaleParams, get_correction_models,
) get_upscaling_models,
from ..utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_not_empty,
) )
from .context import ServerContext
from .utils import get_model_path
logger = getLogger(__name__) logger = getLogger(__name__)
def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]: def pipeline_from_request(
context: ServerContext,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr user = request.remote_addr
# platform stuff # platform stuff
@ -43,9 +37,7 @@ def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImagePa
lpw = get_not_empty(request.args, "lpw", "false") == "true" lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model")) model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model) model_path = get_model_path(context, model)
scheduler = get_from_list( scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
request.args, "scheduler", pipeline_schedulers.keys()
)
if scheduler is None: if scheduler is None:
scheduler = get_config_value("scheduler") scheduler = get_config_value("scheduler")

View File

@ -2,9 +2,9 @@ from os import path
from flask import Flask, send_from_directory from flask import Flask, send_from_directory
from .utils import wrap_route
from .context import ServerContext
from ..worker.pool import DevicePoolExecutor from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .utils import wrap_route
def serve_bundle_file(context: ServerContext, filename="index.html"): def serve_bundle_file(context: ServerContext, filename="index.html"):
@ -26,9 +26,11 @@ def output(context: ServerContext, filename: str):
) )
def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): def register_static_routes(
return [ app: Flask, context: ServerContext, pool: DevicePoolExecutor
app.route("/")(wrap_route(index, context)), ):
app.route("/<path:filename>")(wrap_route(index_path, context)), return [
app.route("/output/<path:filename>")(wrap_route(output, context)), app.route("/")(wrap_route(index, context)),
] app.route("/<path:filename>")(wrap_route(index_path, context)),
app.route("/output/<path:filename>")(wrap_route(output, context)),
]

View File

@ -1,6 +1,6 @@
from functools import partial, update_wrapper
from os import makedirs, path from os import makedirs, path
from typing import Callable, Dict, List, Tuple from typing import Callable, Dict, List, Tuple
from functools import partial, update_wrapper
from flask import Flask from flask import Flask
@ -22,7 +22,12 @@ def get_model_path(context: ServerContext, model: str):
return base_join(context.model_path, model) return base_join(context.model_path, model)
def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]): def register_routes(
app: Flask,
context: ServerContext,
pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]],
):
pass pass

View File

@ -11,7 +11,7 @@ from .chain import (
) )
from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext from .server import ServerContext
from .worker import WorkerContext, ProgressCallback from .worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -1,2 +1,2 @@
from .context import WorkerContext, ProgressCallback from .context import WorkerContext, ProgressCallback
from .pool import DevicePoolExecutor from .pool import DevicePoolExecutor

View File

@ -1,7 +1,8 @@
from logging import getLogger from logging import getLogger
from torch.multiprocessing import Queue, Value
from typing import Any, Callable, Tuple from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
logger = getLogger(__name__) logger = getLogger(__name__)
@ -9,6 +10,7 @@ logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None] ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext: class WorkerContext:
cancel: "Value[bool]" = None cancel: "Value[bool]" = None
key: str = None key: str = None

View File

@ -1 +1 @@
# TODO: queue-based logger # TODO: queue-based logger

View File

@ -1,9 +1,10 @@
from collections import Counter from collections import Counter
from logging import getLogger from logging import getLogger
from multiprocessing import Queue from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Lock, Process, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .context import WorkerContext from .context import WorkerContext
@ -35,7 +36,7 @@ class DevicePoolExecutor:
self.pending = {} self.pending = {}
self.progress = {} self.progress = {}
self.workers = {} self.workers = {}
self.jobs = {} # Dict[Output, Device] self.jobs = {} # Dict[Output, Device]
self.job_count = 0 self.job_count = 0
# TODO: make this a method # TODO: make this a method
@ -58,15 +59,21 @@ class DevicePoolExecutor:
cancel = Value("B", False, lock=lock) cancel = Value("B", False, lock=lock)
finished = Value("B", False) finished = Value("B", False)
self.finished[name] = finished self.finished[name] = finished
progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? progress = Value(
"I", 0
) # , lock=lock) # needs its own lock for some reason. TODO: why?
self.progress[name] = progress self.progress[name] = progress
pending = Queue() pending = Queue()
self.pending[name] = pending self.pending[name] = pending
context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished) context = WorkerContext(
name, cancel, device, pending, progress, self.log_queue, finished
)
self.context[name] = context self.context[name] = context
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(lock, context, server)) self.workers[name] = Process(
target=worker_init, args=(lock, context, server)
)
self.workers[name].start() self.workers[name].start()
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
@ -78,7 +85,7 @@ class DevicePoolExecutor:
raise NotImplementedError() raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Tuple[Optional[bool], int]:
if not key in self.jobs: if key not in self.jobs:
logger.warn("checking status for unknown key: %s", key) logger.warn("checking status for unknown key: %s", key)
return (None, 0) return (None, 0)
@ -88,7 +95,6 @@ class DevicePoolExecutor:
return (finished.value, progress.value) return (finished.value, progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible # respect overrides if possible
if needs_device is not None: if needs_device is not None:
@ -96,9 +102,7 @@ class DevicePoolExecutor:
if self.devices[i].device == needs_device.device: if self.devices[i].device == needs_device.device:
return i return i
pending = [ pending = [self.pending[d.device].qsize() for d in self.devices]
self.pending[d.device].qsize() for d in self.devices
]
jobs = Counter(range(len(self.devices))) jobs = Counter(range(len(self.devices)))
jobs.update(pending) jobs.update(pending)
@ -128,7 +132,7 @@ class DevicePoolExecutor:
finished_count - self.finished_limit, finished_count - self.finished_limit,
finished_count, finished_count,
) )
self.finished[:] = self.finished[-self.finished_limit:] self.finished[:] = self.finished[-self.finished_limit :]
def recycle(self): def recycle(self):
for name, proc in self.workers.items(): for name, proc in self.workers.items():
@ -149,10 +153,11 @@ class DevicePoolExecutor:
lock = self.locks[name] lock = self.locks[name]
logger.debug("starting worker for device %s", name) logger.debug("starting worker for device %s", name)
self.workers[name] = Process(target=worker_init, args=(lock, context, self.server)) self.workers[name] = Process(
target=worker_init, args=(lock, context, self.server)
)
self.workers[name].start() self.workers[name].start()
def submit( def submit(
self, self,
key: str, key: str,
@ -171,7 +176,10 @@ class DevicePoolExecutor:
self.prune() self.prune()
device_idx = self.get_next_device(needs_device=needs_device) device_idx = self.get_next_device(needs_device=needs_device)
logger.info( logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] "assigning job %s to device %s: %s",
key,
device_idx,
self.devices[device_idx],
) )
device = self.devices[device_idx] device = self.devices[device_idx]
@ -180,7 +188,6 @@ class DevicePoolExecutor:
self.jobs[key] = device.device self.jobs[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [ pending = [
( (

View File

@ -1,12 +1,12 @@
from logging import getLogger from logging import getLogger
import torch # has to come before ORT
from onnxruntime import get_available_providers
from torch.multiprocessing import Lock, Queue
from traceback import format_exception from traceback import format_exception
from setproctitle import setproctitle
from .context import WorkerContext from setproctitle import setproctitle
from torch.multiprocessing import Lock, Queue
from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches from ..server import ServerContext, apply_patches
from .context import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -29,7 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
logger.info("checking in from worker, %s, %s", lock, get_available_providers()) logger.info("checking in from worker, %s, %s", lock, get_available_providers())
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s", context.device.device) setproctitle("onnx-web worker: %s" % (context.device.device))
while True: while True:
job = context.pending.get() job = context.pending.get()
@ -52,4 +52,3 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
except Exception as e: except Exception as e:
logger.error(format_exception(type(e), e, e.__traceback__)) logger.error(format_exception(type(e), e, e.__traceback__))