fix array sizing, update release tests
This commit is contained in:
parent
1a8d538bfe
commit
3e5a95548b
|
@ -333,8 +333,8 @@ class StageResult:
|
|||
)
|
||||
elif self.arrays is not None:
|
||||
return Size(
|
||||
max([image.height for image in self.images], default=0),
|
||||
max([image.height for image in self.images], default=0),
|
||||
max([array.shape[0] for array in self.arrays], default=0),
|
||||
max([array.shape[1] for array in self.arrays], default=0),
|
||||
) # TODO: which fields within the shape are width/height?
|
||||
else:
|
||||
return Size(0, 0)
|
||||
|
|
|
@ -607,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("output name is required")
|
||||
|
||||
output_file = sanitize_name(output_file)
|
||||
pending, progress = pool.done(output_file)
|
||||
status, progress = pool.status(output_file)
|
||||
|
||||
if pending:
|
||||
if status == JobStatus.PENDING:
|
||||
return ready_reply(pending=True)
|
||||
|
||||
if progress is None:
|
||||
|
@ -623,10 +623,10 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
|||
) # is a missing image really an error? yes will display the retry button
|
||||
|
||||
return ready_reply(
|
||||
ready=progress.finished,
|
||||
progress=progress.progress,
|
||||
failed=progress.failed,
|
||||
cancelled=progress.cancelled,
|
||||
ready=(status == JobStatus.SUCCESS),
|
||||
progress=progress.steps.current,
|
||||
failed=(status == JobStatus.FAILED),
|
||||
cancelled=(status == JobStatus.CANCELLED),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -463,7 +463,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]:
|
|||
resp = requests.post(f"{host}/api/{test.query}", files=files)
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
return json.get("outputs")
|
||||
return json.get("name")
|
||||
else:
|
||||
logger.warning("generate request failed: %s: %s", resp.status_code, resp.text)
|
||||
raise TestError("error generating image")
|
||||
|
@ -484,10 +484,21 @@ def check_ready(host: str, key: str) -> bool:
|
|||
logger.warning("ready request failed: %s", resp.status_code)
|
||||
raise TestError("error getting image status")
|
||||
|
||||
def check_outputs(host: str, key: str) -> List[str]:
|
||||
resp = requests.get(f"{host}/api/ready?output={key}")
|
||||
if resp.status_code == 200:
|
||||
json = resp.json()
|
||||
outputs = json.get("outputs", [])
|
||||
return outputs
|
||||
|
||||
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
|
||||
raise TestError("error getting image outputs")
|
||||
|
||||
def download_images(host: str, key: str) -> List[Image.Image]:
|
||||
outputs = check_outputs(host, key)
|
||||
|
||||
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
|
||||
images = []
|
||||
for key in keys:
|
||||
for key in outputs:
|
||||
resp = requests.get(f"{host}/output/{key}")
|
||||
if resp.status_code == 200:
|
||||
logger.debug("downloading image: %s", key)
|
||||
|
@ -528,14 +539,14 @@ def run_test(
|
|||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||
"""
|
||||
|
||||
keys = generate_images(host, test)
|
||||
if keys is None:
|
||||
job = generate_images(host, test)
|
||||
if job is None:
|
||||
return TestResult.failed(test.name, "could not generate image")
|
||||
|
||||
ready = False
|
||||
for attempt in tqdm(range(test.max_attempts * time_mult)):
|
||||
if check_ready(host, keys[0]):
|
||||
logger.debug("image is ready: %s", keys)
|
||||
if check_ready(host, job):
|
||||
logger.debug("image is ready: %s", job)
|
||||
ready = True
|
||||
break
|
||||
else:
|
||||
|
@ -545,7 +556,7 @@ def run_test(
|
|||
if not ready:
|
||||
return TestResult.failed(test.name, "image was not ready in time")
|
||||
|
||||
results = download_images(host, keys)
|
||||
results = download_images(host, job)
|
||||
if results is None or len(results) == 0:
|
||||
return TestResult.failed(test.name, "could not download image")
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from os import path
|
|||
from typing import List
|
||||
from unittest import skipUnless
|
||||
|
||||
from onnx_web.params import DeviceParams
|
||||
from onnx_web.params import DeviceParams, ImageParams, Size
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
|
||||
|
||||
|
@ -23,6 +23,14 @@ def test_device() -> DeviceParams:
|
|||
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||
|
||||
|
||||
def test_size() -> Size:
|
||||
return Size(64, 64)
|
||||
|
||||
|
||||
def test_params() -> ImageParams:
|
||||
return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0)
|
||||
|
||||
|
||||
def test_worker() -> WorkerContext:
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
|
|
Loading…
Reference in New Issue