fix(scripts): update release tests with support for batches
This commit is contained in:
parent
cb8e9e7080
commit
6809d2da82
|
@ -25,6 +25,9 @@ except Exception as err:
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
FAST_TEST = 20
|
||||||
|
SLOW_TEST = 50
|
||||||
|
|
||||||
|
|
||||||
def test_root() -> str:
|
def test_root() -> str:
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
|
@ -42,7 +45,7 @@ class TestCase:
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
query: str,
|
query: str,
|
||||||
max_attempts: int = 20,
|
max_attempts: int = FAST_TEST,
|
||||||
mse_threshold: float = 0.001,
|
mse_threshold: float = 0.001,
|
||||||
source: Union[Image.Image, List[Image.Image]] = None,
|
source: Union[Image.Image, List[Image.Image]] = None,
|
||||||
mask: Image.Image = None,
|
mask: Image.Image = None,
|
||||||
|
@ -95,23 +98,23 @@ TEST_DATA = [
|
||||||
TestCase(
|
TestCase(
|
||||||
"img2img-sd-v1-5-512-pumpkin",
|
"img2img-sd-v1-5-512-pumpkin",
|
||||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"img2img-sd-v1-5-256-pumpkin",
|
"img2img-sd-v1-5-256-pumpkin",
|
||||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
||||||
source="txt2img-sd-v1-5-256-muffin",
|
source="txt2img-sd-v1-5-256-muffin-0",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"inpaint-v1-512-white",
|
"inpaint-v1-512-white",
|
||||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-white",
|
mask="mask-white",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"inpaint-v1-512-black",
|
"inpaint-v1-512-black",
|
||||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -120,8 +123,9 @@ TEST_DATA = [
|
||||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||||
"&top=256&bottom=256&left=256&right=256"
|
"&top=256&bottom=256&left=256&right=256"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.025,
|
mse_threshold=0.025,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -130,8 +134,9 @@ TEST_DATA = [
|
||||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||||
"&top=512&bottom=512&left=0&right=0"
|
"&top=512&bottom=512&left=0&right=0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.010,
|
mse_threshold=0.010,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -140,27 +145,28 @@ TEST_DATA = [
|
||||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||||
"&top=0&bottom=0&left=512&right=512"
|
"&top=0&bottom=0&left=512&right=512"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.010,
|
mse_threshold=0.010,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x2-1024-muffin",
|
"upscale-resrgan-x2-1024-muffin",
|
||||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x4-2048-muffin",
|
"upscale-resrgan-x4-2048-muffin",
|
||||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
|
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
|
||||||
source="txt2img-sd-v1-5-512-muffin",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"blend-512-muffin-black",
|
"blend-512-muffin-black",
|
||||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
source=[
|
source=[
|
||||||
"txt2img-sd-v1-5-512-muffin",
|
"txt2img-sd-v1-5-512-muffin-0",
|
||||||
"txt2img-sd-v2-1-512-muffin",
|
"txt2img-sd-v2-1-512-muffin-0",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -168,14 +174,14 @@ TEST_DATA = [
|
||||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||||
mask="mask-white",
|
mask="mask-white",
|
||||||
source=[
|
source=[
|
||||||
"txt2img-sd-v2-1-512-muffin",
|
"txt2img-sd-v2-1-512-muffin-0",
|
||||||
"txt2img-sd-v1-5-512-muffin",
|
"txt2img-sd-v1-5-512-muffin-0",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generate_image(root: str, test: TestCase) -> Optional[str]:
|
def generate_images(root: str, test: TestCase) -> Optional[str]:
|
||||||
files = {}
|
files = {}
|
||||||
if test.source is not None:
|
if test.source is not None:
|
||||||
if isinstance(test.source, list):
|
if isinstance(test.source, list):
|
||||||
|
@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
|
||||||
resp = requests.post(f"{root}/api/{test.query}", files=files)
|
resp = requests.post(f"{root}/api/{test.query}", files=files)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
json = resp.json()
|
json = resp.json()
|
||||||
return json.get("output")
|
return json.get("outputs")
|
||||||
else:
|
else:
|
||||||
logger.warning("request failed: %s", resp.status_code)
|
logger.warning("request failed: %s: %s", resp.status_code, resp.text)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def download_image(root: str, key: str) -> Image.Image:
|
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
|
||||||
resp = requests.get(f"{root}/output/{key}")
|
images = []
|
||||||
if resp.status_code == 200:
|
for key in keys:
|
||||||
logger.debug("downloading image: %s", key)
|
resp = requests.get(f"{root}/output/{key}")
|
||||||
return Image.open(BytesIO(resp.content))
|
if resp.status_code == 200:
|
||||||
else:
|
logger.debug("downloading image: %s", key)
|
||||||
logger.warning("request failed: %s", resp.status_code)
|
images.append(Image.open(BytesIO(resp.content)))
|
||||||
return None
|
else:
|
||||||
|
logger.warning("request failed: %s", resp.status_code)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
||||||
|
@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
||||||
def run_test(
|
def run_test(
|
||||||
root: str,
|
root: str,
|
||||||
test: TestCase,
|
test: TestCase,
|
||||||
ref: Image.Image,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = generate_image(root, test)
|
keys = generate_images(root, test)
|
||||||
if key is None:
|
if keys is None:
|
||||||
raise ValueError("could not generate")
|
raise ValueError("could not generate")
|
||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while attempts < test.max_attempts:
|
while attempts < test.max_attempts:
|
||||||
if check_ready(root, key):
|
if check_ready(root, keys[0]):
|
||||||
logger.debug("image is ready: %s", key)
|
logger.debug("image is ready: %s", keys)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.debug("waiting for image to be ready")
|
logger.debug("waiting for image to be ready")
|
||||||
|
@ -282,16 +290,25 @@ def run_test(
|
||||||
if attempts == test.max_attempts:
|
if attempts == test.max_attempts:
|
||||||
raise ValueError("image was not ready in time")
|
raise ValueError("image was not ready in time")
|
||||||
|
|
||||||
result = download_image(root, key)
|
results = download_images(root, keys)
|
||||||
result.save(test_path(path.join("test-results", f"{test.name}.png")))
|
|
||||||
mse = find_mse(result, ref)
|
|
||||||
|
|
||||||
if mse < test.mse_threshold:
|
passed = True
|
||||||
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
|
for i in range(len(results)):
|
||||||
return True
|
result = results[i]
|
||||||
else:
|
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||||
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
|
|
||||||
return False
|
ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
|
||||||
|
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||||
|
|
||||||
|
mse = find_mse(result, ref)
|
||||||
|
|
||||||
|
if mse < test.mse_threshold:
|
||||||
|
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
|
||||||
|
else:
|
||||||
|
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
return passed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -303,9 +320,7 @@ def main():
|
||||||
for test in TEST_DATA:
|
for test in TEST_DATA:
|
||||||
try:
|
try:
|
||||||
logger.info("starting test: %s", test.name)
|
logger.info("starting test: %s", test.name)
|
||||||
ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
|
if run_test(root, test):
|
||||||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
|
||||||
if run_test(root, test, ref):
|
|
||||||
logger.info("test passed: %s", test.name)
|
logger.info("test passed: %s", test.name)
|
||||||
passed.append(test.name)
|
passed.append(test.name)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue