fix(api): add name filter and MSE multiplier args to release tests
This commit is contained in:
parent
d9f251c88b
commit
5378619ef2
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from logging.config import dictConfig
|
||||
|
@ -28,13 +29,6 @@ FAST_TEST = 20
|
|||
SLOW_TEST = 50
|
||||
|
||||
|
||||
def test_root() -> str:
|
||||
if len(sys.argv) > 1:
|
||||
return sys.argv[1]
|
||||
else:
|
||||
return "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_path(relpath: str) -> str:
|
||||
return path.join(path.dirname(__file__), relpath)
|
||||
|
||||
|
@ -321,7 +315,18 @@ class TestError(Exception):
|
|||
return super().__str__()
|
||||
|
||||
|
||||
def generate_images(root: str, test: TestCase) -> Optional[str]:
|
||||
def parse_args(args: List[str]):
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web release tests",
|
||||
description="regression tests for onnx-web",
|
||||
)
|
||||
parser.add_argument("--host", default="http://127.0.0.1:5000")
|
||||
parser.add_argument("-n", "--name")
|
||||
parser.add_argument("-m", "--mse", default=1.0, type=float)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def generate_images(host: str, test: TestCase) -> Optional[str]:
|
||||
files = {}
|
||||
if test.source is not None:
|
||||
if isinstance(test.source, list):
|
||||
|
@ -354,7 +359,7 @@ def generate_images(root: str, test: TestCase) -> Optional[str]:
|
|||
files["mask"] = mask_bytes
|
||||
|
||||
logger.debug("generating image: %s", test.query)
|
||||
resp = requests.post(f"{root}/api/{test.query}", files=files)
|
||||
resp = requests.post(f"{host}/api/{test.query}", files=files)
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
return json.get("outputs")
|
||||
|
@ -363,8 +368,8 @@ def generate_images(root: str, test: TestCase) -> Optional[str]:
|
|||
raise TestError("error generating image")
|
||||
|
||||
|
||||
def check_ready(root: str, key: str) -> bool:
|
||||
resp = requests.get(f"{root}/api/ready?output={key}")
|
||||
def check_ready(host: str, key: str) -> bool:
|
||||
resp = requests.get(f"{host}/api/ready?output={key}")
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
ready = json.get("ready", False)
|
||||
|
@ -379,10 +384,10 @@ def check_ready(root: str, key: str) -> bool:
|
|||
raise TestError("error getting image status")
|
||||
|
||||
|
||||
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
|
||||
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
|
||||
images = []
|
||||
for key in keys:
|
||||
resp = requests.get(f"{root}/output/{key}")
|
||||
resp = requests.get(f"{host}/output/{key}")
|
||||
if resp.status_code == 200:
|
||||
logger.debug("downloading image: %s", key)
|
||||
images.append(Image.open(BytesIO(resp.content)))
|
||||
|
@ -413,20 +418,21 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
|||
|
||||
|
||||
def run_test(
|
||||
root: str,
|
||||
host: str,
|
||||
test: TestCase,
|
||||
mse_mult: float = 1.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||
"""
|
||||
|
||||
keys = generate_images(root, test)
|
||||
keys = generate_images(host, test)
|
||||
if keys is None:
|
||||
raise ValueError("could not generate image")
|
||||
|
||||
attempts = 0
|
||||
while attempts < test.max_attempts:
|
||||
if check_ready(root, keys[0]):
|
||||
if check_ready(host, keys[0]):
|
||||
logger.debug("image is ready: %s", keys)
|
||||
break
|
||||
else:
|
||||
|
@ -437,7 +443,7 @@ def run_test(
|
|||
if attempts == test.max_attempts:
|
||||
raise ValueError("image was not ready in time")
|
||||
|
||||
results = download_images(root, keys)
|
||||
results = download_images(host, keys)
|
||||
if results is None:
|
||||
raise ValueError("could not download image")
|
||||
|
||||
|
@ -451,7 +457,7 @@ def run_test(
|
|||
|
||||
mse = find_mse(result, ref)
|
||||
|
||||
if mse < test.mse_threshold:
|
||||
if mse < (test.mse_threshold * mse_mult):
|
||||
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
|
||||
else:
|
||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)
|
||||
|
@ -461,24 +467,29 @@ def run_test(
|
|||
|
||||
|
||||
def main():
|
||||
root = test_root()
|
||||
logger.info("running release tests against API: %s", root)
|
||||
args = parse_args(sys.argv[1:])
|
||||
logger.info("running release tests against API: %s", args.host)
|
||||
|
||||
if args.name is None:
|
||||
tests = TEST_DATA
|
||||
else:
|
||||
tests = [test for test in TEST_DATA if args.name in test.name]
|
||||
|
||||
# make sure tests have unique names
|
||||
test_names = [test.name for test in TEST_DATA]
|
||||
test_names = [test.name for test in tests]
|
||||
if len(test_names) != len(set(test_names)):
|
||||
logger.error("tests must have unique names: %s", test_names)
|
||||
sys.exit(1)
|
||||
|
||||
passed = []
|
||||
failed = []
|
||||
for test in TEST_DATA:
|
||||
for test in tests:
|
||||
test_passed = False
|
||||
|
||||
for _i in range(3):
|
||||
try:
|
||||
logger.info("starting test: %s", test.name)
|
||||
if run_test(root, test):
|
||||
if run_test(args.host, test, mse_mult=args.mse):
|
||||
logger.info("test passed: %s", test.name)
|
||||
test_passed = True
|
||||
break
|
||||
|
@ -492,7 +503,7 @@ def main():
|
|||
else:
|
||||
failed.append(test.name)
|
||||
|
||||
logger.info("%s of %s tests passed", len(passed), len(TEST_DATA))
|
||||
logger.info("%s of %s tests passed", len(passed), len(tests))
|
||||
failed = list(set(failed))
|
||||
if len(failed) > 0:
|
||||
logger.error("%s tests had errors: %s", len(failed), failed)
|
||||
|
|
Loading…
Reference in New Issue