From 9a7770fe4811f0f5eea67c7d4ed16313d7938b04 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Feb 2023 10:25:27 -0600 Subject: [PATCH] feat(test): add release regression testing script --- .../test-refs/txt2img-sd-v1-5-256-muffin.png | 3 + .../test-refs/txt2img-sd-v1-5-512-muffin.png | 3 + .../test-refs/txt2img-sd-v2-1-512-muffin.png | 3 + .../test-refs/txt2img-sd-v2-1-768-muffin.png | 3 + api/scripts/test-release.py | 168 ++++++++++++++++++ api/scripts/test-results/.gitignore | 1 + api/scripts/test-results/.gitkeep | 0 7 files changed, 181 insertions(+) create mode 100644 api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png create mode 100644 api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png create mode 100644 api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png create mode 100644 api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png create mode 100644 api/scripts/test-release.py create mode 100644 api/scripts/test-results/.gitignore create mode 100644 api/scripts/test-results/.gitkeep diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png new file mode 100644 index 00000000..f1fa9ead --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eeb6a621638eb5a0409048df4f1cf2ab0ab0c7cad378b6fca70f981e21a4d000 +size 121046 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png new file mode 100644 index 00000000..9af7aed2 --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4106859b033cf57d0f2b08c36cff38beb5a67f6e99e2eddefad998223be4182 +size 498897 diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png new file mode 100644 index 00000000..7b1a789a --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dee66e2a5b003c6e892e2f5e7b7ba95c471ee301cdad34e11199083544a46d8 +size 502780 diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png new file mode 100644 index 00000000..7ffe41a9 --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd993473982b2a7fe24b2997d4a7e61e656bd18aba83006a57b2de3e5755b2f5 +size 1062706 diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py new file mode 100644 index 00000000..1ea674b4 --- /dev/null +++ b/api/scripts/test-release.py @@ -0,0 +1,168 @@ +import sys +import traceback +from io import BytesIO +from logging import getLogger +from logging.config import dictConfig +from os import environ, path +from time import sleep +from typing import Optional + +import cv2 +import numpy as np +import requests +from PIL import Image +from yaml import safe_load + +TEST_DATA = [ + ( + "txt2img-sd-v1-5-256-muffin", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256", + ), + ( + "txt2img-sd-v1-5-512-muffin", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim", + ), + ( + "txt2img-sd-v2-1-512-muffin", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1", + ), + ( + "txt2img-sd-v2-1-768-muffin", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768", + ), +] + +logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml") + +try: + if path.exists(logging_path): + with open(logging_path, "r") as f: + config_logging = safe_load(f) + dictConfig(config_logging) +except Exception as err: + print("error loading logging config: %s" % (err)) + +logger = getLogger(__name__) + + +def test_root() -> str: + if len(sys.argv) > 1: + return sys.argv[1] + else: + return "http://127.0.0.1:5000" + + +def test_path(relpath: str) -> str: + return path.join(path.dirname(__file__), relpath) + + +def generate_image(root: str, params: str) -> Optional[str]: + resp = requests.post(f"{root}/api/{params}") + if resp.status_code == 200: + json = resp.json() + return json.get("output") + else: + logger.warning("request failed: %s", resp.status_code) + return None + + +def check_ready(root: str, key: str) -> bool: + resp = requests.get(f"{root}/api/ready?output={key}") + if resp.status_code == 200: + json = resp.json() + return json.get("ready", False) + else: + logger.warning("request failed: %s", resp.status_code) + return False + + +def download_image(root: str, key: str) -> Image.Image: + resp = requests.get(f"{root}/output/{key}") + if resp.status_code == 200: + return Image.open(BytesIO(resp.content)) + else: + logger.warning("request failed: %s", resp.status_code) + return None + + +def find_mse(result: Image.Image, ref: Image.Image) -> float: + if result.mode != ref.mode: + logger.warning("image mode does not match: %s vs %s", result.mode, ref.mode) + return float("inf") + + if result.size != ref.size: + logger.warning("image size does not match: %s vs %s", result.size, ref.size) + return float("inf") + + nd_result = np.array(result) + nd_ref = np.array(ref) + + diff = cv2.subtract(nd_ref, nd_result) + diff = np.sum(diff**2) + + return diff / (float(ref.height * ref.width)) / 255.0 + + +def run_test( + root: str, + name: str, + params: str, + ref: Image.Image, + max_attempts: int = 20, + mse_threshold: float = 0.0001, +) -> bool: + """ + Generate an image, wait for it to be ready, and calculate the MSE from the reference. + """ + + logger.info("running test: %s", params) + + key = generate_image(root, params) + if key is None: + raise ValueError("could not generate") + + attempts = 0 + while attempts < max_attempts and not check_ready(root, key): + logger.debug("waiting for image to be ready") + sleep(6) + + if attempts == max_attempts: + raise ValueError("image was not ready in time") + + result = download_image(root, key) + result.save(test_path(path.join("test-results", f"{name}.png"))) + mse = find_mse(result, ref) + + if mse < mse_threshold: + logger.debug("MSE within threshold: %.4f < %.4f", mse, mse_threshold) + return True + else: + logger.warning("MSE above threshold: %.4f > %.4f", mse, mse_threshold) + return False + + +def main(): + root = test_root() + logger.info("running release tests against API: %s", root) + + failures = 0 + for name, query in TEST_DATA: + try: + ref_name = test_path(path.join("test-refs", f"{name}.png")) + ref = Image.open(ref_name) if path.exists(ref_name) else None + if run_test(root, name, query, ref): + logger.info("test passed: %s", name) + else: + logger.warning("test failed: %s", name) + failures += 1 + except Exception as e: + traceback.print_exception(type(e), e, e.__traceback__) + logger.error("error running test for %s: %s", name, e) + failures += 1 + + if failures > 0: + logger.error("%s tests had errors", failures) + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/api/scripts/test-results/.gitignore b/api/scripts/test-results/.gitignore new file mode 100644 index 00000000..aab52d90 --- /dev/null +++ b/api/scripts/test-results/.gitignore @@ -0,0 +1 @@ +*.png \ No newline at end of file diff --git a/api/scripts/test-results/.gitkeep b/api/scripts/test-results/.gitkeep new file mode 100644 index 00000000..e69de29b