lint(api): remove global paths entirely
This commit is contained in:
parent
4615614e5e
commit
c98c0ff4dd
|
@ -18,7 +18,7 @@ from flask_cors import CORS
|
|||
from flask_executor import Executor
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from os import environ, makedirs, path, scandir
|
||||
from os import makedirs, path, scandir
|
||||
from typing import Tuple
|
||||
|
||||
from .image import (
|
||||
|
@ -57,17 +57,6 @@ from .utils import (
|
|||
import json
|
||||
import numpy as np
|
||||
|
||||
# paths
|
||||
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
|
||||
path.join('..', 'gui', 'out'))
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models'))
|
||||
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs'))
|
||||
params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json')
|
||||
|
||||
# options
|
||||
cors_origin = environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(',')
|
||||
num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))
|
||||
|
||||
# pipeline caching
|
||||
available_models = []
|
||||
config_params = {}
|
||||
|
@ -109,14 +98,12 @@ mask_filters = {
|
|||
# TODO: load from model_path
|
||||
upscale_models = [
|
||||
'RealESRGAN_x4plus',
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
|
||||
'GFPGANv1.3',
|
||||
# TODO: convert GFPGAN
|
||||
# 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
||||
]
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', bundle_path), filename)
|
||||
|
||||
|
||||
def url_from_rule(rule) -> str:
|
||||
options = {}
|
||||
for arg in rule.arguments:
|
||||
|
@ -125,10 +112,6 @@ def url_from_rule(rule) -> str:
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def get_model_path(model: str):
|
||||
return safer_join(model_path, model)
|
||||
|
||||
|
||||
def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
|
@ -210,37 +193,51 @@ def upscale_from_request() -> UpscaleParams:
|
|||
denoise=denoise,
|
||||
)
|
||||
|
||||
def check_paths():
|
||||
if not path.exists(model_path):
|
||||
|
||||
def check_paths(context: ServerContext):
|
||||
if not path.exists(context.model_path):
|
||||
raise RuntimeError('model path must exist')
|
||||
|
||||
if not path.exists(output_path):
|
||||
makedirs(output_path)
|
||||
if not path.exists(context.output_path):
|
||||
makedirs(context.output_path)
|
||||
|
||||
|
||||
def load_models():
|
||||
def load_models(context: ServerContext):
|
||||
global available_models
|
||||
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
|
||||
available_models = [f.name for f in scandir(
|
||||
context.model_path) if f.is_dir()]
|
||||
|
||||
|
||||
def load_params():
|
||||
def load_params(context: ServerContext):
|
||||
global config_params
|
||||
with open(params_path) as f:
|
||||
params_file = path.join(context.params_path, 'params.json')
|
||||
with open(params_file) as f:
|
||||
config_params = json.load(f)
|
||||
|
||||
|
||||
check_paths()
|
||||
load_models()
|
||||
load_params()
|
||||
context = ServerContext()
|
||||
|
||||
check_paths(context)
|
||||
load_models(context)
|
||||
load_params(context)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['EXECUTOR_MAX_WORKERS'] = num_workers
|
||||
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
|
||||
app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
|
||||
|
||||
CORS(app, origins=cors_origin)
|
||||
CORS(app, origins=context.cors_origin)
|
||||
executor = Executor(app)
|
||||
|
||||
context = ServerContext(bundle_path, model_path, output_path, params_path)
|
||||
|
||||
# TODO: these two use context
|
||||
|
||||
def get_model_path(model: str):
|
||||
return safer_join(context.model_path, model)
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', context.bundle_path), filename)
|
||||
|
||||
|
||||
# routes
|
||||
|
||||
|
@ -418,4 +415,4 @@ def ready():
|
|||
|
||||
@app.route('/api/output/<path:filename>')
|
||||
def output(filename: str):
|
||||
return send_from_directory(path.join('..', output_path), filename, as_attachment=False)
|
||||
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from os import path
|
||||
from os import environ, path
|
||||
from time import time
|
||||
from struct import pack
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
@ -54,15 +54,34 @@ class Border:
|
|||
class ServerContext:
|
||||
def __init__(
|
||||
self,
|
||||
bundle_path: str,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
params_path: str
|
||||
bundle_path: str = '.',
|
||||
model_path: str = '.',
|
||||
output_path: str = '.',
|
||||
params_path: str = '.',
|
||||
cors_origin: str = '*',
|
||||
num_workers: int = 1,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
self.output_path = output_path
|
||||
self.params_path = params_path
|
||||
self.cors_origin = cors_origin
|
||||
self.num_workers = num_workers
|
||||
|
||||
@classmethod
|
||||
def from_environ():
|
||||
return ServerContext(
|
||||
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
|
||||
path.join('..', 'gui', 'out')),
|
||||
model_path=environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models')),
|
||||
output_path=environ.get(
|
||||
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
|
||||
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
|
||||
# others
|
||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||
)
|
||||
|
||||
|
||||
class Size:
|
||||
|
|
Loading…
Reference in New Issue