1
0
Fork 0

fix(scripts): update release tests with support for batches

This commit is contained in:
Sean Sube 2023-02-25 18:01:05 -06:00
parent cb8e9e7080
commit 6809d2da82
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 57 additions and 42 deletions

View File

@ -25,6 +25,9 @@ except Exception as err:
logger = getLogger(__name__)
FAST_TEST = 20
SLOW_TEST = 50
def test_root() -> str:
if len(sys.argv) > 1:
@ -42,7 +45,7 @@ class TestCase:
self,
name: str,
query: str,
max_attempts: int = 20,
max_attempts: int = FAST_TEST,
mse_threshold: float = 0.001,
source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None,
@ -95,23 +98,23 @@ TEST_DATA = [
TestCase(
"img2img-sd-v1-5-512-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-256-muffin",
source="txt2img-sd-v1-5-256-muffin-0",
),
TestCase(
"inpaint-v1-512-white",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-white",
),
TestCase(
"inpaint-v1-512-black",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
),
TestCase(
@ -120,8 +123,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=256&bottom=256&left=256&right=256"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.025,
),
TestCase(
@ -130,8 +134,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=512&bottom=512&left=0&right=0"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
),
TestCase(
@ -140,27 +145,28 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=0&bottom=0&left=512&right=512"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
),
TestCase(
"upscale-resrgan-x2-1024-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"upscale-resrgan-x4-2048-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"blend-512-muffin-black",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-black",
source=[
"txt2img-sd-v1-5-512-muffin",
"txt2img-sd-v2-1-512-muffin",
"txt2img-sd-v1-5-512-muffin-0",
"txt2img-sd-v2-1-512-muffin-0",
],
),
TestCase(
@ -168,14 +174,14 @@ TEST_DATA = [
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-white",
source=[
"txt2img-sd-v2-1-512-muffin",
"txt2img-sd-v1-5-512-muffin",
"txt2img-sd-v2-1-512-muffin-0",
"txt2img-sd-v1-5-512-muffin-0",
],
),
]
def generate_image(root: str, test: TestCase) -> Optional[str]:
def generate_images(root: str, test: TestCase) -> Optional[str]:
files = {}
if test.source is not None:
if isinstance(test.source, list):
@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{root}/api/{test.query}", files=files)
if resp.status_code == 200:
json = resp.json()
return json.get("output")
return json.get("outputs")
else:
logger.warning("request failed: %s", resp.status_code)
logger.warning("request failed: %s: %s", resp.status_code, resp.text)
return None
@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
return False
def download_image(root: str, key: str) -> Image.Image:
resp = requests.get(f"{root}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
return Image.open(BytesIO(resp.content))
else:
logger.warning("request failed: %s", resp.status_code)
return None
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
images = []
for key in keys:
resp = requests.get(f"{root}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
images.append(Image.open(BytesIO(resp.content)))
else:
logger.warning("request failed: %s", resp.status_code)
return images
def find_mse(result: Image.Image, ref: Image.Image) -> float:
@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test(
root: str,
test: TestCase,
ref: Image.Image,
) -> bool:
"""
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""
key = generate_image(root, test)
if key is None:
keys = generate_images(root, test)
if keys is None:
raise ValueError("could not generate")
attempts = 0
while attempts < test.max_attempts:
if check_ready(root, key):
logger.debug("image is ready: %s", key)
if check_ready(root, keys[0]):
logger.debug("image is ready: %s", keys)
break
else:
logger.debug("waiting for image to be ready")
@ -282,16 +290,25 @@ def run_test(
if attempts == test.max_attempts:
raise ValueError("image was not ready in time")
result = download_image(root, key)
result.save(test_path(path.join("test-results", f"{test.name}.png")))
mse = find_mse(result, ref)
results = download_images(root, keys)
if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
return True
else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
return False
passed = True
for i in range(len(results)):
result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None
mse = find_mse(result, ref)
if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
passed = False
return passed
def main():
@ -303,9 +320,7 @@ def main():
for test in TEST_DATA:
try:
logger.info("starting test: %s", test.name)
ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None
if run_test(root, test, ref):
if run_test(root, test):
logger.info("test passed: %s", test.name)
passed.append(test.name)
else: