1
0
Fork 0

fix(api): add name filter and MSE multiplier args to release tests

This commit is contained in:
Sean Sube 2023-06-30 22:31:46 -05:00
parent d9f251c88b
commit 5378619ef2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 35 additions and 24 deletions

View File

@ -1,4 +1,5 @@
import sys
from argparse import ArgumentParser
from io import BytesIO
from logging import getLogger
from logging.config import dictConfig
@ -28,13 +29,6 @@ FAST_TEST = 20
SLOW_TEST = 50
def test_root() -> str:
if len(sys.argv) > 1:
return sys.argv[1]
else:
return "http://127.0.0.1:5000"
def test_path(relpath: str) -> str:
return path.join(path.dirname(__file__), relpath)
@ -321,7 +315,18 @@ class TestError(Exception):
return super().__str__()
def generate_images(root: str, test: TestCase) -> Optional[str]:
def parse_args(args: List[str]):
parser = ArgumentParser(
prog="onnx-web release tests",
description="regression tests for onnx-web",
)
parser.add_argument("--host", default="http://127.0.0.1:5000")
parser.add_argument("-n", "--name")
parser.add_argument("-m", "--mse", default=1.0, type=float)
return parser.parse_args(args)
def generate_images(host: str, test: TestCase) -> Optional[str]:
files = {}
if test.source is not None:
if isinstance(test.source, list):
@ -354,7 +359,7 @@ def generate_images(root: str, test: TestCase) -> Optional[str]:
files["mask"] = mask_bytes
logger.debug("generating image: %s", test.query)
resp = requests.post(f"{root}/api/{test.query}", files=files)
resp = requests.post(f"{host}/api/{test.query}", files=files)
if resp.status_code == 200:
json = resp.json()
return json.get("outputs")
@ -363,8 +368,8 @@ def generate_images(root: str, test: TestCase) -> Optional[str]:
raise TestError("error generating image")
def check_ready(root: str, key: str) -> bool:
resp = requests.get(f"{root}/api/ready?output={key}")
def check_ready(host: str, key: str) -> bool:
resp = requests.get(f"{host}/api/ready?output={key}")
if resp.status_code == 200:
json = resp.json()
ready = json.get("ready", False)
@ -379,10 +384,10 @@ def check_ready(root: str, key: str) -> bool:
raise TestError("error getting image status")
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
images = []
for key in keys:
resp = requests.get(f"{root}/output/{key}")
resp = requests.get(f"{host}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
images.append(Image.open(BytesIO(resp.content)))
@ -413,20 +418,21 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test(
root: str,
host: str,
test: TestCase,
mse_mult: float = 1.0,
) -> bool:
"""
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""
keys = generate_images(root, test)
keys = generate_images(host, test)
if keys is None:
raise ValueError("could not generate image")
attempts = 0
while attempts < test.max_attempts:
if check_ready(root, keys[0]):
if check_ready(host, keys[0]):
logger.debug("image is ready: %s", keys)
break
else:
@ -437,7 +443,7 @@ def run_test(
if attempts == test.max_attempts:
raise ValueError("image was not ready in time")
results = download_images(root, keys)
results = download_images(host, keys)
if results is None:
raise ValueError("could not download image")
@ -451,7 +457,7 @@ def run_test(
mse = find_mse(result, ref)
if mse < test.mse_threshold:
if mse < (test.mse_threshold * mse_mult):
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)
@ -461,24 +467,29 @@ def run_test(
def main():
root = test_root()
logger.info("running release tests against API: %s", root)
args = parse_args(sys.argv[1:])
logger.info("running release tests against API: %s", args.host)
if args.name is None:
tests = TEST_DATA
else:
tests = [test for test in TEST_DATA if args.name in test.name]
# make sure tests have unique names
test_names = [test.name for test in TEST_DATA]
test_names = [test.name for test in tests]
if len(test_names) != len(set(test_names)):
logger.error("tests must have unique names: %s", test_names)
sys.exit(1)
passed = []
failed = []
for test in TEST_DATA:
for test in tests:
test_passed = False
for _i in range(3):
try:
logger.info("starting test: %s", test.name)
if run_test(root, test):
if run_test(args.host, test, mse_mult=args.mse):
logger.info("test passed: %s", test.name)
test_passed = True
break
@ -492,7 +503,7 @@ def main():
else:
failed.append(test.name)
logger.info("%s of %s tests passed", len(passed), len(TEST_DATA))
logger.info("%s of %s tests passed", len(passed), len(tests))
failed = list(set(failed))
if len(failed) > 0:
logger.error("%s tests had errors: %s", len(failed), failed)