1
0
Fork 0
onnx-web/api/onnx_web/server/context.py

141 lines
5.0 KiB
Python
Raw Normal View History

2023-02-19 02:28:21 +00:00
from logging import getLogger
from os import environ, path
from secrets import token_urlsafe
from typing import List, Optional
2023-02-19 02:28:21 +00:00
import torch
2023-02-19 02:28:21 +00:00
from ..utils import get_boolean
from .model_cache import ModelCache
logger = getLogger(__name__)
DEFAULT_ANY_PLATFORM = True
DEFAULT_CACHE_LIMIT = 5
DEFAULT_JOB_LIMIT = 10
DEFAULT_IMAGE_FORMAT = "png"
DEFAULT_SERVER_VERSION = "v0.10.0"
DEFAULT_SHOW_PROGRESS = True
DEFAULT_WORKER_RETRIES = 3
2023-02-19 02:28:21 +00:00
class ServerContext:
bundle_path: str
model_path: str
output_path: str
params_path: str
cors_origin: str
any_platform: bool
block_platforms: List[str]
default_platform: str
image_format: str
cache_limit: int
cache_path: str
show_progress: bool
optimizations: List[str]
extra_models: List[str]
job_limit: int
memory_limit: int
admin_token: str
server_version: str
worker_retries: int
feature_flags: List[str]
2023-11-18 23:20:13 +00:00
plugins: List[str]
2023-02-19 02:28:21 +00:00
def __init__(
self,
bundle_path: str = ".",
model_path: str = ".",
output_path: str = ".",
params_path: str = ".",
cors_origin: str = "*",
any_platform: bool = DEFAULT_ANY_PLATFORM,
block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None,
image_format: str = DEFAULT_IMAGE_FORMAT,
cache_limit: int = DEFAULT_CACHE_LIMIT,
cache_path: Optional[str] = None,
show_progress: bool = DEFAULT_SHOW_PROGRESS,
optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None,
admin_token: Optional[str] = None,
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
feature_flags: Optional[List[str]] = None,
2023-02-19 02:28:21 +00:00
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
self.cors_origin = cors_origin
self.any_platform = any_platform
2023-02-19 13:37:29 +00:00
self.block_platforms = block_platforms or []
2023-02-19 02:28:21 +00:00
self.default_platform = default_platform
self.image_format = image_format
self.cache_limit = cache_limit
2023-02-19 02:28:21 +00:00
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
2023-02-19 13:37:29 +00:00
self.optimizations = optimizations or []
self.extra_models = extra_models or []
self.job_limit = job_limit
self.memory_limit = memory_limit
self.admin_token = admin_token or token_urlsafe()
self.server_version = server_version
self.worker_retries = worker_retries
self.feature_flags = feature_flags or []
self.cache = ModelCache(self.cache_limit)
2023-02-19 02:28:21 +00:00
@classmethod
def from_environ(cls):
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
if memory_limit is not None:
memory_limit = int(memory_limit)
2023-02-19 02:28:21 +00:00
return cls(
bundle_path=environ.get(
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
),
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
),
2023-02-19 02:28:21 +00:00
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
),
2023-02-19 02:28:21 +00:00
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit,
2023-05-09 02:42:25 +00:00
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
),
worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
feature_flags=environ.get("ONNX_WEB_FEATURE_FLAGS", "").split(","),
2023-02-19 02:28:21 +00:00
)
def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags
def has_optimization(self, opt: str) -> bool:
return opt in self.optimizations
def torch_dtype(self):
if self.has_optimization("torch-fp16"):
return torch.float16
else:
return torch.float32