1
0
Fork 0

more release tests

This commit is contained in:
Sean Sube 2023-02-19 11:20:59 -06:00
parent f561dfae83
commit 47643867be
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 79 additions and 27 deletions

BIN
api/scripts/test-refs/txt2img-knollingcase-512-muffin.png (Stored with Git LFS) Normal file

Binary file not shown.

BIN
api/scripts/test-refs/txt2img-openjourney-512-muffin.png (Stored with Git LFS) Normal file

Binary file not shown.

BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png (Stored with Git LFS) Normal file

Binary file not shown.

BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png (Stored with Git LFS) Normal file

Binary file not shown.

BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -6,6 +6,7 @@ from logging.config import dictConfig
from os import environ, path from os import environ, path
from time import sleep from time import sleep
from typing import Optional from typing import Optional
from collections import Counter
import cv2 import cv2
import numpy as np import numpy as np
@ -13,23 +14,57 @@ import requests
from PIL import Image from PIL import Image
from yaml import safe_load from yaml import safe_load
class TestCase:
def __init__(
self,
name: str,
query: str,
max_attempts: int = 20,
mse_threshold: float = 0.0001,
) -> None:
self.name = name
self.query = query
self.max_attempts = max_attempts
self.mse_threshold = mse_threshold
TEST_DATA = [ TEST_DATA = [
( TestCase(
"txt2img-sd-v1-5-256-muffin", "txt2img-sd-v1-5-256-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256",
), ),
( TestCase(
"txt2img-sd-v1-5-512-muffin", "txt2img-sd-v1-5-512-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim",
), ),
( TestCase(
"txt2img-sd-v1-5-512-muffin-deis",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
),
TestCase(
"txt2img-sd-v1-5-512-muffin-dpm",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-multi",
),
TestCase(
"txt2img-sd-v1-5-512-muffin-heun",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
),
TestCase(
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1",
), ),
( TestCase(
"txt2img-sd-v2-1-768-muffin", "txt2img-sd-v2-1-768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
), ),
TestCase(
"txt2img-openjourney-512-muffin",
"txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney",
),
TestCase(
"txt2img-knollingcase-512-muffin",
"txt2img?prompt=knollingcase+display+case+with+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-knollingcase",
),
] ]
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml") logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
@ -105,39 +140,36 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test( def run_test(
root: str, root: str,
name: str, test: TestCase,
params: str,
ref: Image.Image, ref: Image.Image,
max_attempts: int = 20,
mse_threshold: float = 0.0001,
) -> bool: ) -> bool:
""" """
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
logger.info("running test: %s", params) logger.info("running test: %s", test.query)
key = generate_image(root, params) key = generate_image(root, test.query)
if key is None: if key is None:
raise ValueError("could not generate") raise ValueError("could not generate")
attempts = 0 attempts = 0
while attempts < max_attempts and not check_ready(root, key): while attempts < test.max_attempts and not check_ready(root, key):
logger.debug("waiting for image to be ready") logger.debug("waiting for image to be ready")
sleep(6) sleep(6)
if attempts == max_attempts: if attempts == test.max_attempts:
raise ValueError("image was not ready in time") raise ValueError("image was not ready in time")
result = download_image(root, key) result = download_image(root, key)
result.save(test_path(path.join("test-results", f"{name}.png"))) result.save(test_path(path.join("test-results", f"{test.name}.png")))
mse = find_mse(result, ref) mse = find_mse(result, ref)
if mse < mse_threshold: if mse < test.mse_threshold:
logger.debug("MSE within threshold: %.4f < %.4f", mse, mse_threshold) logger.debug("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
return True return True
else: else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, mse_threshold) logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
return False return False
@ -145,23 +177,28 @@ def main():
root = test_root() root = test_root()
logger.info("running release tests against API: %s", root) logger.info("running release tests against API: %s", root)
failures = 0 results = Counter({
for name, query in TEST_DATA: True: 0,
False: 0,
})
for test in TEST_DATA:
try: try:
ref_name = test_path(path.join("test-refs", f"{name}.png")) ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None ref = Image.open(ref_name) if path.exists(ref_name) else None
if run_test(root, name, query, ref): if run_test(root, test, ref):
logger.info("test passed: %s", name) logger.info("test passed: %s", test.name)
results[True] += 1
else: else:
logger.warning("test failed: %s", name) logger.warning("test failed: %s", test.name)
failures += 1 results[False] += 1
except Exception as e: except Exception as e:
traceback.print_exception(type(e), e, e.__traceback__) traceback.print_exception(type(e), e, e.__traceback__)
logger.error("error running test for %s: %s", name, e) logger.error("error running test for %s: %s", test.name, e)
failures += 1 results[False] += 1
if failures > 0: logger.info("%s of %s tests passed", results[True], results[True] + results[False])
logger.error("%s tests had errors", failures) if results[False] > 0:
logger.error("%s tests had errors", results[False])
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":