fix imports, lint
This commit is contained in:
parent
85118d17c6
commit
b931da1d2c
|
@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
|
||||||
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%
|
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
flask --app="onnx_web.serve:run" run --host=0.0.0.0
|
||||||
|
|
|
@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
|
||||||
--token=${HF_TOKEN:-}
|
--token=${HF_TOKEN:-}
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||||
|
|
|
@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..."
|
||||||
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%
|
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
flask --app="onnx_web.serve:run" run --host=0.0.0.0
|
||||||
|
|
|
@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
|
||||||
--token=${HF_TOKEN:-}
|
--token=${HF_TOKEN:-}
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import atexit
|
||||||
import gc
|
import gc
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
from diffusers.utils.logging import disable_progress_bar
|
from diffusers.utils.logging import disable_progress_bar
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
@ -20,6 +22,8 @@ from .server.utils import check_paths
|
||||||
from .utils import is_debug
|
from .utils import is_debug
|
||||||
from .worker import DevicePoolExecutor
|
from .worker import DevicePoolExecutor
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
set_start_method("spawn", force=True)
|
set_start_method("spawn", force=True)
|
||||||
|
@ -53,7 +57,13 @@ def main():
|
||||||
return app, pool
|
return app, pool
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def run():
|
||||||
app, pool = main()
|
app, pool = main()
|
||||||
|
atexit.register(lambda: pool.join())
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = run()
|
||||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||||
pool.join()
|
logger.info("shutting down app")
|
||||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from .torch_before_ort import InferenceSession, SessionOptions
|
from ..torch_before_ort import InferenceSession, SessionOptions
|
||||||
|
|
||||||
|
|
||||||
class OnnxTensor:
|
class OnnxTensor:
|
||||||
|
|
|
@ -2,7 +2,7 @@ from enum import IntEnum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
|
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@ from ..image import ( # mask filters; noise sources
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from ..onnx.torch_before_ort import get_available_providers
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
|
from ..torch_before_ort import get_available_providers
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -48,7 +48,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
def create_logger_worker(self) -> None:
|
def create_logger_worker(self) -> None:
|
||||||
self.log_queue = Queue()
|
self.log_queue = Queue()
|
||||||
self.logger = Process(target=logger_init, args=(self.log_queue))
|
self.logger = Process(target=logger_init, args=(self.log_queue,))
|
||||||
|
|
||||||
logger.debug("starting log worker")
|
logger.debug("starting log worker")
|
||||||
self.logger.start()
|
self.logger.start()
|
||||||
|
|
|
@ -4,8 +4,8 @@ from traceback import format_exception
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
from torch.multiprocessing import Queue
|
from torch.multiprocessing import Queue
|
||||||
|
|
||||||
from ..onnx.torch_before_ort import get_available_providers
|
|
||||||
from ..server import ServerContext, apply_patches
|
from ..server import ServerContext, apply_patches
|
||||||
|
from ..torch_before_ort import get_available_providers
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
def logger_init(logs: Queue):
|
def logger_init(logs: Queue):
|
||||||
setproctitle("onnx-web logger")
|
setproctitle("onnx-web logger")
|
||||||
|
|
||||||
logger.info("checking in from logger, %s")
|
logger.info("checking in from logger")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
job = logs.get()
|
job = logs.get()
|
||||||
|
@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||||
|
|
||||||
logger.info("checking in from worker, %s, %s", get_available_providers())
|
logger.info("checking in from worker, %s", get_available_providers())
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
job = context.pending.get()
|
job = context.pending.get()
|
||||||
|
|
Loading…
Reference in New Issue