1
0
Fork 0

fix imports, lint

This commit is contained in:
Sean Sube 2023-02-26 15:21:58 -06:00
parent 85118d17c6
commit b931da1d2c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 23 additions and 13 deletions

View File

@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%
echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app="onnx_web.serve:run" run --host=0.0.0.0

View File

@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-}
echo "Launching API server..."
flask --app='onnx_web.main:main()' run --host=0.0.0.0
flask --app='onnx_web.main:run' run --host=0.0.0.0

View File

@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..."
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%
echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app="onnx_web.serve:run" run --host=0.0.0.0

View File

@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-}
echo "Launching API server..."
flask --app='onnx_web.main:main()' run --host=0.0.0.0
flask --app='onnx_web.main:run' run --host=0.0.0.0

View File

@ -1,4 +1,6 @@
import atexit
import gc
from logging import getLogger
from diffusers.utils.logging import disable_progress_bar
from flask import Flask
@ -20,6 +22,8 @@ from .server.utils import check_paths
from .utils import is_debug
from .worker import DevicePoolExecutor
logger = getLogger(__name__)
def main():
set_start_method("spawn", force=True)
@ -53,7 +57,13 @@ def main():
return app, pool
if __name__ == "__main__":
def run():
app, pool = main()
atexit.register(lambda: pool.join())
return app
if __name__ == "__main__":
app = run()
app.run("0.0.0.0", 5000, debug=is_debug())
pool.join()
logger.info("shutting down app")

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from ..server import ServerContext
from .torch_before_ort import InferenceSession, SessionOptions
from ..torch_before_ort import InferenceSession, SessionOptions
class OnnxTensor:

View File

@ -2,7 +2,7 @@ from enum import IntEnum
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)

View File

@ -18,8 +18,8 @@ from ..image import ( # mask filters; noise sources
noise_source_normal,
noise_source_uniform,
)
from ..onnx.torch_before_ort import get_available_providers
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from .context import ServerContext
logger = getLogger(__name__)

View File

@ -48,7 +48,7 @@ class DevicePoolExecutor:
def create_logger_worker(self) -> None:
self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.log_queue))
self.logger = Process(target=logger_init, args=(self.log_queue,))
logger.debug("starting log worker")
self.logger.start()

View File

@ -4,8 +4,8 @@ from traceback import format_exception
from setproctitle import setproctitle
from torch.multiprocessing import Queue
from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches
from ..torch_before_ort import get_available_providers
from .context import WorkerContext
logger = getLogger(__name__)
@ -14,7 +14,7 @@ logger = getLogger(__name__)
def logger_init(logs: Queue):
setproctitle("onnx-web logger")
logger.info("checking in from logger, %s")
logger.info("checking in from logger")
while True:
job = logs.get()
@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext):
apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device))
logger.info("checking in from worker, %s, %s", get_available_providers())
logger.info("checking in from worker, %s", get_available_providers())
while True:
job = context.pending.get()