From 584dddb5d69c5a7091617b5e3278665070a386ab Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 14:15:30 -0600 Subject: [PATCH] lint all the new stuff --- api/onnx_web/__init__.py | 2 +- api/onnx_web/chain/__init__.py | 2 +- api/onnx_web/chain/base.py | 2 +- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/blend_inpaint.py | 2 +- api/onnx_web/chain/blend_mask.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 2 +- api/onnx_web/chain/correct_gfpgan.py | 3 +- api/onnx_web/chain/persist_disk.py | 3 +- api/onnx_web/chain/persist_s3.py | 3 +- api/onnx_web/chain/reduce_crop.py | 3 +- api/onnx_web/chain/reduce_thumbnail.py | 3 +- api/onnx_web/chain/source_noise.py | 3 +- api/onnx_web/chain/source_txt2img.py | 3 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- .../chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/diffusion/load.py | 6 +- api/onnx_web/diffusion/run.py | 2 +- api/onnx_web/main.py | 64 +++++------ api/onnx_web/onnx/onnx_net.py | 2 +- api/onnx_web/onnx/torch_before_ort.py | 5 + api/onnx_web/params.py | 3 +- api/onnx_web/server/api.py | 102 +++++++++--------- api/onnx_web/server/config.py | 11 +- api/onnx_web/server/params.py | 34 +++--- api/onnx_web/server/static.py | 18 ++-- api/onnx_web/server/utils.py | 9 +- api/onnx_web/upscale.py | 2 +- api/onnx_web/worker/__init__.py | 2 +- api/onnx_web/worker/context.py | 4 +- api/onnx_web/worker/logging.py | 2 +- api/onnx_web/worker/pool.py | 37 ++++--- api/onnx_web/worker/worker.py | 13 ++- 34 files changed, 183 insertions(+), 174 deletions(-) create mode 100644 api/onnx_web/onnx/torch_before_ort.py diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index b019bb3d..5294121c 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -57,4 +57,4 @@ from .utils import ( ) from .worker import ( DevicePoolExecutor, -) \ No newline at end of file +) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 44fdd6c1..a983c849 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -29,4 +29,4 @@ CHAIN_STAGES = { "upscale-outpaint": upscale_outpaint, "upscale-resrgan": upscale_resrgan, "upscale-stable-diffusion": upscale_stable_diffusion, -} \ No newline at end of file +} diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index dc807322..2ddc59ae 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,9 +7,9 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 67531103..f7c51605 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -7,8 +7,8 @@ from PIL import Image from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 51b6d983..7d864b5e 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index bfe11aab..f7b68e6f 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -7,9 +7,9 @@ from onnx_web.image import valid_image from onnx_web.output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 01d61db7..c3eaec65 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -3,8 +3,8 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index afcae86b..2cff2e18 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,10 +5,9 @@ import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..server import ServerContext from ..utils import run_gc from ..worker import WorkerContext -from ..server import ServerContext - logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 58020b57..eac0f36c 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,9 +4,8 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 926f1598..3e01b9ce 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,9 +5,8 @@ from boto3 import Session from PIL import Image from ..params import ImageParams, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 4cd715b1..cce82f0c 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -3,9 +3,8 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 4950a973..6df2ed6e 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,9 +3,8 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 9ab302b1..0092292c 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,9 +4,8 @@ from typing import Callable from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index b933ecc9..1dec3243 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -7,9 +7,8 @@ from PIL import Image from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext - +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 23393491..69565205 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_grid, process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index ccbb3644..055319d0 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -6,9 +6,9 @@ from PIL import Image from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext from ..utils import run_gc +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 00c1b9d4..0accc854 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,9 +10,9 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import run_gc +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 1fff3f1d..e3b54510 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -4,11 +4,9 @@ from typing import Any, Optional, Tuple import numpy as np from diffusers import ( - DiffusionPipeline, - OnnxRuntimeModel, - StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, + DiffusionPipeline, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -19,7 +17,9 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, + OnnxRuntimeModel, PNDMScheduler, + StableDiffusionPipeline, ) try: diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 0e44b6cc..2e92294c 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -12,10 +12,10 @@ from onnx_web.chain.base import ChainProgress from ..chain import upscale_outpaint from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc +from ..worker import WorkerContext from .load import get_latents_from_seed, load_pipeline logger = getLogger(__name__) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 86a5896d..3f7c9a84 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -7,49 +7,53 @@ from huggingface_hub.utils.tqdm import disable_progress_bars from torch.multiprocessing import set_start_method from .server.api import register_api_routes -from .server.static import register_static_routes -from .server.config import get_available_platforms, load_models, load_params, load_platforms -from .server.utils import check_paths +from .server.config import ( + get_available_platforms, + load_models, + load_params, + load_platforms, +) from .server.context import ServerContext from .server.hacks import apply_patches -from .utils import ( - is_debug, -) +from .server.static import register_static_routes +from .server.utils import check_paths +from .utils import is_debug from .worker import DevicePoolExecutor def main(): - set_start_method("spawn", force=True) + set_start_method("spawn", force=True) - context = ServerContext.from_environ() - apply_patches(context) - check_paths(context) - load_models(context) - load_params(context) - load_platforms(context) + context = ServerContext.from_environ() + apply_patches(context) + check_paths(context) + load_models(context) + load_params(context) + load_platforms(context) - if is_debug(): - gc.set_debug(gc.DEBUG_STATS) + if is_debug(): + gc.set_debug(gc.DEBUG_STATS) - if not context.show_progress: - disable_progress_bar() - disable_progress_bars() + if not context.show_progress: + disable_progress_bar() + disable_progress_bars() - app = Flask(__name__) - CORS(app, origins=context.cors_origin) + app = Flask(__name__) + CORS(app, origins=context.cors_origin) - # any is a fake device, should not be in the pool - pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"]) + # any is a fake device, should not be in the pool + pool = DevicePoolExecutor( + context, [p for p in get_available_platforms() if p.device != "any"] + ) - # register routes - register_static_routes(app, context, pool) - register_api_routes(app, context, pool) + # register routes + register_static_routes(app, context, pool) + register_api_routes(app, context, pool) - return app, pool + return app, pool if __name__ == "__main__": - app, pool = main() - app.run("0.0.0.0", 5000, debug=is_debug()) - pool.join() - + app, pool = main() + app.run("0.0.0.0", 5000, debug=is_debug()) + pool.join() diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 97d5c8b0..42f00a4d 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -3,9 +3,9 @@ from typing import Any, Optional import numpy as np import torch -from onnxruntime import InferenceSession, SessionOptions from ..server import ServerContext +from .torch_before_ort import InferenceSession, SessionOptions class OnnxTensor: diff --git a/api/onnx_web/onnx/torch_before_ort.py b/api/onnx_web/onnx/torch_before_ort.py new file mode 100644 index 00000000..506c1478 --- /dev/null +++ b/api/onnx_web/onnx/torch_before_ort.py @@ -0,0 +1,5 @@ +# this file exists to make sure torch is always imported before onnxruntime +# to work around https://github.com/microsoft/onnxruntime/issues/11092 + +import torch # NOQA +from onnxruntime import * # NOQA diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 9bd5e819..f92328c6 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,8 +2,7 @@ from enum import IntEnum from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union -import torch -from onnxruntime import GraphOptimizationLevel, SessionOptions +from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index ed7dc8c0..7ff61192 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -7,27 +7,7 @@ from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate from PIL import Image -from .context import ServerContext -from .utils import wrap_route -from ..worker.pool import DevicePoolExecutor - -from .config import ( - get_available_platforms, - get_config_params, - get_config_value, - get_correction_models, - get_diffusion_models, - get_inversion_models, - get_mask_filters, - get_noise_sources, - get_upscaling_models, -) -from .params import border_from_request, pipeline_from_request, upscale_from_request - -from ..chain import ( - CHAIN_STAGES, - ChainPipeline, -) +from ..chain import CHAIN_STAGES, ChainPipeline from ..diffusion.load import get_pipeline_schedulers from ..diffusion.run import ( run_blend_pipeline, @@ -36,16 +16,9 @@ from ..diffusion.run import ( run_txt2img_pipeline, run_upscale_pipeline, ) -from ..image import ( # mask filters; noise sources - valid_image, -) +from ..image import valid_image # mask filters; noise sources from ..output import json_params, make_output_name -from ..params import ( - Border, - StageParams, - TileOrder, - UpscaleParams, -) +from ..params import Border, StageParams, TileOrder, UpscaleParams from ..transformers import run_txt2txt_pipeline from ..utils import ( base_join, @@ -56,6 +29,21 @@ from ..utils import ( get_not_empty, get_size, ) +from ..worker.pool import DevicePoolExecutor +from .config import ( + get_available_platforms, + get_config_params, + get_config_value, + get_correction_models, + get_diffusion_models, + get_inversion_models, + get_mask_filters, + get_noise_sources, + get_upscaling_models, +) +from .context import ServerContext +from .params import border_from_request, pipeline_from_request, upscale_from_request +from .utils import wrap_route logger = getLogger(__name__) @@ -456,22 +444,38 @@ def status(context: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): - return [ - app.route("/api")(wrap_route(introspect, context, app=app)), - app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), - app.route("/api/settings/models")(wrap_route(list_models, context)), - app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), - app.route("/api/settings/params")(wrap_route(list_params, context)), - app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), - app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), - app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)), - app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)), - app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)), - app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)), - app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)), - app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)), - app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)), - app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)), - app.route("/api/ready")(wrap_route(ready, context, pool=pool)), - app.route("/api/status")(wrap_route(status, context, pool=pool)), - ] + return [ + app.route("/api")(wrap_route(introspect, context, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), + app.route("/api/settings/models")(wrap_route(list_models, context)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), + app.route("/api/settings/params")(wrap_route(list_params, context)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/img2img", methods=["POST"])( + wrap_route(img2img, context, pool=pool) + ), + app.route("/api/txt2img", methods=["POST"])( + wrap_route(txt2img, context, pool=pool) + ), + app.route("/api/txt2txt", methods=["POST"])( + wrap_route(txt2txt, context, pool=pool) + ), + app.route("/api/inpaint", methods=["POST"])( + wrap_route(inpaint, context, pool=pool) + ), + app.route("/api/upscale", methods=["POST"])( + wrap_route(upscale, context, pool=pool) + ), + app.route("/api/chain", methods=["POST"])( + wrap_route(chain, context, pool=pool) + ), + app.route("/api/blend", methods=["POST"])( + wrap_route(blend, context, pool=pool) + ), + app.route("/api/cancel", methods=["PUT"])( + wrap_route(cancel, context, pool=pool) + ), + app.route("/api/ready")(wrap_route(ready, context, pool=pool)), + app.route("/api/status")(wrap_route(status, context, pool=pool)), + ] diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index 01110548..71b24709 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -2,13 +2,11 @@ from functools import cmp_to_key from glob import glob from logging import getLogger from os import path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union import torch import yaml -from onnxruntime import get_available_providers -from .context import ServerContext from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -20,9 +18,9 @@ from ..image import ( # mask filters; noise sources noise_source_normal, noise_source_uniform, ) -from ..params import ( - DeviceParams, -) +from ..onnx.torch_before_ort import get_available_providers +from ..params import DeviceParams +from .context import ServerContext logger = getLogger(__name__) @@ -221,4 +219,3 @@ def load_platforms(context: ServerContext) -> None: "available acceleration platforms: %s", ", ".join([str(p) for p in available_platforms]), ) - diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index b70ef79a..c16d19ef 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -4,30 +4,24 @@ from typing import Tuple import numpy as np from flask import request -from .context import ServerContext - -from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models -from .utils import get_model_path - from ..diffusion.load import pipeline_schedulers -from ..params import ( - Border, - DeviceParams, - ImageParams, - Size, - UpscaleParams, -) -from ..utils import ( - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_not_empty, +from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams +from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty +from .config import ( + get_available_platforms, + get_config_value, + get_correction_models, + get_upscaling_models, ) +from .context import ServerContext +from .utils import get_model_path logger = getLogger(__name__) -def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]: +def pipeline_from_request( + context: ServerContext, +) -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr # platform stuff @@ -43,9 +37,7 @@ def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImagePa lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(context, model) - scheduler = get_from_list( - request.args, "scheduler", pipeline_schedulers.keys() - ) + scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys()) if scheduler is None: scheduler = get_config_value("scheduler") diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py index 296c8deb..9a67bf0e 100644 --- a/api/onnx_web/server/static.py +++ b/api/onnx_web/server/static.py @@ -2,9 +2,9 @@ from os import path from flask import Flask, send_from_directory -from .utils import wrap_route -from .context import ServerContext from ..worker.pool import DevicePoolExecutor +from .context import ServerContext +from .utils import wrap_route def serve_bundle_file(context: ServerContext, filename="index.html"): @@ -26,9 +26,11 @@ def output(context: ServerContext, filename: str): ) -def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): - return [ - app.route("/")(wrap_route(index, context)), - app.route("/")(wrap_route(index_path, context)), - app.route("/output/")(wrap_route(output, context)), - ] +def register_static_routes( + app: Flask, context: ServerContext, pool: DevicePoolExecutor +): + return [ + app.route("/")(wrap_route(index, context)), + app.route("/")(wrap_route(index_path, context)), + app.route("/output/")(wrap_route(output, context)), + ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py index 8dd359a1..582b1c6a 100644 --- a/api/onnx_web/server/utils.py +++ b/api/onnx_web/server/utils.py @@ -1,6 +1,6 @@ +from functools import partial, update_wrapper from os import makedirs, path from typing import Callable, Dict, List, Tuple -from functools import partial, update_wrapper from flask import Flask @@ -22,7 +22,12 @@ def get_model_path(context: ServerContext, model: str): return base_join(context.model_path, model) -def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]): +def register_routes( + app: Flask, + context: ServerContext, + pool: DevicePoolExecutor, + routes: List[Tuple[str, Dict, Callable]], +): pass diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 8636f8c1..098ce817 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -11,7 +11,7 @@ from .chain import ( ) from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .server import ServerContext -from .worker import WorkerContext, ProgressCallback +from .worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/worker/__init__.py b/api/onnx_web/worker/__init__.py index 0ca5eefc..c1f2d794 100644 --- a/api/onnx_web/worker/__init__.py +++ b/api/onnx_web/worker/__init__.py @@ -1,2 +1,2 @@ from .context import WorkerContext, ProgressCallback -from .pool import DevicePoolExecutor \ No newline at end of file +from .pool import DevicePoolExecutor diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 59f55fdd..ae083509 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,7 +1,8 @@ from logging import getLogger -from torch.multiprocessing import Queue, Value from typing import Any, Callable, Tuple +from torch.multiprocessing import Queue, Value + from ..params import DeviceParams logger = getLogger(__name__) @@ -9,6 +10,7 @@ logger = getLogger(__name__) ProgressCallback = Callable[[int, int, Any], None] + class WorkerContext: cancel: "Value[bool]" = None key: str = None diff --git a/api/onnx_web/worker/logging.py b/api/onnx_web/worker/logging.py index 39808a64..ab90a266 100644 --- a/api/onnx_web/worker/logging.py +++ b/api/onnx_web/worker/logging.py @@ -1 +1 @@ -# TODO: queue-based logger \ No newline at end of file +# TODO: queue-based logger diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 34e5b70c..810bb91a 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,9 +1,10 @@ from collections import Counter from logging import getLogger from multiprocessing import Queue -from torch.multiprocessing import Lock, Process, Value from typing import Callable, Dict, List, Optional, Tuple +from torch.multiprocessing import Lock, Process, Value + from ..params import DeviceParams from ..server import ServerContext from .context import WorkerContext @@ -35,7 +36,7 @@ class DevicePoolExecutor: self.pending = {} self.progress = {} self.workers = {} - self.jobs = {} # Dict[Output, Device] + self.jobs = {} # Dict[Output, Device] self.job_count = 0 # TODO: make this a method @@ -58,15 +59,21 @@ class DevicePoolExecutor: cancel = Value("B", False, lock=lock) finished = Value("B", False) self.finished[name] = finished - progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? + progress = Value( + "I", 0 + ) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending - context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished) + context = WorkerContext( + name, cancel, device, pending, progress, self.log_queue, finished + ) self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(lock, context, server)) + self.workers[name] = Process( + target=worker_init, args=(lock, context, server) + ) self.workers[name].start() def cancel(self, key: str) -> bool: @@ -78,7 +85,7 @@ class DevicePoolExecutor: raise NotImplementedError() def done(self, key: str) -> Tuple[Optional[bool], int]: - if not key in self.jobs: + if key not in self.jobs: logger.warn("checking status for unknown key: %s", key) return (None, 0) @@ -88,7 +95,6 @@ class DevicePoolExecutor: return (finished.value, progress.value) - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible if needs_device is not None: @@ -96,9 +102,7 @@ class DevicePoolExecutor: if self.devices[i].device == needs_device.device: return i - pending = [ - self.pending[d.device].qsize() for d in self.devices - ] + pending = [self.pending[d.device].qsize() for d in self.devices] jobs = Counter(range(len(self.devices))) jobs.update(pending) @@ -128,7 +132,7 @@ class DevicePoolExecutor: finished_count - self.finished_limit, finished_count, ) - self.finished[:] = self.finished[-self.finished_limit:] + self.finished[:] = self.finished[-self.finished_limit :] def recycle(self): for name, proc in self.workers.items(): @@ -149,10 +153,11 @@ class DevicePoolExecutor: lock = self.locks[name] logger.debug("starting worker for device %s", name) - self.workers[name] = Process(target=worker_init, args=(lock, context, self.server)) + self.workers[name] = Process( + target=worker_init, args=(lock, context, self.server) + ) self.workers[name].start() - def submit( self, key: str, @@ -171,7 +176,10 @@ class DevicePoolExecutor: self.prune() device_idx = self.get_next_device(needs_device=needs_device) logger.info( - "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] + "assigning job %s to device %s: %s", + key, + device_idx, + self.devices[device_idx], ) device = self.devices[device_idx] @@ -180,7 +188,6 @@ class DevicePoolExecutor: self.jobs[key] = device.device - def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ ( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 07c6bb02..dbd86896 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,12 +1,12 @@ from logging import getLogger -import torch # has to come before ORT -from onnxruntime import get_available_providers -from torch.multiprocessing import Lock, Queue from traceback import format_exception -from setproctitle import setproctitle -from .context import WorkerContext +from setproctitle import setproctitle +from torch.multiprocessing import Lock, Queue + +from ..onnx.torch_before_ort import get_available_providers from ..server import ServerContext, apply_patches +from .context import WorkerContext logger = getLogger(__name__) @@ -29,7 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s, %s", lock, get_available_providers()) apply_patches(server) - setproctitle("onnx-web worker: %s", context.device.device) + setproctitle("onnx-web worker: %s" % (context.device.device)) while True: job = context.pending.get() @@ -52,4 +52,3 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): except Exception as e: logger.error(format_exception(type(e), e, e.__traceback__)) -