From b931da1d2c11cde4c0369a8482478c042cf5aadd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 15:21:58 -0600 Subject: [PATCH] fix imports, lint --- api/launch-extras.bat | 2 +- api/launch-extras.sh | 2 +- api/launch.bat | 2 +- api/launch.sh | 2 +- api/onnx_web/main.py | 14 ++++++++++++-- api/onnx_web/onnx/onnx_net.py | 2 +- api/onnx_web/params.py | 2 +- api/onnx_web/server/config.py | 2 +- api/onnx_web/{onnx => }/torch_before_ort.py | 0 api/onnx_web/worker/pool.py | 2 +- api/onnx_web/worker/worker.py | 6 +++--- 11 files changed, 23 insertions(+), 13 deletions(-) rename api/onnx_web/{onnx => }/torch_before_ort.py (100%) diff --git a/api/launch-extras.bat b/api/launch-extras.bat index fa3c8908..2f0b95c0 100644 --- a/api/launch-extras.bat +++ b/api/launch-extras.bat @@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.serve:run" run --host=0.0.0.0 diff --git a/api/launch-extras.sh b/api/launch-extras.sh index f18e14c0..50572aa4 100755 --- a/api/launch-extras.sh +++ b/api/launch-extras.sh @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app='onnx_web.main:main()' run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/launch.bat b/api/launch.bat index f589bd11..4ee27fd6 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..." python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.serve:run" run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 983e0930..55b6ff72 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app='onnx_web.main:main()' run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 3f7c9a84..bbfa1039 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -1,4 +1,6 @@ +import atexit import gc +from logging import getLogger from diffusers.utils.logging import disable_progress_bar from flask import Flask @@ -20,6 +22,8 @@ from .server.utils import check_paths from .utils import is_debug from .worker import DevicePoolExecutor +logger = getLogger(__name__) + def main(): set_start_method("spawn", force=True) @@ -53,7 +57,13 @@ def main(): return app, pool -if __name__ == "__main__": +def run(): app, pool = main() + atexit.register(lambda: pool.join()) + return app + + +if __name__ == "__main__": + app = run() app.run("0.0.0.0", 5000, debug=is_debug()) - pool.join() + logger.info("shutting down app") diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 42f00a4d..a974aff4 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -5,7 +5,7 @@ import numpy as np import torch from ..server import ServerContext -from .torch_before_ort import InferenceSession, SessionOptions +from ..torch_before_ort import InferenceSession, SessionOptions class OnnxTensor: diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f92328c6..32440c07 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,7 +2,7 @@ from enum import IntEnum from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions +from .torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index 71b24709..a5dc1a31 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -18,8 +18,8 @@ from ..image import ( # mask filters; noise sources noise_source_normal, noise_source_uniform, ) -from ..onnx.torch_before_ort import get_available_providers from ..params import DeviceParams +from ..torch_before_ort import get_available_providers from .context import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/onnx/torch_before_ort.py b/api/onnx_web/torch_before_ort.py similarity index 100% rename from api/onnx_web/onnx/torch_before_ort.py rename to api/onnx_web/torch_before_ort.py diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index fc58ba57..4f2b66ff 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -48,7 +48,7 @@ class DevicePoolExecutor: def create_logger_worker(self) -> None: self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.log_queue)) + self.logger = Process(target=logger_init, args=(self.log_queue,)) logger.debug("starting log worker") self.logger.start() diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index efb598f6..24a1c4f2 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -4,8 +4,8 @@ from traceback import format_exception from setproctitle import setproctitle from torch.multiprocessing import Queue -from ..onnx.torch_before_ort import get_available_providers from ..server import ServerContext, apply_patches +from ..torch_before_ort import get_available_providers from .context import WorkerContext logger = getLogger(__name__) @@ -14,7 +14,7 @@ logger = getLogger(__name__) def logger_init(logs: Queue): setproctitle("onnx-web logger") - logger.info("checking in from logger, %s") + logger.info("checking in from logger") while True: job = logs.get() @@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) - logger.info("checking in from worker, %s, %s", get_available_providers()) + logger.info("checking in from worker, %s", get_available_providers()) while True: job = context.pending.get()