diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 2e60bd41..e8b2687f 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -16,12 +16,11 @@ class ServerContext: output_path: str = ".", params_path: str = ".", cors_origin: str = "*", - num_workers: int = 1, any_platform: bool = True, block_platforms: Optional[List[str]] = None, default_platform: Optional[str] = None, image_format: str = "png", - cache: Optional[ModelCache] = None, + cache_limit: Optional[int] = 1, cache_path: Optional[str] = None, show_progress: bool = True, optimizations: Optional[List[str]] = None, @@ -32,12 +31,12 @@ class ServerContext: self.output_path = output_path self.params_path = params_path self.cors_origin = cors_origin - self.num_workers = num_workers self.any_platform = any_platform self.block_platforms = block_platforms or [] self.default_platform = default_platform self.image_format = image_format - self.cache = cache or ModelCache(num_workers) + self.cache = ModelCache(cache_limit) + self.cache_limit = cache_limit self.cache_path = cache_path or path.join(model_path, ".cache") self.show_progress = show_progress self.optimizations = optimizations or [] @@ -45,9 +44,6 @@ class ServerContext: @classmethod def from_environ(cls): - num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1)) - cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2)) - return cls( bundle_path=environ.get( "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") @@ -57,12 +53,11 @@ class ServerContext: params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), # others cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), - num_workers=num_workers, any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), - cache=ModelCache(limit=cache_limit), + cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 2)), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 7f1ef522..b7336fbf 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -11,6 +11,7 @@ class ModelCache: def __init__(self, limit: int) -> None: self.cache = [] self.limit = limit + logger.debug("creating model cache with limit of %s models", limit) def drop(self, tag: str, key: Any) -> int: logger.debug("dropping item from cache: %s %s", tag, key)