1
0
Fork 0

lint(api): remove global paths entirely

This commit is contained in:
Sean Sube 2023-01-16 16:39:30 -06:00
parent 4615614e5e
commit c98c0ff4dd
2 changed files with 57 additions and 41 deletions

View File

@ -18,7 +18,7 @@ from flask_cors import CORS
from flask_executor import Executor
from io import BytesIO
from PIL import Image
from os import environ, makedirs, path, scandir
from os import makedirs, path, scandir
from typing import Tuple
from .image import (
@ -57,17 +57,6 @@ from .utils import (
import json
import numpy as np
# paths
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out'))
model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models'))
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs'))
params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json')
# options
cors_origin = environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(',')
num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))
# pipeline caching
available_models = []
config_params = {}
@ -109,14 +98,12 @@ mask_filters = {
# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
'GFPGANv1.3',
# TODO: convert GFPGAN
# 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
]
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
@ -125,10 +112,6 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)
def get_model_path(model: str):
return safer_join(model_path, model)
def pipeline_from_request() -> Tuple[BaseParams, Size]:
user = request.remote_addr
@ -210,37 +193,51 @@ def upscale_from_request() -> UpscaleParams:
denoise=denoise,
)
def check_paths():
if not path.exists(model_path):
def check_paths(context: ServerContext):
if not path.exists(context.model_path):
raise RuntimeError('model path must exist')
if not path.exists(output_path):
makedirs(output_path)
if not path.exists(context.output_path):
makedirs(context.output_path)
def load_models():
def load_models(context: ServerContext):
global available_models
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
available_models = [f.name for f in scandir(
context.model_path) if f.is_dir()]
def load_params():
def load_params(context: ServerContext):
global config_params
with open(params_path) as f:
params_file = path.join(context.params_path, 'params.json')
with open(params_file) as f:
config_params = json.load(f)
check_paths()
load_models()
load_params()
context = ServerContext()
check_paths(context)
load_models(context)
load_params(context)
app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = num_workers
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
CORS(app, origins=cors_origin)
CORS(app, origins=context.cors_origin)
executor = Executor(app)
context = ServerContext(bundle_path, model_path, output_path, params_path)
# TODO: these two use context
def get_model_path(model: str):
return safer_join(context.model_path, model)
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', context.bundle_path), filename)
# routes
@ -418,4 +415,4 @@ def ready():
@app.route('/api/output/<path:filename>')
def output(filename: str):
return send_from_directory(path.join('..', output_path), filename, as_attachment=False)
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)

View File

@ -1,4 +1,4 @@
from os import path
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, Tuple, Union
@ -54,15 +54,34 @@ class Border:
class ServerContext:
def __init__(
self,
bundle_path: str,
model_path: str,
output_path: str,
params_path: str
bundle_path: str = '.',
model_path: str = '.',
output_path: str = '.',
params_path: str = '.',
cors_origin: str = '*',
num_workers: int = 1,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers
@classmethod
def from_environ():
return ServerContext(
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out')),
model_path=environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models')),
output_path=environ.get(
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
)
class Size: