improve some logs, update test refs for panorama
This commit is contained in:
parent
92311281df
commit
7c3e9c22d0
|
@ -231,7 +231,10 @@ def load_pipeline(
|
||||||
else:
|
else:
|
||||||
if "vae" in components:
|
if "vae" in components:
|
||||||
# upscale uses a single VAE
|
# upscale uses a single VAE
|
||||||
logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__)
|
logger.debug(
|
||||||
|
"assembling SD pipeline for %s with single VAE",
|
||||||
|
pipeline_class.__name__,
|
||||||
|
)
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
components["vae"],
|
components["vae"],
|
||||||
components["text_encoder"],
|
components["text_encoder"],
|
||||||
|
@ -241,7 +244,10 @@ def load_pipeline(
|
||||||
scheduler,
|
scheduler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__)
|
logger.debug(
|
||||||
|
"assembling SD pipeline for %s with VAE codec",
|
||||||
|
pipeline_class.__name__,
|
||||||
|
)
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
components["vae_encoder"],
|
components["vae_encoder"],
|
||||||
components["vae_decoder"],
|
components["vae_decoder"],
|
||||||
|
|
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
|
@ -69,6 +69,7 @@ TEST_DATA = [
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-deis",
|
"txt2img-sd-v1-5-512-muffin-deis",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
||||||
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-dpm",
|
"txt2img-sd-v1-5-512-muffin-dpm",
|
||||||
|
@ -346,6 +347,39 @@ class TestError(Exception):
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestResult:
|
||||||
|
error: Optional[str]
|
||||||
|
mse: Optional[float]
|
||||||
|
name: str
|
||||||
|
passed: bool
|
||||||
|
|
||||||
|
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
|
||||||
|
self.error = error
|
||||||
|
self.mse = mse
|
||||||
|
self.name = name
|
||||||
|
self.passed = passed
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.passed:
|
||||||
|
if self.mse is not None:
|
||||||
|
return f"{self.name} ({self.mse})"
|
||||||
|
else:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
if self.mse is not None:
|
||||||
|
return f"{self.name}: {self.error} ({self.mse})"
|
||||||
|
else:
|
||||||
|
return f"{self.name}: {self.error}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def passed(self, name: str, mse = None):
|
||||||
|
return TestResult(name, mse=mse)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failed(self, name: str, error: str, mse = None):
|
||||||
|
return TestResult(name, error=error, mse=mse, passed=False)
|
||||||
|
|
||||||
|
|
||||||
def parse_args(args: List[str]):
|
def parse_args(args: List[str]):
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog="onnx-web release tests",
|
prog="onnx-web release tests",
|
||||||
|
@ -452,14 +486,14 @@ def run_test(
|
||||||
host: str,
|
host: str,
|
||||||
test: TestCase,
|
test: TestCase,
|
||||||
mse_mult: float = 1.0,
|
mse_mult: float = 1.0,
|
||||||
) -> bool:
|
) -> TestResult:
|
||||||
"""
|
"""
|
||||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keys = generate_images(host, test)
|
keys = generate_images(host, test)
|
||||||
if keys is None:
|
if keys is None:
|
||||||
raise ValueError("could not generate image")
|
return TestResult.failed(test.name, "could not generate image")
|
||||||
|
|
||||||
ready = False
|
ready = False
|
||||||
for attempt in tqdm(range(test.max_attempts)):
|
for attempt in tqdm(range(test.max_attempts)):
|
||||||
|
@ -472,13 +506,13 @@ def run_test(
|
||||||
sleep(6)
|
sleep(6)
|
||||||
|
|
||||||
if not ready:
|
if not ready:
|
||||||
raise ValueError("image was not ready in time")
|
return TestResult.failed(test.name, "image was not ready in time")
|
||||||
|
|
||||||
results = download_images(host, keys)
|
results = download_images(host, keys)
|
||||||
if results is None:
|
if results is None or len(results) == 0:
|
||||||
raise ValueError("could not download image")
|
return TestResult.failed(test.name, "could not download image")
|
||||||
|
|
||||||
passed = True
|
passed = False
|
||||||
for i in range(len(results)):
|
for i in range(len(results)):
|
||||||
result = results[i]
|
result = results[i]
|
||||||
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||||
|
@ -491,11 +525,15 @@ def run_test(
|
||||||
|
|
||||||
if mse < threshold:
|
if mse < threshold:
|
||||||
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
|
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
|
||||||
|
passed = True
|
||||||
else:
|
else:
|
||||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
|
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
|
||||||
passed = False
|
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
|
||||||
|
|
||||||
return passed
|
if passed:
|
||||||
|
return TestResult.passed(test.name)
|
||||||
|
else:
|
||||||
|
return TestResult.failed(test.name, "no images tested")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -516,24 +554,26 @@ def main():
|
||||||
passed = []
|
passed = []
|
||||||
failed = []
|
failed = []
|
||||||
for test in tests:
|
for test in tests:
|
||||||
test_passed = False
|
result = None
|
||||||
|
|
||||||
for _i in range(3):
|
for _i in range(3):
|
||||||
try:
|
try:
|
||||||
logger.info("starting test: %s", test.name)
|
logger.info("starting test: %s", test.name)
|
||||||
if run_test(args.host, test, mse_mult=args.mse):
|
result = run_test(args.host, test, mse_mult=args.mse)
|
||||||
|
if result.passed:
|
||||||
logger.info("test passed: %s", test.name)
|
logger.info("test passed: %s", test.name)
|
||||||
test_passed = True
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("test failed: %s", test.name)
|
logger.warning("test failed: %s", test.name)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error running test for %s", test.name)
|
logger.exception("error running test for %s", test.name)
|
||||||
|
result = TestResult.failed(test.name, "TODO: exception message")
|
||||||
|
|
||||||
if test_passed:
|
if result is not None:
|
||||||
passed.append(test.name)
|
if result.passed:
|
||||||
|
passed.append(result)
|
||||||
else:
|
else:
|
||||||
failed.append(test.name)
|
failed.append(result)
|
||||||
|
|
||||||
logger.info("%s of %s tests passed", len(passed), len(tests))
|
logger.info("%s of %s tests passed", len(passed), len(tests))
|
||||||
failed = list(set(failed))
|
failed = list(set(failed))
|
||||||
|
|
Loading…
Reference in New Issue