fix(tests): expand worker tests
This commit is contained in:
parent
535b685a57
commit
5a517704ea
|
@ -134,25 +134,40 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
|
|||
|
||||
def apply_patch_basicsr(server: ServerContext):
|
||||
logger.debug("patching BasicSR module")
|
||||
try:
|
||||
import basicsr.utils.download_util
|
||||
|
||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
||||
except ImportError:
|
||||
logger.info("unable to import basicsr utils for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch basicsr utils")
|
||||
|
||||
|
||||
def apply_patch_codeformer(server: ServerContext):
|
||||
logger.debug("patching CodeFormer module")
|
||||
try:
|
||||
import codeformer.facelib.utils.misc
|
||||
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
||||
except ImportError:
|
||||
logger.info("unable to import codeformer utils for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch codeformer utils")
|
||||
|
||||
|
||||
def apply_patch_facexlib(server: ServerContext):
|
||||
logger.debug("patching Facexlib module")
|
||||
try:
|
||||
import facexlib.utils
|
||||
|
||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
||||
except ImportError:
|
||||
logger.info("unable to import facexlib for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch facexlib utils")
|
||||
|
||||
|
||||
def apply_patches(server: ServerContext):
|
||||
|
|
|
@ -50,7 +50,7 @@ def worker_main(
|
|||
getpid(),
|
||||
worker.get_active(),
|
||||
)
|
||||
exit(EXIT_REPLACED)
|
||||
return exit(EXIT_REPLACED)
|
||||
|
||||
# wait briefly for the next job
|
||||
job = worker.pending.get(timeout=worker.timeout)
|
||||
|
@ -73,15 +73,15 @@ def worker_main(
|
|||
except KeyboardInterrupt:
|
||||
logger.debug("worker got keyboard interrupt")
|
||||
worker.fail()
|
||||
exit(EXIT_INTERRUPT)
|
||||
return exit(EXIT_INTERRUPT)
|
||||
except RetryException:
|
||||
logger.exception("retry error in worker, exiting")
|
||||
worker.fail()
|
||||
exit(EXIT_ERROR)
|
||||
return exit(EXIT_ERROR)
|
||||
except ValueError:
|
||||
logger.exception("value error in worker, exiting")
|
||||
worker.fail()
|
||||
exit(EXIT_ERROR)
|
||||
return exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
e_str = str(e)
|
||||
# restart the worker on memory errors
|
||||
|
@ -89,7 +89,7 @@ def worker_main(
|
|||
if e_mem in e_str:
|
||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||
worker.fail()
|
||||
exit(EXIT_MEMORY)
|
||||
return exit(EXIT_MEMORY)
|
||||
|
||||
# carry on for other errors
|
||||
logger.exception(
|
||||
|
|
|
@ -77,6 +77,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
||||
|
||||
def test_done_running(self):
|
||||
"""
|
||||
TODO: flaky
|
||||
"""
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
server = ServerContext()
|
||||
|
||||
|
@ -104,6 +107,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
lock.set()
|
||||
|
||||
def test_done_finished(self):
|
||||
"""
|
||||
TODO: flaky
|
||||
"""
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
server = ServerContext()
|
||||
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
import unittest
|
||||
from multiprocessing import Queue, Value
|
||||
from os import getpid
|
||||
from onnx_web.errors import RetryException
|
||||
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobCommand
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from onnx_web.worker.worker import EXIT_INTERRUPT, worker_main
|
||||
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main
|
||||
from tests.helpers import test_device
|
||||
|
||||
def main_memory(_worker):
|
||||
raise Exception(MEMORY_ERRORS[0])
|
||||
|
||||
def main_retry(_worker):
|
||||
raise RetryException()
|
||||
|
||||
def main_interrupt(_worker):
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
|
||||
class WorkerMainTests(unittest.TestCase):
|
||||
def test_pending_exception_empty(self):
|
||||
|
@ -15,28 +27,102 @@ class WorkerMainTests(unittest.TestCase):
|
|||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_interrupt, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", getpid())
|
||||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_INTERRUPT)
|
||||
pass
|
||||
|
||||
def test_pending_exception_retry(self):
|
||||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_retry, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", getpid())
|
||||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_ERROR)
|
||||
pass
|
||||
|
||||
def test_pending_exception_value(self):
|
||||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", False)
|
||||
pid = Value("L", getpid())
|
||||
idle = Value("L", False)
|
||||
|
||||
pending.close()
|
||||
# worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_INTERRUPT)
|
||||
|
||||
def test_pending_exception_retry(self):
|
||||
pass
|
||||
|
||||
def test_pending_exception_value(self):
|
||||
pass
|
||||
self.assertEqual(status, EXIT_ERROR)
|
||||
|
||||
def test_pending_exception_other_memory(self):
|
||||
pass
|
||||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_memory, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", getpid())
|
||||
idle = Value("L", False)
|
||||
|
||||
pending.put(job)
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_MEMORY)
|
||||
|
||||
|
||||
def test_pending_exception_other_unknown(self):
|
||||
pass
|
||||
|
||||
def test_pending_replaced(self):
|
||||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", 0)
|
||||
idle = Value("L", False)
|
||||
|
||||
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_REPLACED)
|
||||
|
||||
|
|
Loading…
Reference in New Issue