diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index e8b2687f..5b459ecc 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -7,6 +7,10 @@ from .model_cache import ModelCache logger = getLogger(__name__) +DEFAULT_CACHE_LIMIT = 2 +DEFAULT_JOB_LIMIT = 10 +DEFAULT_IMAGE_FORMAT = "png" + class ServerContext: def __init__( @@ -19,12 +23,14 @@ class ServerContext: any_platform: bool = True, block_platforms: Optional[List[str]] = None, default_platform: Optional[str] = None, - image_format: str = "png", - cache_limit: Optional[int] = 1, + image_format: str = DEFAULT_IMAGE_FORMAT, + cache_limit: int = DEFAULT_CACHE_LIMIT, cache_path: Optional[str] = None, show_progress: bool = True, optimizations: Optional[List[str]] = None, extra_models: Optional[List[str]] = None, + job_limit: int = DEFAULT_JOB_LIMIT, + memory_limit: Optional[int] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -35,15 +41,22 @@ class ServerContext: self.block_platforms = block_platforms or [] self.default_platform = default_platform self.image_format = image_format - self.cache = ModelCache(cache_limit) self.cache_limit = cache_limit self.cache_path = cache_path or path.join(model_path, ".cache") self.show_progress = show_progress self.optimizations = optimizations or [] self.extra_models = extra_models or [] + self.job_limit = job_limit + self.memory_limit = memory_limit + + self.cache = ModelCache(self.cache_limit) @classmethod def from_environ(cls): + memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None) + if memory_limit is not None: + memory_limit = int(memory_limit) + return cls( bundle_path=environ.get( "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") @@ -57,8 +70,10 @@ class ServerContext: block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), - cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 2)), + cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), + job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), + memory_limit=memory_limit, ) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index c8183bb4..6a0ce060 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -268,13 +268,19 @@ def load_platforms(context: ServerContext) -> None: ): if potential == "cuda": for i in range(torch.cuda.device_count()): + options = { + "device_id": i, + } + + if context.memory_limit is not None: + options["arena_extend_strategy"] = "kSameAsRequested" + options["gpu_mem_limit"] = context.memory_limit + available_platforms.append( DeviceParams( potential, platform_providers[potential], - { - "device_id": i, - }, + options, context.optimizations, ) ) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index e03a4687..e222a6b4 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -41,13 +41,12 @@ class DevicePoolExecutor: self, server: ServerContext, devices: List[DeviceParams], - max_jobs_per_worker: int = 10, max_pending_per_worker: int = 100, join_timeout: float = 1.0, ): self.server = server self.devices = devices - self.max_jobs_per_worker = max_jobs_per_worker + self.max_jobs_per_worker = server.job_limit self.max_pending_per_worker = max_pending_per_worker self.join_timeout = join_timeout