diff --git a/api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin.png b/api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin-0.png similarity index 100% rename from api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin.png rename to api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin-0.png diff --git a/api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin.png b/api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin-0.png similarity index 100% rename from api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin.png rename to api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin-0.png diff --git a/api/scripts/test-refs/inpaint-v1-512-black.png b/api/scripts/test-refs/inpaint-v1-512-black-0.png similarity index 100% rename from api/scripts/test-refs/inpaint-v1-512-black.png rename to api/scripts/test-refs/inpaint-v1-512-black-0.png diff --git a/api/scripts/test-refs/inpaint-v1-512-white.png b/api/scripts/test-refs/inpaint-v1-512-white-0.png similarity index 100% rename from api/scripts/test-refs/inpaint-v1-512-white.png rename to api/scripts/test-refs/inpaint-v1-512-white-0.png diff --git a/api/scripts/test-refs/outpaint-even-256.png b/api/scripts/test-refs/outpaint-even-256-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-even-256.png rename to api/scripts/test-refs/outpaint-even-256-0.png diff --git a/api/scripts/test-refs/outpaint-horizontal-512.png b/api/scripts/test-refs/outpaint-horizontal-512-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-horizontal-512.png rename to api/scripts/test-refs/outpaint-horizontal-512-0.png diff --git a/api/scripts/test-refs/outpaint-vertical-512.png b/api/scripts/test-refs/outpaint-vertical-512-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-vertical-512.png rename to api/scripts/test-refs/outpaint-vertical-512-0.png diff --git a/api/scripts/test-refs/txt2img-knollingcase-512-muffin.png b/api/scripts/test-refs/txt2img-knollingcase-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-knollingcase-512-muffin.png rename to api/scripts/test-refs/txt2img-knollingcase-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-openjourney-512-muffin.png b/api/scripts/test-refs/txt2img-openjourney-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-openjourney-512-muffin.png rename to api/scripts/test-refs/txt2img-openjourney-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png rename to api/scripts/test-refs/txt2img-sd-v1-5-256-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png rename to api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png rename to api/scripts/test-refs/txt2img-sd-v2-1-768-muffin-0.png diff --git a/api/scripts/test-refs/upscale-resrgan-x2-1024-muffin.png b/api/scripts/test-refs/upscale-resrgan-x2-1024-muffin-0.png similarity index 100% rename from api/scripts/test-refs/upscale-resrgan-x2-1024-muffin.png rename to api/scripts/test-refs/upscale-resrgan-x2-1024-muffin-0.png diff --git a/api/scripts/test-refs/upscale-resrgan-x4-2048-muffin.png b/api/scripts/test-refs/upscale-resrgan-x4-2048-muffin-0.png similarity index 100% rename from api/scripts/test-refs/upscale-resrgan-x4-2048-muffin.png rename to api/scripts/test-refs/upscale-resrgan-x4-2048-muffin-0.png diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index ff31397c..65572083 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -25,6 +25,9 @@ except Exception as err: logger = getLogger(__name__) +FAST_TEST = 20 +SLOW_TEST = 50 + def test_root() -> str: if len(sys.argv) > 1: @@ -42,7 +45,7 @@ class TestCase: self, name: str, query: str, - max_attempts: int = 20, + max_attempts: int = FAST_TEST, mse_threshold: float = 0.001, source: Union[Image.Image, List[Image.Image]] = None, mask: Image.Image = None, @@ -95,23 +98,23 @@ TEST_DATA = [ TestCase( "img2img-sd-v1-5-512-pumpkin", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "img2img-sd-v1-5-256-pumpkin", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", - source="txt2img-sd-v1-5-256-muffin", + source="txt2img-sd-v1-5-256-muffin-0", ), TestCase( "inpaint-v1-512-white", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-white", ), TestCase( "inpaint-v1-512-black", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", ), TestCase( @@ -120,8 +123,9 @@ TEST_DATA = [ "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=256&bottom=256&left=256&right=256" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.025, ), TestCase( @@ -130,8 +134,9 @@ TEST_DATA = [ "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=512&bottom=512&left=0&right=0" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.010, ), TestCase( @@ -140,27 +145,28 @@ TEST_DATA = [ "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=0&bottom=0&left=512&right=512" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.010, ), TestCase( "upscale-resrgan-x2-1024-muffin", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-resrgan-x4-2048-muffin", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "blend-512-muffin-black", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", mask="mask-black", source=[ - "txt2img-sd-v1-5-512-muffin", - "txt2img-sd-v2-1-512-muffin", + "txt2img-sd-v1-5-512-muffin-0", + "txt2img-sd-v2-1-512-muffin-0", ], ), TestCase( @@ -168,14 +174,14 @@ TEST_DATA = [ "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", mask="mask-white", source=[ - "txt2img-sd-v2-1-512-muffin", - "txt2img-sd-v1-5-512-muffin", + "txt2img-sd-v2-1-512-muffin-0", + "txt2img-sd-v1-5-512-muffin-0", ], ), ] -def generate_image(root: str, test: TestCase) -> Optional[str]: +def generate_images(root: str, test: TestCase) -> Optional[str]: files = {} if test.source is not None: if isinstance(test.source, list): @@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]: resp = requests.post(f"{root}/api/{test.query}", files=files) if resp.status_code == 200: json = resp.json() - return json.get("output") + return json.get("outputs") else: - logger.warning("request failed: %s", resp.status_code) + logger.warning("request failed: %s: %s", resp.status_code, resp.text) return None @@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool: return False -def download_image(root: str, key: str) -> Image.Image: - resp = requests.get(f"{root}/output/{key}") - if resp.status_code == 200: - logger.debug("downloading image: %s", key) - return Image.open(BytesIO(resp.content)) - else: - logger.warning("request failed: %s", resp.status_code) - return None +def download_images(root: str, keys: List[str]) -> List[Image.Image]: + images = [] + for key in keys: + resp = requests.get(f"{root}/output/{key}") + if resp.status_code == 200: + logger.debug("downloading image: %s", key) + images.append(Image.open(BytesIO(resp.content))) + else: + logger.warning("request failed: %s", resp.status_code) + + return images def find_mse(result: Image.Image, ref: Image.Image) -> float: @@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float: def run_test( root: str, test: TestCase, - ref: Image.Image, ) -> bool: """ Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ - key = generate_image(root, test) - if key is None: + keys = generate_images(root, test) + if keys is None: raise ValueError("could not generate") attempts = 0 while attempts < test.max_attempts: - if check_ready(root, key): - logger.debug("image is ready: %s", key) + if check_ready(root, keys[0]): + logger.debug("image is ready: %s", keys) break else: logger.debug("waiting for image to be ready") @@ -282,16 +290,25 @@ def run_test( if attempts == test.max_attempts: raise ValueError("image was not ready in time") - result = download_image(root, key) - result.save(test_path(path.join("test-results", f"{test.name}.png"))) - mse = find_mse(result, ref) + results = download_images(root, keys) - if mse < test.mse_threshold: - logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) - return True - else: - logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) - return False + passed = True + for i in range(len(results)): + result = results[i] + result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) + + ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png")) + ref = Image.open(ref_name) if path.exists(ref_name) else None + + mse = find_mse(result, ref) + + if mse < test.mse_threshold: + logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) + else: + logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) + passed = False + + return passed def main(): @@ -303,9 +320,7 @@ def main(): for test in TEST_DATA: try: logger.info("starting test: %s", test.name) - ref_name = test_path(path.join("test-refs", f"{test.name}.png")) - ref = Image.open(ref_name) if path.exists(ref_name) else None - if run_test(root, test, ref): + if run_test(root, test): logger.info("test passed: %s", test.name) passed.append(test.name) else: