feat(api): add server setting for CUDA memory limit (#211)
This commit is contained in:
parent
af326a784f
commit
aec540a524
|
@ -7,6 +7,10 @@ from .model_cache import ModelCache
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_CACHE_LIMIT = 2
|
||||||
|
DEFAULT_JOB_LIMIT = 10
|
||||||
|
DEFAULT_IMAGE_FORMAT = "png"
|
||||||
|
|
||||||
|
|
||||||
class ServerContext:
|
class ServerContext:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -19,12 +23,14 @@ class ServerContext:
|
||||||
any_platform: bool = True,
|
any_platform: bool = True,
|
||||||
block_platforms: Optional[List[str]] = None,
|
block_platforms: Optional[List[str]] = None,
|
||||||
default_platform: Optional[str] = None,
|
default_platform: Optional[str] = None,
|
||||||
image_format: str = "png",
|
image_format: str = DEFAULT_IMAGE_FORMAT,
|
||||||
cache_limit: Optional[int] = 1,
|
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
optimizations: Optional[List[str]] = None,
|
optimizations: Optional[List[str]] = None,
|
||||||
extra_models: Optional[List[str]] = None,
|
extra_models: Optional[List[str]] = None,
|
||||||
|
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||||
|
memory_limit: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -35,15 +41,22 @@ class ServerContext:
|
||||||
self.block_platforms = block_platforms or []
|
self.block_platforms = block_platforms or []
|
||||||
self.default_platform = default_platform
|
self.default_platform = default_platform
|
||||||
self.image_format = image_format
|
self.image_format = image_format
|
||||||
self.cache = ModelCache(cache_limit)
|
|
||||||
self.cache_limit = cache_limit
|
self.cache_limit = cache_limit
|
||||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
self.optimizations = optimizations or []
|
self.optimizations = optimizations or []
|
||||||
self.extra_models = extra_models or []
|
self.extra_models = extra_models or []
|
||||||
|
self.job_limit = job_limit
|
||||||
|
self.memory_limit = memory_limit
|
||||||
|
|
||||||
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
|
||||||
|
if memory_limit is not None:
|
||||||
|
memory_limit = int(memory_limit)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
bundle_path=environ.get(
|
bundle_path=environ.get(
|
||||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
||||||
|
@ -57,8 +70,10 @@ class ServerContext:
|
||||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 2)),
|
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||||
|
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
|
memory_limit=memory_limit,
|
||||||
)
|
)
|
||||||
|
|
|
@ -268,13 +268,19 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
):
|
):
|
||||||
if potential == "cuda":
|
if potential == "cuda":
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
|
options = {
|
||||||
|
"device_id": i,
|
||||||
|
}
|
||||||
|
|
||||||
|
if context.memory_limit is not None:
|
||||||
|
options["arena_extend_strategy"] = "kSameAsRequested"
|
||||||
|
options["gpu_mem_limit"] = context.memory_limit
|
||||||
|
|
||||||
available_platforms.append(
|
available_platforms.append(
|
||||||
DeviceParams(
|
DeviceParams(
|
||||||
potential,
|
potential,
|
||||||
platform_providers[potential],
|
platform_providers[potential],
|
||||||
{
|
options,
|
||||||
"device_id": i,
|
|
||||||
},
|
|
||||||
context.optimizations,
|
context.optimizations,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,13 +41,12 @@ class DevicePoolExecutor:
|
||||||
self,
|
self,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
devices: List[DeviceParams],
|
devices: List[DeviceParams],
|
||||||
max_jobs_per_worker: int = 10,
|
|
||||||
max_pending_per_worker: int = 100,
|
max_pending_per_worker: int = 100,
|
||||||
join_timeout: float = 1.0,
|
join_timeout: float = 1.0,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.max_jobs_per_worker = max_jobs_per_worker
|
self.max_jobs_per_worker = server.job_limit
|
||||||
self.max_pending_per_worker = max_pending_per_worker
|
self.max_pending_per_worker = max_pending_per_worker
|
||||||
self.join_timeout = join_timeout
|
self.join_timeout = join_timeout
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue