fix imports, lint
This commit is contained in:
parent
85118d17c6
commit
b931da1d2c
|
@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
|
|||
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%
|
||||
|
||||
echo "Launching API server..."
|
||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
||||
flask --app="onnx_web.serve:run" run --host=0.0.0.0
|
||||
|
|
|
@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
|
|||
--token=${HF_TOKEN:-}
|
||||
|
||||
echo "Launching API server..."
|
||||
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
||||
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||
|
|
|
@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..."
|
|||
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%
|
||||
|
||||
echo "Launching API server..."
|
||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
||||
flask --app="onnx_web.serve:run" run --host=0.0.0.0
|
||||
|
|
|
@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
|
|||
--token=${HF_TOKEN:-}
|
||||
|
||||
echo "Launching API server..."
|
||||
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
||||
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import atexit
|
||||
import gc
|
||||
from logging import getLogger
|
||||
|
||||
from diffusers.utils.logging import disable_progress_bar
|
||||
from flask import Flask
|
||||
|
@ -20,6 +22,8 @@ from .server.utils import check_paths
|
|||
from .utils import is_debug
|
||||
from .worker import DevicePoolExecutor
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
set_start_method("spawn", force=True)
|
||||
|
@ -53,7 +57,13 @@ def main():
|
|||
return app, pool
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run():
|
||||
app, pool = main()
|
||||
atexit.register(lambda: pool.join())
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = run()
|
||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||
pool.join()
|
||||
logger.info("shutting down app")
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from ..server import ServerContext
|
||||
from .torch_before_ort import InferenceSession, SessionOptions
|
||||
from ..torch_before_ort import InferenceSession, SessionOptions
|
||||
|
||||
|
||||
class OnnxTensor:
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import IntEnum
|
|||
from logging import getLogger
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ from ..image import ( # mask filters; noise sources
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from ..onnx.torch_before_ort import get_available_providers
|
||||
from ..params import DeviceParams
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from .context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -48,7 +48,7 @@ class DevicePoolExecutor:
|
|||
|
||||
def create_logger_worker(self) -> None:
|
||||
self.log_queue = Queue()
|
||||
self.logger = Process(target=logger_init, args=(self.log_queue))
|
||||
self.logger = Process(target=logger_init, args=(self.log_queue,))
|
||||
|
||||
logger.debug("starting log worker")
|
||||
self.logger.start()
|
||||
|
|
|
@ -4,8 +4,8 @@ from traceback import format_exception
|
|||
from setproctitle import setproctitle
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from ..onnx.torch_before_ort import get_available_providers
|
||||
from ..server import ServerContext, apply_patches
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from .context import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
def logger_init(logs: Queue):
|
||||
setproctitle("onnx-web logger")
|
||||
|
||||
logger.info("checking in from logger, %s")
|
||||
logger.info("checking in from logger")
|
||||
|
||||
while True:
|
||||
job = logs.get()
|
||||
|
@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
|||
apply_patches(server)
|
||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||
|
||||
logger.info("checking in from worker, %s, %s", get_available_providers())
|
||||
logger.info("checking in from worker, %s", get_available_providers())
|
||||
|
||||
while True:
|
||||
job = context.pending.get()
|
||||
|
|
Loading…
Reference in New Issue