1
0
Fork 0

fix(tests): expand worker tests

This commit is contained in:
Sean Sube 2023-11-18 17:20:45 -06:00
parent 535b685a57
commit 5a517704ea
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 131 additions and 24 deletions

View File

@ -134,25 +134,40 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module")
try:
import basicsr.utils.download_util
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import basicsr utils for patching")
except AttributeError:
logger.warning("unable to patch basicsr utils")
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import codeformer utils for patching")
except AttributeError:
logger.warning("unable to patch codeformer utils")
def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module")
try:
import facexlib.utils
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import facexlib for patching")
except AttributeError:
logger.warning("unable to patch facexlib utils")
def apply_patches(server: ServerContext):

View File

@ -50,7 +50,7 @@ def worker_main(
getpid(),
worker.get_active(),
)
exit(EXIT_REPLACED)
return exit(EXIT_REPLACED)
# wait briefly for the next job
job = worker.pending.get(timeout=worker.timeout)
@ -73,15 +73,15 @@ def worker_main(
except KeyboardInterrupt:
logger.debug("worker got keyboard interrupt")
worker.fail()
exit(EXIT_INTERRUPT)
return exit(EXIT_INTERRUPT)
except RetryException:
logger.exception("retry error in worker, exiting")
worker.fail()
exit(EXIT_ERROR)
return exit(EXIT_ERROR)
except ValueError:
logger.exception("value error in worker, exiting")
worker.fail()
exit(EXIT_ERROR)
return exit(EXIT_ERROR)
except Exception as e:
e_str = str(e)
# restart the worker on memory errors
@ -89,7 +89,7 @@ def worker_main(
if e_mem in e_str:
logger.error("detected out-of-memory error, exiting: %s", e)
worker.fail()
exit(EXIT_MEMORY)
return exit(EXIT_MEMORY)
# carry on for other errors
logger.exception(

View File

@ -77,6 +77,9 @@ class TestWorkerPool(unittest.TestCase):
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
def test_done_running(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
@ -104,6 +107,9 @@ class TestWorkerPool(unittest.TestCase):
lock.set()
def test_done_finished(self):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()

View File

@ -1,11 +1,23 @@
import unittest
from multiprocessing import Queue, Value
from os import getpid
from onnx_web.errors import RetryException
from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobCommand
from onnx_web.worker.context import WorkerContext
from onnx_web.worker.worker import EXIT_INTERRUPT, worker_main
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main
from tests.helpers import test_device
def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0])
def main_retry(_worker):
raise RetryException()
def main_interrupt(_worker):
raise KeyboardInterrupt()
class WorkerMainTests(unittest.TestCase):
def test_pending_exception_empty(self):
@ -15,28 +27,102 @@ class WorkerMainTests(unittest.TestCase):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_interrupt, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_INTERRUPT)
pass
def test_pending_exception_retry(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_retry, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_ERROR)
pass
def test_pending_exception_value(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", False)
pid = Value("L", getpid())
idle = Value("L", False)
pending.close()
# worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_INTERRUPT)
def test_pending_exception_retry(self):
pass
def test_pending_exception_value(self):
pass
self.assertEqual(status, EXIT_ERROR)
def test_pending_exception_other_memory(self):
pass
status = None
def exit(exit_status):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_memory, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", getpid())
idle = Value("L", False)
pending.put(job)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_MEMORY)
def test_pending_exception_other_unknown(self):
pass
def test_pending_replaced(self):
status = None
def exit(exit_status):
nonlocal status
status = exit_status
cancel = Value("L", False)
logs = Queue()
pending = Queue()
progress = Queue()
pid = Value("L", 0)
idle = Value("L", False)
worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
self.assertEqual(status, EXIT_REPLACED)