1
0
Fork 0

fix(scripts): update release tests with support for batches

This commit is contained in:
Sean Sube 2023-02-25 18:01:05 -06:00
parent cb8e9e7080
commit 6809d2da82
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 57 additions and 42 deletions

View File

@ -25,6 +25,9 @@ except Exception as err:
logger = getLogger(__name__) logger = getLogger(__name__)
FAST_TEST = 20
SLOW_TEST = 50
def test_root() -> str: def test_root() -> str:
if len(sys.argv) > 1: if len(sys.argv) > 1:
@ -42,7 +45,7 @@ class TestCase:
self, self,
name: str, name: str,
query: str, query: str,
max_attempts: int = 20, max_attempts: int = FAST_TEST,
mse_threshold: float = 0.001, mse_threshold: float = 0.001,
source: Union[Image.Image, List[Image.Image]] = None, source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None, mask: Image.Image = None,
@ -95,23 +98,23 @@ TEST_DATA = [
TestCase( TestCase(
"img2img-sd-v1-5-512-pumpkin", "img2img-sd-v1-5-512-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"img2img-sd-v1-5-256-pumpkin", "img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-256-muffin", source="txt2img-sd-v1-5-256-muffin-0",
), ),
TestCase( TestCase(
"inpaint-v1-512-white", "inpaint-v1-512-white",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-white", mask="mask-white",
), ),
TestCase( TestCase(
"inpaint-v1-512-black", "inpaint-v1-512-black",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
), ),
TestCase( TestCase(
@ -120,8 +123,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=256&bottom=256&left=256&right=256" "&top=256&bottom=256&left=256&right=256"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.025, mse_threshold=0.025,
), ),
TestCase( TestCase(
@ -130,8 +134,9 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=512&bottom=512&left=0&right=0" "&top=512&bottom=512&left=0&right=0"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=0.010,
), ),
TestCase( TestCase(
@ -140,27 +145,28 @@ TEST_DATA = [
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=0&bottom=0&left=512&right=512" "&top=0&bottom=0&left=512&right=512"
), ),
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=0.010,
), ),
TestCase( TestCase(
"upscale-resrgan-x2-1024-muffin", "upscale-resrgan-x2-1024-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"upscale-resrgan-x4-2048-muffin", "upscale-resrgan-x4-2048-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
source="txt2img-sd-v1-5-512-muffin", source="txt2img-sd-v1-5-512-muffin-0",
), ),
TestCase( TestCase(
"blend-512-muffin-black", "blend-512-muffin-black",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-black", mask="mask-black",
source=[ source=[
"txt2img-sd-v1-5-512-muffin", "txt2img-sd-v1-5-512-muffin-0",
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin-0",
], ],
), ),
TestCase( TestCase(
@ -168,14 +174,14 @@ TEST_DATA = [
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-white", mask="mask-white",
source=[ source=[
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin-0",
"txt2img-sd-v1-5-512-muffin", "txt2img-sd-v1-5-512-muffin-0",
], ],
), ),
] ]
def generate_image(root: str, test: TestCase) -> Optional[str]: def generate_images(root: str, test: TestCase) -> Optional[str]:
files = {} files = {}
if test.source is not None: if test.source is not None:
if isinstance(test.source, list): if isinstance(test.source, list):
@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{root}/api/{test.query}", files=files) resp = requests.post(f"{root}/api/{test.query}", files=files)
if resp.status_code == 200: if resp.status_code == 200:
json = resp.json() json = resp.json()
return json.get("output") return json.get("outputs")
else: else:
logger.warning("request failed: %s", resp.status_code) logger.warning("request failed: %s: %s", resp.status_code, resp.text)
return None return None
@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
return False return False
def download_image(root: str, key: str) -> Image.Image: def download_images(root: str, keys: List[str]) -> List[Image.Image]:
resp = requests.get(f"{root}/output/{key}") images = []
if resp.status_code == 200: for key in keys:
logger.debug("downloading image: %s", key) resp = requests.get(f"{root}/output/{key}")
return Image.open(BytesIO(resp.content)) if resp.status_code == 200:
else: logger.debug("downloading image: %s", key)
logger.warning("request failed: %s", resp.status_code) images.append(Image.open(BytesIO(resp.content)))
return None else:
logger.warning("request failed: %s", resp.status_code)
return images
def find_mse(result: Image.Image, ref: Image.Image) -> float: def find_mse(result: Image.Image, ref: Image.Image) -> float:
@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test( def run_test(
root: str, root: str,
test: TestCase, test: TestCase,
ref: Image.Image,
) -> bool: ) -> bool:
""" """
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
key = generate_image(root, test) keys = generate_images(root, test)
if key is None: if keys is None:
raise ValueError("could not generate") raise ValueError("could not generate")
attempts = 0 attempts = 0
while attempts < test.max_attempts: while attempts < test.max_attempts:
if check_ready(root, key): if check_ready(root, keys[0]):
logger.debug("image is ready: %s", key) logger.debug("image is ready: %s", keys)
break break
else: else:
logger.debug("waiting for image to be ready") logger.debug("waiting for image to be ready")
@ -282,16 +290,25 @@ def run_test(
if attempts == test.max_attempts: if attempts == test.max_attempts:
raise ValueError("image was not ready in time") raise ValueError("image was not ready in time")
result = download_image(root, key) results = download_images(root, keys)
result.save(test_path(path.join("test-results", f"{test.name}.png")))
mse = find_mse(result, ref)
if mse < test.mse_threshold: passed = True
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) for i in range(len(results)):
return True result = results[i]
else: result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
return False ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None
mse = find_mse(result, ref)
if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
passed = False
return passed
def main(): def main():
@ -303,9 +320,7 @@ def main():
for test in TEST_DATA: for test in TEST_DATA:
try: try:
logger.info("starting test: %s", test.name) logger.info("starting test: %s", test.name)
ref_name = test_path(path.join("test-refs", f"{test.name}.png")) if run_test(root, test):
ref = Image.open(ref_name) if path.exists(ref_name) else None
if run_test(root, test, ref):
logger.info("test passed: %s", test.name) logger.info("test passed: %s", test.name)
passed.append(test.name) passed.append(test.name)
else: else: