improve some logs, update test refs for panorama
This commit is contained in:
parent
92311281df
commit
7c3e9c22d0
|
@ -231,7 +231,10 @@ def load_pipeline(
|
|||
else:
|
||||
if "vae" in components:
|
||||
# upscale uses a single VAE
|
||||
logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__)
|
||||
logger.debug(
|
||||
"assembling SD pipeline for %s with single VAE",
|
||||
pipeline_class.__name__,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
components["vae"],
|
||||
components["text_encoder"],
|
||||
|
@ -241,7 +244,10 @@ def load_pipeline(
|
|||
scheduler,
|
||||
)
|
||||
else:
|
||||
logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__)
|
||||
logger.debug(
|
||||
"assembling SD pipeline for %s with VAE codec",
|
||||
pipeline_class.__name__,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
components["vae_encoder"],
|
||||
components["vae_decoder"],
|
||||
|
|
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
|
@ -69,6 +69,7 @@ TEST_DATA = [
|
|||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-deis",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-dpm",
|
||||
|
@ -346,6 +347,39 @@ class TestError(Exception):
|
|||
return super().__str__()
|
||||
|
||||
|
||||
class TestResult:
|
||||
error: Optional[str]
|
||||
mse: Optional[float]
|
||||
name: str
|
||||
passed: bool
|
||||
|
||||
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
|
||||
self.error = error
|
||||
self.mse = mse
|
||||
self.name = name
|
||||
self.passed = passed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.passed:
|
||||
if self.mse is not None:
|
||||
return f"{self.name} ({self.mse})"
|
||||
else:
|
||||
return self.name
|
||||
else:
|
||||
if self.mse is not None:
|
||||
return f"{self.name}: {self.error} ({self.mse})"
|
||||
else:
|
||||
return f"{self.name}: {self.error}"
|
||||
|
||||
@classmethod
|
||||
def passed(self, name: str, mse = None):
|
||||
return TestResult(name, mse=mse)
|
||||
|
||||
@classmethod
|
||||
def failed(self, name: str, error: str, mse = None):
|
||||
return TestResult(name, error=error, mse=mse, passed=False)
|
||||
|
||||
|
||||
def parse_args(args: List[str]):
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web release tests",
|
||||
|
@ -452,14 +486,14 @@ def run_test(
|
|||
host: str,
|
||||
test: TestCase,
|
||||
mse_mult: float = 1.0,
|
||||
) -> bool:
|
||||
) -> TestResult:
|
||||
"""
|
||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||
"""
|
||||
|
||||
keys = generate_images(host, test)
|
||||
if keys is None:
|
||||
raise ValueError("could not generate image")
|
||||
return TestResult.failed(test.name, "could not generate image")
|
||||
|
||||
ready = False
|
||||
for attempt in tqdm(range(test.max_attempts)):
|
||||
|
@ -472,13 +506,13 @@ def run_test(
|
|||
sleep(6)
|
||||
|
||||
if not ready:
|
||||
raise ValueError("image was not ready in time")
|
||||
return TestResult.failed(test.name, "image was not ready in time")
|
||||
|
||||
results = download_images(host, keys)
|
||||
if results is None:
|
||||
raise ValueError("could not download image")
|
||||
if results is None or len(results) == 0:
|
||||
return TestResult.failed(test.name, "could not download image")
|
||||
|
||||
passed = True
|
||||
passed = False
|
||||
for i in range(len(results)):
|
||||
result = results[i]
|
||||
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||
|
@ -491,11 +525,15 @@ def run_test(
|
|||
|
||||
if mse < threshold:
|
||||
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
|
||||
passed = True
|
||||
else:
|
||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
|
||||
passed = False
|
||||
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
|
||||
|
||||
return passed
|
||||
if passed:
|
||||
return TestResult.passed(test.name)
|
||||
else:
|
||||
return TestResult.failed(test.name, "no images tested")
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -516,24 +554,26 @@ def main():
|
|||
passed = []
|
||||
failed = []
|
||||
for test in tests:
|
||||
test_passed = False
|
||||
result = None
|
||||
|
||||
for _i in range(3):
|
||||
try:
|
||||
logger.info("starting test: %s", test.name)
|
||||
if run_test(args.host, test, mse_mult=args.mse):
|
||||
result = run_test(args.host, test, mse_mult=args.mse)
|
||||
if result.passed:
|
||||
logger.info("test passed: %s", test.name)
|
||||
test_passed = True
|
||||
break
|
||||
else:
|
||||
logger.warning("test failed: %s", test.name)
|
||||
except Exception:
|
||||
logger.exception("error running test for %s", test.name)
|
||||
result = TestResult.failed(test.name, "TODO: exception message")
|
||||
|
||||
if test_passed:
|
||||
passed.append(test.name)
|
||||
else:
|
||||
failed.append(test.name)
|
||||
if result is not None:
|
||||
if result.passed:
|
||||
passed.append(result)
|
||||
else:
|
||||
failed.append(result)
|
||||
|
||||
logger.info("%s of %s tests passed", len(passed), len(tests))
|
||||
failed = list(set(failed))
|
||||
|
|
Loading…
Reference in New Issue