diff --git a/api/Makefile b/api/Makefile index 142ddd9a..5ac36e7e 100644 --- a/api/Makefile +++ b/api/Makefile @@ -17,8 +17,8 @@ pip-dev: check-venv test: python -m coverage run -m unittest discover -s tests/ - python -m coverage html - python -m coverage xml + python -m coverage html -i + python -m coverage xml -i package: package-dist package-upload diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 31213436..00717223 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -12,11 +12,15 @@ class ModelCache: self.cache = [] self.limit = limit - def drop(self, tag: str, key: Any) -> None: + def drop(self, tag: str, key: Any) -> int: logger.debug("dropping item from cache: %s", tag) - self.cache[:] = [ - model for model in self.cache if model[0] != tag and model[1] != key + removed = [ + model for model in self.cache if model[0] == tag and model[1] == key ] + for item in removed: + self.cache.remove(item) + + return len(removed) def get(self, tag: str, key: Any) -> Any: for t, k, v in self.cache: @@ -52,3 +56,7 @@ class ModelCache: self.cache[:] = self.cache[-self.limit :] else: logger.debug("model cache below limit, %s of %s", total, self.limit) + + @property + def size(self): + return len(self.cache) \ No newline at end of file diff --git a/api/tests/server/__init__.py b/api/tests/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/server/test_model_cache.py b/api/tests/server/test_model_cache.py new file mode 100644 index 00000000..149b22e9 --- /dev/null +++ b/api/tests/server/test_model_cache.py @@ -0,0 +1,30 @@ +import unittest + +from onnx_web.server.model_cache import ModelCache + +class TestStringMethods(unittest.TestCase): + def test_drop_existing(self): + cache = ModelCache(10) + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bar",)), 1) + + def test_drop_missing(self): + cache = ModelCache(10) + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bin",)), 0) + + def test_get_existing(self): + cache = ModelCache(10) + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bar",)), value) + + def test_get_missing(self): + cache = ModelCache(10) + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bin",)), None)