diff --git a/api/scripts/test-refs/txt2img-knollingcase-512-muffin.png b/api/scripts/test-refs/txt2img-knollingcase-512-muffin.png new file mode 100644 index 00000000..196ae4e3 --- /dev/null +++ b/api/scripts/test-refs/txt2img-knollingcase-512-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:680920d9275dddd73cc22de361465b20619223a19c19dcb0db970455eaa30f7f +size 386421 diff --git a/api/scripts/test-refs/txt2img-openjourney-512-muffin.png b/api/scripts/test-refs/txt2img-openjourney-512-muffin.png new file mode 100644 index 00000000..552481c9 --- /dev/null +++ b/api/scripts/test-refs/txt2img-openjourney-512-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21001d9064aceaa1b2fdba49457b4583c6e5530b51debff03ac6e2f24622385b +size 433455 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png new file mode 100644 index 00000000..c3b23071 --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a543d18a572d24c983f30fbae535409cc743932fdb20abcdd765aa04436193d8 +size 432524 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png new file mode 100644 index 00000000..cf819b63 --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c500a76798a89799d8037ee5c71bf68ca8bb2caa02a519ff7c42e0e56108e78a +size 526240 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png new file mode 100644 index 00000000..d3c2b16b --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4f97d2e321fd091272452832350f4cbfd13bc4ba4a8aeeafcfaa396e6bad3f9 +size 497895 diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index 1ea674b4..25a15e67 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -6,6 +6,7 @@ from logging.config import dictConfig from os import environ, path from time import sleep from typing import Optional +from collections import Counter import cv2 import numpy as np @@ -13,23 +14,57 @@ import requests from PIL import Image from yaml import safe_load +class TestCase: + def __init__( + self, + name: str, + query: str, + max_attempts: int = 20, + mse_threshold: float = 0.0001, + ) -> None: + self.name = name + self.query = query + self.max_attempts = max_attempts + self.mse_threshold = mse_threshold + + TEST_DATA = [ - ( + TestCase( "txt2img-sd-v1-5-256-muffin", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256", ), - ( + TestCase( "txt2img-sd-v1-5-512-muffin", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim", ), - ( + TestCase( + "txt2img-sd-v1-5-512-muffin-deis", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", + ), + TestCase( + "txt2img-sd-v1-5-512-muffin-dpm", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-multi", + ), + TestCase( + "txt2img-sd-v1-5-512-muffin-heun", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun", + ), + TestCase( "txt2img-sd-v2-1-512-muffin", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1", ), - ( + TestCase( "txt2img-sd-v2-1-768-muffin", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768", ), + TestCase( + "txt2img-openjourney-512-muffin", + "txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney", + ), + TestCase( + "txt2img-knollingcase-512-muffin", + "txt2img?prompt=knollingcase+display+case+with+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-knollingcase", + ), ] logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml") @@ -105,39 +140,36 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float: def run_test( root: str, - name: str, - params: str, + test: TestCase, ref: Image.Image, - max_attempts: int = 20, - mse_threshold: float = 0.0001, ) -> bool: """ Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ - logger.info("running test: %s", params) + logger.info("running test: %s", test.query) - key = generate_image(root, params) + key = generate_image(root, test.query) if key is None: raise ValueError("could not generate") attempts = 0 - while attempts < max_attempts and not check_ready(root, key): + while attempts < test.max_attempts and not check_ready(root, key): logger.debug("waiting for image to be ready") sleep(6) - if attempts == max_attempts: + if attempts == test.max_attempts: raise ValueError("image was not ready in time") result = download_image(root, key) - result.save(test_path(path.join("test-results", f"{name}.png"))) + result.save(test_path(path.join("test-results", f"{test.name}.png"))) mse = find_mse(result, ref) - if mse < mse_threshold: - logger.debug("MSE within threshold: %.4f < %.4f", mse, mse_threshold) + if mse < test.mse_threshold: + logger.debug("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) return True else: - logger.warning("MSE above threshold: %.4f > %.4f", mse, mse_threshold) + logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) return False @@ -145,23 +177,28 @@ def main(): root = test_root() logger.info("running release tests against API: %s", root) - failures = 0 - for name, query in TEST_DATA: + results = Counter({ + True: 0, + False: 0, + }) + for test in TEST_DATA: try: - ref_name = test_path(path.join("test-refs", f"{name}.png")) + ref_name = test_path(path.join("test-refs", f"{test.name}.png")) ref = Image.open(ref_name) if path.exists(ref_name) else None - if run_test(root, name, query, ref): - logger.info("test passed: %s", name) + if run_test(root, test, ref): + logger.info("test passed: %s", test.name) + results[True] += 1 else: - logger.warning("test failed: %s", name) - failures += 1 + logger.warning("test failed: %s", test.name) + results[False] += 1 except Exception as e: traceback.print_exception(type(e), e, e.__traceback__) - logger.error("error running test for %s: %s", name, e) - failures += 1 + logger.error("error running test for %s: %s", test.name, e) + results[False] += 1 - if failures > 0: - logger.error("%s tests had errors", failures) + logger.info("%s of %s tests passed", results[True], results[True] + results[False]) + if results[False] > 0: + logger.error("%s tests had errors", results[False]) sys.exit(1) if __name__ == "__main__":