diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index b7336fbf..c9f9df4e 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -3,26 +3,31 @@ from typing import Any, List, Tuple logger = getLogger(__name__) +cache: List[Tuple[str, Any, Any]] = [] + class ModelCache: - cache: List[Tuple[str, Any, Any]] + # cache: List[Tuple[str, Any, Any]] limit: int def __init__(self, limit: int) -> None: - self.cache = [] self.limit = limit logger.debug("creating model cache with limit of %s models", limit) def drop(self, tag: str, key: Any) -> int: + global cache + logger.debug("dropping item from cache: %s %s", tag, key) - removed = [model for model in self.cache if model[0] == tag and model[1] == key] + removed = [model for model in cache if model[0] == tag and model[1] == key] for item in removed: - self.cache.remove(item) + cache.remove(item) return len(removed) def get(self, tag: str, key: Any) -> Any: - for t, k, v in self.cache: + global cache + + for t, k, v in cache: if tag == t and key == k: logger.debug("found cached model: %s %s", tag, key) return v @@ -31,36 +36,42 @@ class ModelCache: return None def set(self, tag: str, key: Any, value: Any) -> None: + global cache + if self.limit == 0: logger.debug("cache limit set to 0, not caching model: %s", tag) return - for i in range(len(self.cache)): - t, k, v = self.cache[i] + for i in range(len(cache)): + t, k, v = cache[i] if tag == t and key != k: logger.debug("updating model cache: %s %s", tag, key) - self.cache[i] = (tag, key, value) + cache[i] = (tag, key, value) return logger.debug("adding new model to cache: %s %s", tag, key) - self.cache.append((tag, key, value)) + cache.append((tag, key, value)) self.prune() def prune(self): - total = len(self.cache) + global cache + + total = len(cache) overage = total - self.limit if overage > 0: - removed = self.cache[:overage] + removed = cache[:overage] logger.info( "removing %s of %s models from cache, %s", overage, total, [m[0] for m in removed], ) - self.cache[:] = self.cache[-self.limit :] + cache[:] = cache[-self.limit :] else: logger.debug("model cache below limit, %s of %s", total, self.limit) @property def size(self): - return len(self.cache) + global cache + + return len(cache)