diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 8c94d798..57f8619e 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -333,8 +333,8 @@ class StageResult: ) elif self.arrays is not None: return Size( - max([image.height for image in self.images], default=0), - max([image.height for image in self.images], default=0), + max([array.shape[0] for array in self.arrays], default=0), + max([array.shape[1] for array in self.arrays], default=0), ) # TODO: which fields within the shape are width/height? else: return Size(0, 0) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 26777ce6..fc52f0f7 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -607,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): return error_reply("output name is required") output_file = sanitize_name(output_file) - pending, progress = pool.done(output_file) + status, progress = pool.status(output_file) - if pending: + if status == JobStatus.PENDING: return ready_reply(pending=True) if progress is None: @@ -623,10 +623,10 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): ) # is a missing image really an error? yes will display the retry button return ready_reply( - ready=progress.finished, - progress=progress.progress, - failed=progress.failed, - cancelled=progress.cancelled, + ready=(status == JobStatus.SUCCESS), + progress=progress.steps.current, + failed=(status == JobStatus.FAILED), + cancelled=(status == JobStatus.CANCELLED), ) diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index 1af512c5..c281cc10 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -463,7 +463,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]: resp = requests.post(f"{host}/api/{test.query}", files=files) if resp.status_code == 200: json = resp.json() - return json.get("outputs") + return json.get("name") else: logger.warning("generate request failed: %s: %s", resp.status_code, resp.text) raise TestError("error generating image") @@ -484,10 +484,21 @@ def check_ready(host: str, key: str) -> bool: logger.warning("ready request failed: %s", resp.status_code) raise TestError("error getting image status") +def check_outputs(host: str, key: str) -> List[str]: + resp = requests.get(f"{host}/api/ready?output={key}") + if resp.status_code == 200: + json = resp.json() + outputs = json.get("outputs", []) + return outputs + + logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text) + raise TestError("error getting image outputs") + +def download_images(host: str, key: str) -> List[Image.Image]: + outputs = check_outputs(host, key) -def download_images(host: str, keys: List[str]) -> List[Image.Image]: images = [] - for key in keys: + for key in outputs: resp = requests.get(f"{host}/output/{key}") if resp.status_code == 200: logger.debug("downloading image: %s", key) @@ -528,14 +539,14 @@ def run_test( Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ - keys = generate_images(host, test) - if keys is None: + job = generate_images(host, test) + if job is None: return TestResult.failed(test.name, "could not generate image") ready = False for attempt in tqdm(range(test.max_attempts * time_mult)): - if check_ready(host, keys[0]): - logger.debug("image is ready: %s", keys) + if check_ready(host, job): + logger.debug("image is ready: %s", job) ready = True break else: @@ -545,7 +556,7 @@ def run_test( if not ready: return TestResult.failed(test.name, "image was not ready in time") - results = download_images(host, keys) + results = download_images(host, job) if results is None or len(results) == 0: return TestResult.failed(test.name, "could not download image") diff --git a/api/tests/helpers.py b/api/tests/helpers.py index f0c10edd..4803e109 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -3,7 +3,7 @@ from os import path from typing import List from unittest import skipUnless -from onnx_web.params import DeviceParams +from onnx_web.params import DeviceParams, ImageParams, Size from onnx_web.worker.context import WorkerContext @@ -23,6 +23,14 @@ def test_device() -> DeviceParams: return DeviceParams("cpu", "CPUExecutionProvider") +def test_size() -> Size: + return Size(64, 64) + + +def test_params() -> ImageParams: + return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0) + + def test_worker() -> WorkerContext: cancel = Value("L", 0) logs = Queue()