1
0
Fork 0

fix imports, lint

This commit is contained in:
Sean Sube 2023-02-26 15:21:58 -06:00
parent 85118d17c6
commit b931da1d2c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 23 additions and 13 deletions

View File

@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN% python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%
echo "Launching API server..." echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0 flask --app="onnx_web.serve:run" run --host=0.0.0.0

View File

@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-} --token=${HF_TOKEN:-}
echo "Launching API server..." echo "Launching API server..."
flask --app='onnx_web.main:main()' run --host=0.0.0.0 flask --app='onnx_web.main:run' run --host=0.0.0.0

View File

@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..."
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN% python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%
echo "Launching API server..." echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0 flask --app="onnx_web.serve:run" run --host=0.0.0.0

View File

@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-} --token=${HF_TOKEN:-}
echo "Launching API server..." echo "Launching API server..."
flask --app='onnx_web.main:main()' run --host=0.0.0.0 flask --app='onnx_web.main:run' run --host=0.0.0.0

View File

@ -1,4 +1,6 @@
import atexit
import gc import gc
from logging import getLogger
from diffusers.utils.logging import disable_progress_bar from diffusers.utils.logging import disable_progress_bar
from flask import Flask from flask import Flask
@ -20,6 +22,8 @@ from .server.utils import check_paths
from .utils import is_debug from .utils import is_debug
from .worker import DevicePoolExecutor from .worker import DevicePoolExecutor
logger = getLogger(__name__)
def main(): def main():
set_start_method("spawn", force=True) set_start_method("spawn", force=True)
@ -53,7 +57,13 @@ def main():
return app, pool return app, pool
if __name__ == "__main__": def run():
app, pool = main() app, pool = main()
atexit.register(lambda: pool.join())
return app
if __name__ == "__main__":
app = run()
app.run("0.0.0.0", 5000, debug=is_debug()) app.run("0.0.0.0", 5000, debug=is_debug())
pool.join() logger.info("shutting down app")

View File

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from ..server import ServerContext from ..server import ServerContext
from .torch_before_ort import InferenceSession, SessionOptions from ..torch_before_ort import InferenceSession, SessionOptions
class OnnxTensor: class OnnxTensor:

View File

@ -2,7 +2,7 @@ from enum import IntEnum
from logging import getLogger from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions from .torch_before_ort import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -18,8 +18,8 @@ from ..image import ( # mask filters; noise sources
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from ..onnx.torch_before_ort import get_available_providers
from ..params import DeviceParams from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from .context import ServerContext from .context import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -48,7 +48,7 @@ class DevicePoolExecutor:
def create_logger_worker(self) -> None: def create_logger_worker(self) -> None:
self.log_queue = Queue() self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.log_queue)) self.logger = Process(target=logger_init, args=(self.log_queue,))
logger.debug("starting log worker") logger.debug("starting log worker")
self.logger.start() self.logger.start()

View File

@ -4,8 +4,8 @@ from traceback import format_exception
from setproctitle import setproctitle from setproctitle import setproctitle
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches from ..server import ServerContext, apply_patches
from ..torch_before_ort import get_available_providers
from .context import WorkerContext from .context import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -14,7 +14,7 @@ logger = getLogger(__name__)
def logger_init(logs: Queue): def logger_init(logs: Queue):
setproctitle("onnx-web logger") setproctitle("onnx-web logger")
logger.info("checking in from logger, %s") logger.info("checking in from logger")
while True: while True:
job = logs.get() job = logs.get()
@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (context.device.device))
logger.info("checking in from worker, %s, %s", get_available_providers()) logger.info("checking in from worker, %s", get_available_providers())
while True: while True:
job = context.pending.get() job = context.pending.get()