1
0
Fork 0

improve some logs, update test refs for panorama

This commit is contained in:
Sean Sube 2023-12-03 15:34:34 -06:00
parent 92311281df
commit 7c3e9c22d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 69 additions and 23 deletions

View File

@ -231,7 +231,10 @@ def load_pipeline(
else:
if "vae" in components:
# upscale uses a single VAE
logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__)
logger.debug(
"assembling SD pipeline for %s with single VAE",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae"],
components["text_encoder"],
@ -241,7 +244,10 @@ def load_pipeline(
scheduler,
)
else:
logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__)
logger.debug(
"assembling SD pipeline for %s with VAE codec",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae_encoder"],
components["vae_decoder"],

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -69,6 +69,7 @@ TEST_DATA = [
TestCase(
"txt2img-sd-v1-5-512-muffin-deis",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
mse_threshold=LOOSE_TEST,
),
TestCase(
"txt2img-sd-v1-5-512-muffin-dpm",
@ -346,6 +347,39 @@ class TestError(Exception):
return super().__str__()
class TestResult:
error: Optional[str]
mse: Optional[float]
name: str
passed: bool
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
self.error = error
self.mse = mse
self.name = name
self.passed = passed
def __repr__(self) -> str:
if self.passed:
if self.mse is not None:
return f"{self.name} ({self.mse})"
else:
return self.name
else:
if self.mse is not None:
return f"{self.name}: {self.error} ({self.mse})"
else:
return f"{self.name}: {self.error}"
@classmethod
def passed(self, name: str, mse = None):
return TestResult(name, mse=mse)
@classmethod
def failed(self, name: str, error: str, mse = None):
return TestResult(name, error=error, mse=mse, passed=False)
def parse_args(args: List[str]):
parser = ArgumentParser(
prog="onnx-web release tests",
@ -452,14 +486,14 @@ def run_test(
host: str,
test: TestCase,
mse_mult: float = 1.0,
) -> bool:
) -> TestResult:
"""
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""
keys = generate_images(host, test)
if keys is None:
raise ValueError("could not generate image")
return TestResult.failed(test.name, "could not generate image")
ready = False
for attempt in tqdm(range(test.max_attempts)):
@ -472,13 +506,13 @@ def run_test(
sleep(6)
if not ready:
raise ValueError("image was not ready in time")
return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys)
if results is None:
raise ValueError("could not download image")
if results is None or len(results) == 0:
return TestResult.failed(test.name, "could not download image")
passed = True
passed = False
for i in range(len(results)):
result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
@ -491,11 +525,15 @@ def run_test(
if mse < threshold:
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
passed = True
else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
passed = False
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
return passed
if passed:
return TestResult.passed(test.name)
else:
return TestResult.failed(test.name, "no images tested")
def main():
@ -516,24 +554,26 @@ def main():
passed = []
failed = []
for test in tests:
test_passed = False
result = None
for _i in range(3):
try:
logger.info("starting test: %s", test.name)
if run_test(args.host, test, mse_mult=args.mse):
result = run_test(args.host, test, mse_mult=args.mse)
if result.passed:
logger.info("test passed: %s", test.name)
test_passed = True
break
else:
logger.warning("test failed: %s", test.name)
except Exception:
logger.exception("error running test for %s", test.name)
result = TestResult.failed(test.name, "TODO: exception message")
if test_passed:
passed.append(test.name)
else:
failed.append(test.name)
if result is not None:
if result.passed:
passed.append(result)
else:
failed.append(result)
logger.info("%s of %s tests passed", len(passed), len(tests))
failed = list(set(failed))