diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index b59bb73a..69053b75 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -134,25 +134,40 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str: def apply_patch_basicsr(server: ServerContext): logger.debug("patching BasicSR module") - import basicsr.utils.download_util + try: + import basicsr.utils.download_util - basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl - basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) + basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl + basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) + except ImportError: + logger.info("unable to import basicsr utils for patching") + except AttributeError: + logger.warning("unable to patch basicsr utils") def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") - import codeformer.facelib.utils.misc + try: + import codeformer.facelib.utils.misc - codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl - codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) + codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl + codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) + except ImportError: + logger.info("unable to import codeformer utils for patching") + except AttributeError: + logger.warning("unable to patch codeformer utils") def apply_patch_facexlib(server: ServerContext): logger.debug("patching Facexlib module") - import facexlib.utils + try: + import facexlib.utils - facexlib.utils.load_file_from_url = partial(patch_cache_path, server) + facexlib.utils.load_file_from_url = partial(patch_cache_path, server) + except ImportError: + logger.info("unable to import facexlib for patching") + except AttributeError: + logger.warning("unable to patch facexlib utils") def apply_patches(server: ServerContext): diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index a55ba4a2..55ebcaac 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -50,7 +50,7 @@ def worker_main( getpid(), worker.get_active(), ) - exit(EXIT_REPLACED) + return exit(EXIT_REPLACED) # wait briefly for the next job job = worker.pending.get(timeout=worker.timeout) @@ -73,15 +73,15 @@ def worker_main( except KeyboardInterrupt: logger.debug("worker got keyboard interrupt") worker.fail() - exit(EXIT_INTERRUPT) + return exit(EXIT_INTERRUPT) except RetryException: logger.exception("retry error in worker, exiting") worker.fail() - exit(EXIT_ERROR) + return exit(EXIT_ERROR) except ValueError: logger.exception("value error in worker, exiting") worker.fail() - exit(EXIT_ERROR) + return exit(EXIT_ERROR) except Exception as e: e_str = str(e) # restart the worker on memory errors @@ -89,7 +89,7 @@ def worker_main( if e_mem in e_str: logger.error("detected out-of-memory error, exiting: %s", e) worker.fail() - exit(EXIT_MEMORY) + return exit(EXIT_MEMORY) # carry on for other errors logger.exception( diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 7ea73451..d0a36982 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -77,6 +77,9 @@ class TestWorkerPool(unittest.TestCase): self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) def test_done_running(self): + """ + TODO: flaky + """ device = DeviceParams("cpu", "CPUProvider") server = ServerContext() @@ -104,6 +107,9 @@ class TestWorkerPool(unittest.TestCase): lock.set() def test_done_finished(self): + """ + TODO: flaky + """ device = DeviceParams("cpu", "CPUProvider") server = ServerContext() diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py index 06c0822d..f0c3e89c 100644 --- a/api/tests/worker/test_worker.py +++ b/api/tests/worker/test_worker.py @@ -1,11 +1,23 @@ import unittest from multiprocessing import Queue, Value +from os import getpid +from onnx_web.errors import RetryException from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobCommand from onnx_web.worker.context import WorkerContext -from onnx_web.worker.worker import EXIT_INTERRUPT, worker_main +from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main from tests.helpers import test_device +def main_memory(_worker): + raise Exception(MEMORY_ERRORS[0]) + +def main_retry(_worker): + raise RetryException() + +def main_interrupt(_worker): + raise KeyboardInterrupt() + class WorkerMainTests(unittest.TestCase): def test_pending_exception_empty(self): @@ -15,28 +27,102 @@ class WorkerMainTests(unittest.TestCase): status = None def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_interrupt, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + + self.assertEqual(status, EXIT_INTERRUPT) + pass + + def test_pending_exception_retry(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_retry, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + + self.assertEqual(status, EXIT_ERROR) + pass + + def test_pending_exception_value(self): + status = None + + def exit(exit_status): + nonlocal status status = exit_status cancel = Value("L", False) logs = Queue() pending = Queue() progress = Queue() - pid = Value("L", False) + pid = Value("L", getpid()) idle = Value("L", False) pending.close() - # worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) - self.assertEqual(status, EXIT_INTERRUPT) - - def test_pending_exception_retry(self): - pass - - def test_pending_exception_value(self): - pass + self.assertEqual(status, EXIT_ERROR) def test_pending_exception_other_memory(self): - pass + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_memory, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + + self.assertEqual(status, EXIT_MEMORY) + def test_pending_exception_other_unknown(self): pass + + def test_pending_replaced(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", 0) + idle = Value("L", False) + + worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + + self.assertEqual(status, EXIT_REPLACED) +