fix(scripts): update release tests with support for batches
This commit is contained in:
parent
cb8e9e7080
commit
6809d2da82
|
@ -25,6 +25,9 @@ except Exception as err:
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
FAST_TEST = 20
|
||||
SLOW_TEST = 50
|
||||
|
||||
|
||||
def test_root() -> str:
|
||||
if len(sys.argv) > 1:
|
||||
|
@ -42,7 +45,7 @@ class TestCase:
|
|||
self,
|
||||
name: str,
|
||||
query: str,
|
||||
max_attempts: int = 20,
|
||||
max_attempts: int = FAST_TEST,
|
||||
mse_threshold: float = 0.001,
|
||||
source: Union[Image.Image, List[Image.Image]] = None,
|
||||
mask: Image.Image = None,
|
||||
|
@ -95,23 +98,23 @@ TEST_DATA = [
|
|||
TestCase(
|
||||
"img2img-sd-v1-5-512-pumpkin",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
),
|
||||
TestCase(
|
||||
"img2img-sd-v1-5-256-pumpkin",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
|
||||
source="txt2img-sd-v1-5-256-muffin",
|
||||
source="txt2img-sd-v1-5-256-muffin-0",
|
||||
),
|
||||
TestCase(
|
||||
"inpaint-v1-512-white",
|
||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-white",
|
||||
),
|
||||
TestCase(
|
||||
"inpaint-v1-512-black",
|
||||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
),
|
||||
TestCase(
|
||||
|
@ -120,8 +123,9 @@ TEST_DATA = [
|
|||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||
"&top=256&bottom=256&left=256&right=256"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.025,
|
||||
),
|
||||
TestCase(
|
||||
|
@ -130,8 +134,9 @@ TEST_DATA = [
|
|||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||
"&top=512&bottom=512&left=0&right=0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.010,
|
||||
),
|
||||
TestCase(
|
||||
|
@ -140,27 +145,28 @@ TEST_DATA = [
|
|||
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
|
||||
"&top=0&bottom=0&left=512&right=512"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.010,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x2-1024-muffin",
|
||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x4-2048-muffin",
|
||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
|
||||
source="txt2img-sd-v1-5-512-muffin",
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
),
|
||||
TestCase(
|
||||
"blend-512-muffin-black",
|
||||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||
mask="mask-black",
|
||||
source=[
|
||||
"txt2img-sd-v1-5-512-muffin",
|
||||
"txt2img-sd-v2-1-512-muffin",
|
||||
"txt2img-sd-v1-5-512-muffin-0",
|
||||
"txt2img-sd-v2-1-512-muffin-0",
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
|
@ -168,14 +174,14 @@ TEST_DATA = [
|
|||
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
|
||||
mask="mask-white",
|
||||
source=[
|
||||
"txt2img-sd-v2-1-512-muffin",
|
||||
"txt2img-sd-v1-5-512-muffin",
|
||||
"txt2img-sd-v2-1-512-muffin-0",
|
||||
"txt2img-sd-v1-5-512-muffin-0",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def generate_image(root: str, test: TestCase) -> Optional[str]:
|
||||
def generate_images(root: str, test: TestCase) -> Optional[str]:
|
||||
files = {}
|
||||
if test.source is not None:
|
||||
if isinstance(test.source, list):
|
||||
|
@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
|
|||
resp = requests.post(f"{root}/api/{test.query}", files=files)
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
return json.get("output")
|
||||
return json.get("outputs")
|
||||
else:
|
||||
logger.warning("request failed: %s", resp.status_code)
|
||||
logger.warning("request failed: %s: %s", resp.status_code, resp.text)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def download_image(root: str, key: str) -> Image.Image:
|
||||
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
|
||||
images = []
|
||||
for key in keys:
|
||||
resp = requests.get(f"{root}/output/{key}")
|
||||
if resp.status_code == 200:
|
||||
logger.debug("downloading image: %s", key)
|
||||
return Image.open(BytesIO(resp.content))
|
||||
images.append(Image.open(BytesIO(resp.content)))
|
||||
else:
|
||||
logger.warning("request failed: %s", resp.status_code)
|
||||
return None
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
||||
|
@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
|
|||
def run_test(
|
||||
root: str,
|
||||
test: TestCase,
|
||||
ref: Image.Image,
|
||||
) -> bool:
|
||||
"""
|
||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||
"""
|
||||
|
||||
key = generate_image(root, test)
|
||||
if key is None:
|
||||
keys = generate_images(root, test)
|
||||
if keys is None:
|
||||
raise ValueError("could not generate")
|
||||
|
||||
attempts = 0
|
||||
while attempts < test.max_attempts:
|
||||
if check_ready(root, key):
|
||||
logger.debug("image is ready: %s", key)
|
||||
if check_ready(root, keys[0]):
|
||||
logger.debug("image is ready: %s", keys)
|
||||
break
|
||||
else:
|
||||
logger.debug("waiting for image to be ready")
|
||||
|
@ -282,16 +290,25 @@ def run_test(
|
|||
if attempts == test.max_attempts:
|
||||
raise ValueError("image was not ready in time")
|
||||
|
||||
result = download_image(root, key)
|
||||
result.save(test_path(path.join("test-results", f"{test.name}.png")))
|
||||
results = download_images(root, keys)
|
||||
|
||||
passed = True
|
||||
for i in range(len(results)):
|
||||
result = results[i]
|
||||
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||
|
||||
ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
|
||||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||
|
||||
mse = find_mse(result, ref)
|
||||
|
||||
if mse < test.mse_threshold:
|
||||
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
|
||||
return True
|
||||
else:
|
||||
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
|
||||
return False
|
||||
passed = False
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -303,9 +320,7 @@ def main():
|
|||
for test in TEST_DATA:
|
||||
try:
|
||||
logger.info("starting test: %s", test.name)
|
||||
ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
|
||||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||
if run_test(root, test, ref):
|
||||
if run_test(root, test):
|
||||
logger.info("test passed: %s", test.name)
|
||||
passed.append(test.name)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue