1
0
Fork 0

fix(api): make cache global within each worker process (#227)

This commit is contained in:
Sean Sube 2023-03-11 13:30:11 -06:00
parent 01d3519aa3
commit 575cb8831b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 24 additions and 13 deletions

View File

@ -3,26 +3,31 @@ from typing import Any, List, Tuple
logger = getLogger(__name__) logger = getLogger(__name__)
cache: List[Tuple[str, Any, Any]] = []
class ModelCache: class ModelCache:
cache: List[Tuple[str, Any, Any]] # cache: List[Tuple[str, Any, Any]]
limit: int limit: int
def __init__(self, limit: int) -> None: def __init__(self, limit: int) -> None:
self.cache = []
self.limit = limit self.limit = limit
logger.debug("creating model cache with limit of %s models", limit) logger.debug("creating model cache with limit of %s models", limit)
def drop(self, tag: str, key: Any) -> int: def drop(self, tag: str, key: Any) -> int:
global cache
logger.debug("dropping item from cache: %s %s", tag, key) logger.debug("dropping item from cache: %s %s", tag, key)
removed = [model for model in self.cache if model[0] == tag and model[1] == key] removed = [model for model in cache if model[0] == tag and model[1] == key]
for item in removed: for item in removed:
self.cache.remove(item) cache.remove(item)
return len(removed) return len(removed)
def get(self, tag: str, key: Any) -> Any: def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache: global cache
for t, k, v in cache:
if tag == t and key == k: if tag == t and key == k:
logger.debug("found cached model: %s %s", tag, key) logger.debug("found cached model: %s %s", tag, key)
return v return v
@ -31,36 +36,42 @@ class ModelCache:
return None return None
def set(self, tag: str, key: Any, value: Any) -> None: def set(self, tag: str, key: Any, value: Any) -> None:
global cache
if self.limit == 0: if self.limit == 0:
logger.debug("cache limit set to 0, not caching model: %s", tag) logger.debug("cache limit set to 0, not caching model: %s", tag)
return return
for i in range(len(self.cache)): for i in range(len(cache)):
t, k, v = self.cache[i] t, k, v = cache[i]
if tag == t and key != k: if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key) logger.debug("updating model cache: %s %s", tag, key)
self.cache[i] = (tag, key, value) cache[i] = (tag, key, value)
return return
logger.debug("adding new model to cache: %s %s", tag, key) logger.debug("adding new model to cache: %s %s", tag, key)
self.cache.append((tag, key, value)) cache.append((tag, key, value))
self.prune() self.prune()
def prune(self): def prune(self):
total = len(self.cache) global cache
total = len(cache)
overage = total - self.limit overage = total - self.limit
if overage > 0: if overage > 0:
removed = self.cache[:overage] removed = cache[:overage]
logger.info( logger.info(
"removing %s of %s models from cache, %s", "removing %s of %s models from cache, %s",
overage, overage,
total, total,
[m[0] for m in removed], [m[0] for m in removed],
) )
self.cache[:] = self.cache[-self.limit :] cache[:] = cache[-self.limit :]
else: else:
logger.debug("model cache below limit, %s of %s", total, self.limit) logger.debug("model cache below limit, %s of %s", total, self.limit)
@property @property
def size(self): def size(self):
return len(self.cache) global cache
return len(cache)