1
0
Fork 0

fix array sizing, update release tests

This commit is contained in:
Sean Sube 2024-01-06 02:33:01 -06:00
parent 1a8d538bfe
commit 3e5a95548b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 36 additions and 17 deletions

View File

@ -333,8 +333,8 @@ class StageResult:
) )
elif self.arrays is not None: elif self.arrays is not None:
return Size( return Size(
max([image.height for image in self.images], default=0), max([array.shape[0] for array in self.arrays], default=0),
max([image.height for image in self.images], default=0), max([array.shape[1] for array in self.arrays], default=0),
) # TODO: which fields within the shape are width/height? ) # TODO: which fields within the shape are width/height?
else: else:
return Size(0, 0) return Size(0, 0)

View File

@ -607,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required") return error_reply("output name is required")
output_file = sanitize_name(output_file) output_file = sanitize_name(output_file)
pending, progress = pool.done(output_file) status, progress = pool.status(output_file)
if pending: if status == JobStatus.PENDING:
return ready_reply(pending=True) return ready_reply(pending=True)
if progress is None: if progress is None:
@ -623,10 +623,10 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
) # is a missing image really an error? yes will display the retry button ) # is a missing image really an error? yes will display the retry button
return ready_reply( return ready_reply(
ready=progress.finished, ready=(status == JobStatus.SUCCESS),
progress=progress.progress, progress=progress.steps.current,
failed=progress.failed, failed=(status == JobStatus.FAILED),
cancelled=progress.cancelled, cancelled=(status == JobStatus.CANCELLED),
) )

View File

@ -463,7 +463,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{host}/api/{test.query}", files=files) resp = requests.post(f"{host}/api/{test.query}", files=files)
if resp.status_code == 200: if resp.status_code == 200:
json = resp.json() json = resp.json()
return json.get("outputs") return json.get("name")
else: else:
logger.warning("generate request failed: %s: %s", resp.status_code, resp.text) logger.warning("generate request failed: %s: %s", resp.status_code, resp.text)
raise TestError("error generating image") raise TestError("error generating image")
@ -484,10 +484,21 @@ def check_ready(host: str, key: str) -> bool:
logger.warning("ready request failed: %s", resp.status_code) logger.warning("ready request failed: %s", resp.status_code)
raise TestError("error getting image status") raise TestError("error getting image status")
def check_outputs(host: str, key: str) -> List[str]:
resp = requests.get(f"{host}/api/ready?output={key}")
if resp.status_code == 200:
json = resp.json()
outputs = json.get("outputs", [])
return outputs
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
raise TestError("error getting image outputs")
def download_images(host: str, key: str) -> List[Image.Image]:
outputs = check_outputs(host, key)
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
images = [] images = []
for key in keys: for key in outputs:
resp = requests.get(f"{host}/output/{key}") resp = requests.get(f"{host}/output/{key}")
if resp.status_code == 200: if resp.status_code == 200:
logger.debug("downloading image: %s", key) logger.debug("downloading image: %s", key)
@ -528,14 +539,14 @@ def run_test(
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
keys = generate_images(host, test) job = generate_images(host, test)
if keys is None: if job is None:
return TestResult.failed(test.name, "could not generate image") return TestResult.failed(test.name, "could not generate image")
ready = False ready = False
for attempt in tqdm(range(test.max_attempts * time_mult)): for attempt in tqdm(range(test.max_attempts * time_mult)):
if check_ready(host, keys[0]): if check_ready(host, job):
logger.debug("image is ready: %s", keys) logger.debug("image is ready: %s", job)
ready = True ready = True
break break
else: else:
@ -545,7 +556,7 @@ def run_test(
if not ready: if not ready:
return TestResult.failed(test.name, "image was not ready in time") return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys) results = download_images(host, job)
if results is None or len(results) == 0: if results is None or len(results) == 0:
return TestResult.failed(test.name, "could not download image") return TestResult.failed(test.name, "could not download image")

View File

@ -3,7 +3,7 @@ from os import path
from typing import List from typing import List
from unittest import skipUnless from unittest import skipUnless
from onnx_web.params import DeviceParams from onnx_web.params import DeviceParams, ImageParams, Size
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
@ -23,6 +23,14 @@ def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider") return DeviceParams("cpu", "CPUExecutionProvider")
def test_size() -> Size:
return Size(64, 64)
def test_params() -> ImageParams:
return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0)
def test_worker() -> WorkerContext: def test_worker() -> WorkerContext:
cancel = Value("L", 0) cancel = Value("L", 0)
logs = Queue() logs = Queue()