more release tests
This commit is contained in:
parent
f561dfae83
commit
47643867be
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -6,6 +6,7 @@ from logging.config import dictConfig
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,23 +14,57 @@ import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
|
class TestCase:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
query: str,
|
||||||
|
max_attempts: int = 20,
|
||||||
|
mse_threshold: float = 0.0001,
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.query = query
|
||||||
|
self.max_attempts = max_attempts
|
||||||
|
self.mse_threshold = mse_threshold
|
||||||
|
|
||||||
|
|
||||||
TEST_DATA = [
|
TEST_DATA = [
|
||||||
(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-256-muffin",
|
"txt2img-sd-v1-5-256-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256",
|
||||||
),
|
),
|
||||||
(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin",
|
"txt2img-sd-v1-5-512-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim",
|
||||||
),
|
),
|
||||||
(
|
TestCase(
|
||||||
|
"txt2img-sd-v1-5-512-muffin-deis",
|
||||||
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
"txt2img-sd-v1-5-512-muffin-dpm",
|
||||||
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-multi",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
"txt2img-sd-v1-5-512-muffin-heun",
|
||||||
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
"txt2img-sd-v2-1-512-muffin",
|
"txt2img-sd-v2-1-512-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1",
|
||||||
),
|
),
|
||||||
(
|
TestCase(
|
||||||
"txt2img-sd-v2-1-768-muffin",
|
"txt2img-sd-v2-1-768-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
|
||||||
),
|
),
|
||||||
|
TestCase(
|
||||||
|
"txt2img-openjourney-512-muffin",
|
||||||
|
"txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
"txt2img-knollingcase-512-muffin",
|
||||||
|
"txt2img?prompt=knollingcase+display+case+with+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-knollingcase",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
|
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
|
||||||
|
@ -105,39 +140,36 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
root: str,
|
root: str,
|
||||||
name: str,
|
test: TestCase,
|
||||||
params: str,
|
|
||||||
ref: Image.Image,
|
ref: Image.Image,
|
||||||
max_attempts: int = 20,
|
|
||||||
mse_threshold: float = 0.0001,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info("running test: %s", params)
|
logger.info("running test: %s", test.query)
|
||||||
|
|
||||||
key = generate_image(root, params)
|
key = generate_image(root, test.query)
|
||||||
if key is None:
|
if key is None:
|
||||||
raise ValueError("could not generate")
|
raise ValueError("could not generate")
|
||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while attempts < max_attempts and not check_ready(root, key):
|
while attempts < test.max_attempts and not check_ready(root, key):
|
||||||
logger.debug("waiting for image to be ready")
|
logger.debug("waiting for image to be ready")
|
||||||
sleep(6)
|
sleep(6)
|
||||||
|
|
||||||
if attempts == max_attempts:
|
if attempts == test.max_attempts:
|
||||||
raise ValueError("image was not ready in time")
|
raise ValueError("image was not ready in time")
|
||||||
|
|
||||||
result = download_image(root, key)
|
result = download_image(root, key)
|
||||||
result.save(test_path(path.join("test-results", f"{name}.png")))
|
result.save(test_path(path.join("test-results", f"{test.name}.png")))
|
||||||
mse = find_mse(result, ref)
|
mse = find_mse(result, ref)
|
||||||
|
|
||||||
if mse < mse_threshold:
|
if mse < test.mse_threshold:
|
||||||
logger.debug("MSE within threshold: %.4f < %.4f", mse, mse_threshold)
|
logger.debug("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning("MSE above threshold: %.4f > %.4f", mse, mse_threshold)
|
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,23 +177,28 @@ def main():
|
||||||
root = test_root()
|
root = test_root()
|
||||||
logger.info("running release tests against API: %s", root)
|
logger.info("running release tests against API: %s", root)
|
||||||
|
|
||||||
failures = 0
|
results = Counter({
|
||||||
for name, query in TEST_DATA:
|
True: 0,
|
||||||
|
False: 0,
|
||||||
|
})
|
||||||
|
for test in TEST_DATA:
|
||||||
try:
|
try:
|
||||||
ref_name = test_path(path.join("test-refs", f"{name}.png"))
|
ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
|
||||||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||||
if run_test(root, name, query, ref):
|
if run_test(root, test, ref):
|
||||||
logger.info("test passed: %s", name)
|
logger.info("test passed: %s", test.name)
|
||||||
|
results[True] += 1
|
||||||
else:
|
else:
|
||||||
logger.warning("test failed: %s", name)
|
logger.warning("test failed: %s", test.name)
|
||||||
failures += 1
|
results[False] += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(type(e), e, e.__traceback__)
|
traceback.print_exception(type(e), e, e.__traceback__)
|
||||||
logger.error("error running test for %s: %s", name, e)
|
logger.error("error running test for %s: %s", test.name, e)
|
||||||
failures += 1
|
results[False] += 1
|
||||||
|
|
||||||
if failures > 0:
|
logger.info("%s of %s tests passed", results[True], results[True] + results[False])
|
||||||
logger.error("%s tests had errors", failures)
|
if results[False] > 0:
|
||||||
|
logger.error("%s tests had errors", results[False])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue