fix array sizing, update release tests
This commit is contained in:
parent
1a8d538bfe
commit
3e5a95548b
|
@ -333,8 +333,8 @@ class StageResult:
|
||||||
)
|
)
|
||||||
elif self.arrays is not None:
|
elif self.arrays is not None:
|
||||||
return Size(
|
return Size(
|
||||||
max([image.height for image in self.images], default=0),
|
max([array.shape[0] for array in self.arrays], default=0),
|
||||||
max([image.height for image in self.images], default=0),
|
max([array.shape[1] for array in self.arrays], default=0),
|
||||||
) # TODO: which fields within the shape are width/height?
|
) # TODO: which fields within the shape are width/height?
|
||||||
else:
|
else:
|
||||||
return Size(0, 0)
|
return Size(0, 0)
|
||||||
|
|
|
@ -607,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
|
||||||
output_file = sanitize_name(output_file)
|
output_file = sanitize_name(output_file)
|
||||||
pending, progress = pool.done(output_file)
|
status, progress = pool.status(output_file)
|
||||||
|
|
||||||
if pending:
|
if status == JobStatus.PENDING:
|
||||||
return ready_reply(pending=True)
|
return ready_reply(pending=True)
|
||||||
|
|
||||||
if progress is None:
|
if progress is None:
|
||||||
|
@ -623,10 +623,10 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
) # is a missing image really an error? yes will display the retry button
|
) # is a missing image really an error? yes will display the retry button
|
||||||
|
|
||||||
return ready_reply(
|
return ready_reply(
|
||||||
ready=progress.finished,
|
ready=(status == JobStatus.SUCCESS),
|
||||||
progress=progress.progress,
|
progress=progress.steps.current,
|
||||||
failed=progress.failed,
|
failed=(status == JobStatus.FAILED),
|
||||||
cancelled=progress.cancelled,
|
cancelled=(status == JobStatus.CANCELLED),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -463,7 +463,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]:
|
||||||
resp = requests.post(f"{host}/api/{test.query}", files=files)
|
resp = requests.post(f"{host}/api/{test.query}", files=files)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
json = resp.json()
|
json = resp.json()
|
||||||
return json.get("outputs")
|
return json.get("name")
|
||||||
else:
|
else:
|
||||||
logger.warning("generate request failed: %s: %s", resp.status_code, resp.text)
|
logger.warning("generate request failed: %s: %s", resp.status_code, resp.text)
|
||||||
raise TestError("error generating image")
|
raise TestError("error generating image")
|
||||||
|
@ -484,10 +484,21 @@ def check_ready(host: str, key: str) -> bool:
|
||||||
logger.warning("ready request failed: %s", resp.status_code)
|
logger.warning("ready request failed: %s", resp.status_code)
|
||||||
raise TestError("error getting image status")
|
raise TestError("error getting image status")
|
||||||
|
|
||||||
|
def check_outputs(host: str, key: str) -> List[str]:
|
||||||
|
resp = requests.get(f"{host}/api/ready?output={key}")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
json = resp.json()
|
||||||
|
outputs = json.get("outputs", [])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text)
|
||||||
|
raise TestError("error getting image outputs")
|
||||||
|
|
||||||
|
def download_images(host: str, key: str) -> List[Image.Image]:
|
||||||
|
outputs = check_outputs(host, key)
|
||||||
|
|
||||||
def download_images(host: str, keys: List[str]) -> List[Image.Image]:
|
|
||||||
images = []
|
images = []
|
||||||
for key in keys:
|
for key in outputs:
|
||||||
resp = requests.get(f"{host}/output/{key}")
|
resp = requests.get(f"{host}/output/{key}")
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
logger.debug("downloading image: %s", key)
|
logger.debug("downloading image: %s", key)
|
||||||
|
@ -528,14 +539,14 @@ def run_test(
|
||||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keys = generate_images(host, test)
|
job = generate_images(host, test)
|
||||||
if keys is None:
|
if job is None:
|
||||||
return TestResult.failed(test.name, "could not generate image")
|
return TestResult.failed(test.name, "could not generate image")
|
||||||
|
|
||||||
ready = False
|
ready = False
|
||||||
for attempt in tqdm(range(test.max_attempts * time_mult)):
|
for attempt in tqdm(range(test.max_attempts * time_mult)):
|
||||||
if check_ready(host, keys[0]):
|
if check_ready(host, job):
|
||||||
logger.debug("image is ready: %s", keys)
|
logger.debug("image is ready: %s", job)
|
||||||
ready = True
|
ready = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -545,7 +556,7 @@ def run_test(
|
||||||
if not ready:
|
if not ready:
|
||||||
return TestResult.failed(test.name, "image was not ready in time")
|
return TestResult.failed(test.name, "image was not ready in time")
|
||||||
|
|
||||||
results = download_images(host, keys)
|
results = download_images(host, job)
|
||||||
if results is None or len(results) == 0:
|
if results is None or len(results) == 0:
|
||||||
return TestResult.failed(test.name, "could not download image")
|
return TestResult.failed(test.name, "could not download image")
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from os import path
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest import skipUnless
|
from unittest import skipUnless
|
||||||
|
|
||||||
from onnx_web.params import DeviceParams
|
from onnx_web.params import DeviceParams, ImageParams, Size
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +23,14 @@ def test_device() -> DeviceParams:
|
||||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
||||||
|
def test_size() -> Size:
|
||||||
|
return Size(64, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_params() -> ImageParams:
|
||||||
|
return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0)
|
||||||
|
|
||||||
|
|
||||||
def test_worker() -> WorkerContext:
|
def test_worker() -> WorkerContext:
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
|
|
Loading…
Reference in New Issue