1
0
Fork 0

improve some logs, update test refs for panorama

This commit is contained in:
Sean Sube 2023-12-03 15:34:34 -06:00
parent 92311281df
commit 7c3e9c22d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 69 additions and 23 deletions

View File

@ -231,7 +231,10 @@ def load_pipeline(
else: else:
if "vae" in components: if "vae" in components:
# upscale uses a single VAE # upscale uses a single VAE
logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__) logger.debug(
"assembling SD pipeline for %s with single VAE",
pipeline_class.__name__,
)
pipe = pipeline_class( pipe = pipeline_class(
components["vae"], components["vae"],
components["text_encoder"], components["text_encoder"],
@ -241,7 +244,10 @@ def load_pipeline(
scheduler, scheduler,
) )
else: else:
logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__) logger.debug(
"assembling SD pipeline for %s with VAE codec",
pipeline_class.__name__,
)
pipe = pipeline_class( pipe = pipeline_class(
components["vae_encoder"], components["vae_encoder"],
components["vae_decoder"], components["vae_decoder"],

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -69,6 +69,7 @@ TEST_DATA = [
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-deis", "txt2img-sd-v1-5-512-muffin-deis",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-dpm", "txt2img-sd-v1-5-512-muffin-dpm",
@ -346,6 +347,39 @@ class TestError(Exception):
return super().__str__() return super().__str__()
class TestResult:
error: Optional[str]
mse: Optional[float]
name: str
passed: bool
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
self.error = error
self.mse = mse
self.name = name
self.passed = passed
def __repr__(self) -> str:
if self.passed:
if self.mse is not None:
return f"{self.name} ({self.mse})"
else:
return self.name
else:
if self.mse is not None:
return f"{self.name}: {self.error} ({self.mse})"
else:
return f"{self.name}: {self.error}"
@classmethod
def passed(self, name: str, mse = None):
return TestResult(name, mse=mse)
@classmethod
def failed(self, name: str, error: str, mse = None):
return TestResult(name, error=error, mse=mse, passed=False)
def parse_args(args: List[str]): def parse_args(args: List[str]):
parser = ArgumentParser( parser = ArgumentParser(
prog="onnx-web release tests", prog="onnx-web release tests",
@ -452,14 +486,14 @@ def run_test(
host: str, host: str,
test: TestCase, test: TestCase,
mse_mult: float = 1.0, mse_mult: float = 1.0,
) -> bool: ) -> TestResult:
""" """
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
keys = generate_images(host, test) keys = generate_images(host, test)
if keys is None: if keys is None:
raise ValueError("could not generate image") return TestResult.failed(test.name, "could not generate image")
ready = False ready = False
for attempt in tqdm(range(test.max_attempts)): for attempt in tqdm(range(test.max_attempts)):
@ -472,13 +506,13 @@ def run_test(
sleep(6) sleep(6)
if not ready: if not ready:
raise ValueError("image was not ready in time") return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys) results = download_images(host, keys)
if results is None: if results is None or len(results) == 0:
raise ValueError("could not download image") return TestResult.failed(test.name, "could not download image")
passed = True passed = False
for i in range(len(results)): for i in range(len(results)):
result = results[i] result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
@ -491,11 +525,15 @@ def run_test(
if mse < threshold: if mse < threshold:
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold) logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
passed = True
else: else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold) logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
passed = False return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
return passed if passed:
return TestResult.passed(test.name)
else:
return TestResult.failed(test.name, "no images tested")
def main(): def main():
@ -516,24 +554,26 @@ def main():
passed = [] passed = []
failed = [] failed = []
for test in tests: for test in tests:
test_passed = False result = None
for _i in range(3): for _i in range(3):
try: try:
logger.info("starting test: %s", test.name) logger.info("starting test: %s", test.name)
if run_test(args.host, test, mse_mult=args.mse): result = run_test(args.host, test, mse_mult=args.mse)
if result.passed:
logger.info("test passed: %s", test.name) logger.info("test passed: %s", test.name)
test_passed = True
break break
else: else:
logger.warning("test failed: %s", test.name) logger.warning("test failed: %s", test.name)
except Exception: except Exception:
logger.exception("error running test for %s", test.name) logger.exception("error running test for %s", test.name)
result = TestResult.failed(test.name, "TODO: exception message")
if test_passed: if result is not None:
passed.append(test.name) if result.passed:
passed.append(result)
else: else:
failed.append(test.name) failed.append(result)
logger.info("%s of %s tests passed", len(passed), len(tests)) logger.info("%s of %s tests passed", len(passed), len(tests))
failed = list(set(failed)) failed = list(set(failed))