diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index ea62fd6b..bfeabb00 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -266,6 +266,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): except Exception as e: logger.error("error converting correction model %s: %s", name, e) + def main() -> int: parser = ArgumentParser( prog="onnx-web model converter", description="convert checkpoint models to ONNX" diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 7ddfe5f2..676e2663 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -222,7 +222,9 @@ def load_tensor(name: str, map_location=None): ) checkpoint = torch.load(name, map_location=map_location) checkpoint = ( - checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint + checkpoint["state_dict"] + if "state_dict" in checkpoint + else checkpoint ) else: logger.debug("loading ckpt") diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 104cf99e..a23198e4 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,7 +1,7 @@ import gc +import threading from logging import getLogger from os import environ, path -import threading from typing import Any, Dict, List, Optional, Union import torch @@ -136,7 +136,9 @@ def get_size(val: Union[int, str, None]) -> SizeChart: def run_gc(devices: List[DeviceParams] = []): - logger.debug("running garbage collection with %s active threads", threading.active_count()) + logger.debug( + "running garbage collection with %s active threads", threading.active_count() + ) gc.collect() if torch.cuda.is_available():