1
0
Fork 0

fix(scripts): update release tests with support for batches

This commit is contained in:
Sean Sube 2023-02-25 18:01:05 -06:00
parent cb8e9e7080
commit 6809d2da82
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 57 additions and 42 deletions

View File

@ -25,6 +25,9 @@ except Exception as err:
logger = getLogger(__name__) logger = getLogger(__name__)
FAST_TEST = 20
SLOW_TEST = 50
def test_root() -> str: def test_root() -> str:
if len(sys.argv) > 1: if len(sys.argv) > 1:
@ -42,7 +45,7 @@ class TestCase:
self, self,
name: str, name: str,
query: str, query: str,
max_attempts: int = 20, max_attempts: int = FAST_TEST,
mse_threshold: float = 0.001, mse_threshold: float = 0.001,
source: Union[Image.Image, List[Image.Image]] = None, source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None, mask: Image.Image = None,
@ -95,23 +98,23 @@ TEST_DATA = [
TestCase( TestCase(
"img2img-sd-v1-5-512-pumpkin", "img2img-sd-v1-5-512-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"img2img-sd-v1-5-256-pumpkin", "img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-256-muffin", source="txt2img-sd-v1-5-256-muffin-0",
), ),
TestCase( TestCase(
"inpaint-v1-512-white", "inpaint-v1-512-white",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-white", mask="mask-white",
), ),
TestCase( TestCase(
"inpaint-v1-512-black", "inpaint-v1-512-black",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
), ),
TestCase( TestCase(
@ -120,8 +123,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=256&bottom=256&left=256&right=256" "&top=256&bottom=256&left=256&right=256"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.025, mse_threshold=0.025,
), ),
TestCase( TestCase(
@ -130,8 +134,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=512&bottom=512&left=0&right=0" "&top=512&bottom=512&left=0&right=0"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=0.010,
), ),
TestCase( TestCase(
@ -140,27 +145,28 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=0&bottom=0&left=512&right=512" "&top=0&bottom=0&left=512&right=512"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=0.010,
), ),
TestCase( TestCase(
"upscale-resrgan-x2-1024-muffin", "upscale-resrgan-x2-1024-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"upscale-resrgan-x4-2048-muffin", "upscale-resrgan-x4-2048-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"blend-512-muffin-black", "blend-512-muffin-black",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-black", mask="mask-black",
source=[ source=[
"txt2img-sd-v1-5-512-muffin", "txt2img-sd-v1-5-512-muffin-0",
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin-0",
], ],
), ),
TestCase( TestCase(
@ -168,14 +174,14 @@ TEST_DATA = [
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-white", mask="mask-white",
source=[ source=[
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin-0",
"txt2img-sd-v1-5-512-muffin", "txt2img-sd-v1-5-512-muffin-0",
], ],
), ),
] ]
def generate_image(root: str, test: TestCase) -> Optional[str]: def generate_images(root: str, test: TestCase) -> Optional[str]:
files = {} files = {}
if test.source is not None: if test.source is not None:
if isinstance(test.source, list): if isinstance(test.source, list):
@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{root}/api/{test.query}", files=files) resp = requests.post(f"{root}/api/{test.query}", files=files)
if resp.status_code == 200: if resp.status_code == 200:
json = resp.json() json = resp.json()
return json.get("output") return json.get("outputs")
else: else:
logger.warning("request failed: %s", resp.status_code) logger.warning("request failed: %s: %s", resp.status_code, resp.text)
return None return None
@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
return False return False
def download_image(root: str, key: str) -> Image.Image: def download_images(root: str, keys: List[str]) -> List[Image.Image]:
images = []
for key in keys:
resp = requests.get(f"{root}/output/{key}") resp = requests.get(f"{root}/output/{key}")
if resp.status_code == 200: if resp.status_code == 200:
logger.debug("downloading image: %s", key) logger.debug("downloading image: %s", key)
return Image.open(BytesIO(resp.content)) images.append(Image.open(BytesIO(resp.content)))
else: else:
logger.warning("request failed: %s", resp.status_code) logger.warning("request failed: %s", resp.status_code)
return None
return images
def find_mse(result: Image.Image, ref: Image.Image) -> float: def find_mse(result: Image.Image, ref: Image.Image) -> float:
@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test( def run_test(
root: str, root: str,
test: TestCase, test: TestCase,
ref: Image.Image,
) -> bool: ) -> bool:
""" """
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
key = generate_image(root, test) keys = generate_images(root, test)
if key is None: if keys is None:
raise ValueError("could not generate") raise ValueError("could not generate")
attempts = 0 attempts = 0
while attempts < test.max_attempts: while attempts < test.max_attempts:
if check_ready(root, key): if check_ready(root, keys[0]):
logger.debug("image is ready: %s", key) logger.debug("image is ready: %s", keys)
break break
else: else:
logger.debug("waiting for image to be ready") logger.debug("waiting for image to be ready")
@ -282,16 +290,25 @@ def run_test(
if attempts == test.max_attempts: if attempts == test.max_attempts:
raise ValueError("image was not ready in time") raise ValueError("image was not ready in time")
result = download_image(root, key) results = download_images(root, keys)
result.save(test_path(path.join("test-results", f"{test.name}.png")))
passed = True
for i in range(len(results)):
result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None
mse = find_mse(result, ref) mse = find_mse(result, ref)
if mse < test.mse_threshold: if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
return True
else: else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
return False passed = False
return passed
def main(): def main():
@ -303,9 +320,7 @@ def main():
for test in TEST_DATA: for test in TEST_DATA:
try: try:
logger.info("starting test: %s", test.name) logger.info("starting test: %s", test.name)
ref_name = test_path(path.join("test-refs", f"{test.name}.png")) if run_test(root, test):
ref = Image.open(ref_name) if path.exists(ref_name) else None
if run_test(root, test, ref):
logger.info("test passed: %s", test.name) logger.info("test passed: %s", test.name)
passed.append(test.name) passed.append(test.name)
else: else: