diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 366ccb17..30a87863 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -231,7 +231,10 @@ def load_pipeline( else: if "vae" in components: # upscale uses a single VAE - logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__) + logger.debug( + "assembling SD pipeline for %s with single VAE", + pipeline_class.__name__, + ) pipe = pipeline_class( components["vae"], components["text_encoder"], @@ -241,7 +244,10 @@ def load_pipeline( scheduler, ) else: - logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__) + logger.debug( + "assembling SD pipeline for %s with VAE codec", + pipeline_class.__name__, + ) pipe = pipeline_class( components["vae_encoder"], components["vae_decoder"], diff --git a/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png b/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png index 739d4544..4bfe6bd7 100644 --- a/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png +++ b/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5359014a0963adbd896832c8ca0e08a30deb0ec8306d7db5fcbd150d49aad04a -size 1450590 +oid sha256:613ce059320abadb89f4adf00546d45a20d504ea508106499ceca78df389515f +size 1469930 diff --git a/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png b/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png index 56761cfc..14d22eee 100644 --- a/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png +++ b/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:151c811488c933d9b858b70157da49ef626e2ce99207120fddc3791238e5a065 -size 1949395 +oid sha256:f6c4cd00f206bc3127c888dbb026edcef08fcc86be0487bf73abb77fd64bc419 +size 1680355 diff --git a/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png b/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png index f467648a..fdf38af0 100644 --- a/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png +++ b/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2311f9f2c273065a854636113f43fbbae642792e7c7f265c8bb3cbda0e308182 -size 1894231 +oid sha256:2c5486aebb193a2cfc155553934b33ad8fe224ef7fbf04b56dbefdea3ac14a30 +size 1584202 diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index dd8443d5..42806b55 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -69,6 +69,7 @@ TEST_DATA = [ TestCase( "txt2img-sd-v1-5-512-muffin-deis", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", + mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v1-5-512-muffin-dpm", @@ -346,6 +347,39 @@ class TestError(Exception): return super().__str__() +class TestResult: + error: Optional[str] + mse: Optional[float] + name: str + passed: bool + + def __init__(self, name: str, error = None, passed = True, mse = None) -> None: + self.error = error + self.mse = mse + self.name = name + self.passed = passed + + def __repr__(self) -> str: + if self.passed: + if self.mse is not None: + return f"{self.name} ({self.mse})" + else: + return self.name + else: + if self.mse is not None: + return f"{self.name}: {self.error} ({self.mse})" + else: + return f"{self.name}: {self.error}" + + @classmethod + def passed(self, name: str, mse = None): + return TestResult(name, mse=mse) + + @classmethod + def failed(self, name: str, error: str, mse = None): + return TestResult(name, error=error, mse=mse, passed=False) + + def parse_args(args: List[str]): parser = ArgumentParser( prog="onnx-web release tests", @@ -452,14 +486,14 @@ def run_test( host: str, test: TestCase, mse_mult: float = 1.0, -) -> bool: +) -> TestResult: """ Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ keys = generate_images(host, test) if keys is None: - raise ValueError("could not generate image") + return TestResult.failed(test.name, "could not generate image") ready = False for attempt in tqdm(range(test.max_attempts)): @@ -472,13 +506,13 @@ def run_test( sleep(6) if not ready: - raise ValueError("image was not ready in time") + return TestResult.failed(test.name, "image was not ready in time") results = download_images(host, keys) - if results is None: - raise ValueError("could not download image") + if results is None or len(results) == 0: + return TestResult.failed(test.name, "could not download image") - passed = True + passed = False for i in range(len(results)): result = results[i] result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) @@ -491,11 +525,15 @@ def run_test( if mse < threshold: logger.info("MSE within threshold: %.5f < %.5f", mse, threshold) + passed = True else: logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold) - passed = False + return TestResult.failed(test.name, error="MSE above threshold", mse=mse) - return passed + if passed: + return TestResult.passed(test.name) + else: + return TestResult.failed(test.name, "no images tested") def main(): @@ -516,24 +554,26 @@ def main(): passed = [] failed = [] for test in tests: - test_passed = False + result = None for _i in range(3): try: logger.info("starting test: %s", test.name) - if run_test(args.host, test, mse_mult=args.mse): + result = run_test(args.host, test, mse_mult=args.mse) + if result.passed: logger.info("test passed: %s", test.name) - test_passed = True break else: logger.warning("test failed: %s", test.name) except Exception: logger.exception("error running test for %s", test.name) + result = TestResult.failed(test.name, "TODO: exception message") - if test_passed: - passed.append(test.name) - else: - failed.append(test.name) + if result is not None: + if result.passed: + passed.append(result) + else: + failed.append(result) logger.info("%s of %s tests passed", len(passed), len(tests)) failed = list(set(failed))