1
0
Fork 0

fix array sizing, update release tests

This commit is contained in:
Sean Sube 2024-01-06 02:33:01 -06:00
parent 1a8d538bfe
commit 3e5a95548b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 36 additions and 17 deletions

View File

@ -333,8 +333,8 @@ class StageResult:
)
elif self.arrays is not None:
return Size(
max([image.height for image in self.images], default=0),
max([image.height for image in self.images], default=0),
max([array.shape[0] for array in self.arrays], default=0),
max([array.shape[1] for array in self.arrays], default=0),
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)

View File

@ -607,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required")
output_file = sanitize_name(output_file)
pending, progress = pool.done(output_file)
status, progress = pool.status(output_file)
if pending:
if status == JobStatus.PENDING:
return ready_reply(pending=True)
if progress is None:
@ -623,10 +623,10 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
) # is a missing image really an error? yes will display the retry button
return ready_reply(
ready=progress.finished,
progress=progress.progress,
failed=progress.failed,
cancelled=progress.cancelled,
ready=(status == JobStatus.SUCCESS),
progress=progress.steps.current,
failed=(status == JobStatus.FAILED),
cancelled=(status == JobStatus.CANCELLED),
)

View File

@ -463,7 +463,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{host}/api/{test.query}", files=files)
if resp.status_code == 200:
json = resp.json()
return json.get("outputs")
return json.get("name")
else:
logger.warning("generate request failed: %s: %s", resp.status_code, resp.text)
raise TestError("error generating image")
@ -484,10 +484,21 @@ def check_ready(host: str, key: str) -> bool:
logger.warning("ready request failed: %s", resp.status_code)
raise TestError("error getting image status")
def check_outputs(host: str, key: str) -> List[str]:
resp = requests.get(f"{host}/api/ready?output={key}")
if resp.status_code == 200:
json = resp.json()
outputs = json.get("outputs", [])
return outputs
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
raise TestError("error getting image outputs")
def download_images(host: str, key: str) -> List[Image.Image]:
outputs = check_outputs(host, key)
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
images = []
for key in keys:
for key in outputs:
resp = requests.get(f"{host}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
@ -528,14 +539,14 @@ def run_test(
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""
keys = generate_images(host, test)
if keys is None:
job = generate_images(host, test)
if job is None:
return TestResult.failed(test.name, "could not generate image")
ready = False
for attempt in tqdm(range(test.max_attempts * time_mult)):
if check_ready(host, keys[0]):
logger.debug("image is ready: %s", keys)
if check_ready(host, job):
logger.debug("image is ready: %s", job)
ready = True
break
else:
@ -545,7 +556,7 @@ def run_test(
if not ready:
return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys)
results = download_images(host, job)
if results is None or len(results) == 0:
return TestResult.failed(test.name, "could not download image")

View File

@ -3,7 +3,7 @@ from os import path
from typing import List
from unittest import skipUnless
from onnx_web.params import DeviceParams
from onnx_web.params import DeviceParams, ImageParams, Size
from onnx_web.worker.context import WorkerContext
@ -23,6 +23,14 @@ def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")
def test_size() -> Size:
return Size(64, 64)
def test_params() -> ImageParams:
return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0)
def test_worker() -> WorkerContext:
cancel = Value("L", 0)
logs = Queue()