2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-01-16 22:39:30 +00:00
|
|
|
from os import environ, path
|
2023-02-02 14:31:35 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Union
|
2023-01-16 00:54:20 +00:00
|
|
|
|
2023-02-02 03:20:48 +00:00
|
|
|
import gc
|
|
|
|
import torch
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
from .params import (
|
2023-01-29 05:06:25 +00:00
|
|
|
SizeChart,
|
2023-01-28 04:48:06 +00:00
|
|
|
)
|
2023-01-16 01:14:58 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-16 01:14:58 +00:00
|
|
|
|
2023-01-16 13:31:42 +00:00
|
|
|
class ServerContext:
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-01-16 22:39:30 +00:00
|
|
|
bundle_path: str = '.',
|
|
|
|
model_path: str = '.',
|
|
|
|
output_path: str = '.',
|
|
|
|
params_path: str = '.',
|
|
|
|
cors_origin: str = '*',
|
|
|
|
num_workers: int = 1,
|
2023-01-25 05:19:57 +00:00
|
|
|
block_platforms: List[str] = [],
|
2023-01-31 14:45:06 +00:00
|
|
|
default_platform: str = None,
|
2023-02-02 14:19:57 +00:00
|
|
|
image_format: str = 'png',
|
2023-01-16 13:31:42 +00:00
|
|
|
) -> None:
|
|
|
|
self.bundle_path = bundle_path
|
|
|
|
self.model_path = model_path
|
|
|
|
self.output_path = output_path
|
|
|
|
self.params_path = params_path
|
2023-01-16 22:39:30 +00:00
|
|
|
self.cors_origin = cors_origin
|
|
|
|
self.num_workers = num_workers
|
2023-01-25 05:19:57 +00:00
|
|
|
self.block_platforms = block_platforms
|
2023-01-31 14:45:06 +00:00
|
|
|
self.default_platform = default_platform
|
2023-02-02 14:19:57 +00:00
|
|
|
self.image_format = image_format
|
2023-01-16 22:39:30 +00:00
|
|
|
|
|
|
|
@classmethod
|
2023-01-16 22:40:59 +00:00
|
|
|
def from_environ(cls):
|
2023-01-16 22:39:30 +00:00
|
|
|
return ServerContext(
|
|
|
|
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
|
|
|
|
path.join('..', 'gui', 'out')),
|
|
|
|
model_path=environ.get('ONNX_WEB_MODEL_PATH',
|
|
|
|
path.join('..', 'models')),
|
|
|
|
output_path=environ.get(
|
|
|
|
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
|
|
|
|
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
|
|
|
|
# others
|
|
|
|
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
|
|
|
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
2023-01-28 04:48:06 +00:00
|
|
|
block_platforms=environ.get(
|
2023-01-31 14:45:06 +00:00
|
|
|
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
|
|
|
default_platform=environ.get(
|
|
|
|
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
2023-02-02 14:19:57 +00:00
|
|
|
image_format=environ.get(
|
|
|
|
'ONNX_WEB_IMAGE_FORMAT', 'png'
|
|
|
|
),
|
2023-01-16 22:39:30 +00:00
|
|
|
)
|
2023-01-16 13:31:42 +00:00
|
|
|
|
|
|
|
|
2023-02-02 14:31:35 +00:00
|
|
|
def base_join(base: str, tail: str) -> str:
|
|
|
|
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
|
|
|
return path.join(base, tail_path)
|
|
|
|
|
|
|
|
|
2023-01-20 01:46:36 +00:00
|
|
|
def is_debug() -> bool:
|
|
|
|
return environ.get('DEBUG') is not None
|
|
|
|
|
|
|
|
|
2023-01-16 01:47:57 +00:00
|
|
|
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
|
|
|
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
|
|
|
|
|
|
|
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int:
|
|
|
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
2023-01-22 21:44:31 +00:00
|
|
|
selected = args.get(key, None)
|
2023-01-17 02:10:52 +00:00
|
|
|
if selected in values:
|
|
|
|
return selected
|
2023-01-22 21:44:31 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.warn('invalid selection: %s', selected)
|
2023-01-22 21:44:31 +00:00
|
|
|
if len(values) > 0:
|
2023-01-17 02:10:52 +00:00
|
|
|
return values[0]
|
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
return None
|
|
|
|
|
2023-01-17 02:10:52 +00:00
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
|
2023-01-16 01:47:57 +00:00
|
|
|
selected = args.get(key, default)
|
|
|
|
if selected in values:
|
|
|
|
return values[selected]
|
|
|
|
else:
|
|
|
|
return values[default]
|
|
|
|
|
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
def get_not_empty(args: Any, key: str, default: Any) -> Any:
|
2023-01-22 19:48:14 +00:00
|
|
|
val = args.get(key, default)
|
|
|
|
|
|
|
|
if val is None or len(val) == 0:
|
|
|
|
val = default
|
|
|
|
|
|
|
|
return val
|
|
|
|
|
|
|
|
|
2023-01-29 05:06:25 +00:00
|
|
|
def get_size(val: Union[int, str, None]) -> SizeChart:
|
|
|
|
if val is None:
|
|
|
|
return SizeChart.auto
|
|
|
|
|
|
|
|
if type(val) is int:
|
|
|
|
return val
|
|
|
|
|
2023-01-29 05:46:36 +00:00
|
|
|
if type(val) is str:
|
|
|
|
for size in SizeChart:
|
|
|
|
if val == size.name:
|
|
|
|
return size
|
|
|
|
|
|
|
|
return int(val)
|
|
|
|
|
2023-01-29 05:06:25 +00:00
|
|
|
raise Exception('invalid size')
|
|
|
|
|
|
|
|
|
2023-02-02 03:20:48 +00:00
|
|
|
def run_gc():
|
|
|
|
logger.debug('running garbage collection')
|
|
|
|
gc.collect()
|
2023-02-02 14:31:35 +00:00
|
|
|
torch.cuda.empty_cache()
|