1
0
Fork 0

lint(api): add typed errors from cancelled job and download error

This commit is contained in:
Sean Sube 2023-08-20 22:28:40 -05:00
parent 44004730ea
commit 404a314050
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 21 additions and 3 deletions

View File

@ -16,6 +16,7 @@ from packaging import version
from torch.onnx import export from torch.onnx import export
from ..constants import ONNX_WEIGHTS from ..constants import ONNX_WEIGHTS
from ..errors import RequestException
from ..server import ServerContext from ..server import ServerContext
from ..utils import get_boolean from ..utils import get_boolean
@ -103,8 +104,8 @@ def download_progress(urls: List[Tuple[str, str]]):
) )
if req.status_code != 200: if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RuntimeError( raise RequestException(
"Request to %s failed with status code: %s" % (url, req.status_code) "request to %s failed with status code: %s" % (url, req.status_code)
) )
total = int(req.headers.get("Content-Length", 0)) total = int(req.headers.get("Content-Length", 0))

View File

@ -4,3 +4,19 @@ class RetryException(Exception):
""" """
pass pass
class CancelledException(Exception):
"""
Used when a job has been cancelled and needs to stop.
"""
pass
class RequestException(Exception):
"""
Used when an HTTP request has failed.
"""
pass

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Optional
from torch.multiprocessing import Queue, Value from torch.multiprocessing import Queue, Value
from ..errors import CancelledException
from ..params import DeviceParams from ..params import DeviceParams
from .command import JobCommand, ProgressCommand from .command import JobCommand, ProgressCommand
@ -102,7 +103,7 @@ class WorkerContext:
raise RuntimeError("no job on which to set progress") raise RuntimeError("no job on which to set progress")
if self.is_cancelled(): if self.is_cancelled():
raise RuntimeError("job has been cancelled") raise CancelledException("job has been cancelled")
logger.debug("setting progress for job %s to %s", self.job, progress) logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(