apply lint
This commit is contained in:
parent
d09446ca68
commit
ce8c7205dc
|
@ -266,6 +266,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("error converting correction model %s: %s", name, e)
|
logger.error("error converting correction model %s: %s", name, e)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||||
|
|
|
@ -222,7 +222,9 @@ def load_tensor(name: str, map_location=None):
|
||||||
)
|
)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
checkpoint = (
|
checkpoint = (
|
||||||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
checkpoint["state_dict"]
|
||||||
|
if "state_dict" in checkpoint
|
||||||
|
else checkpoint
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("loading ckpt")
|
logger.debug("loading ckpt")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import gc
|
import gc
|
||||||
|
import threading
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
import threading
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -136,7 +136,9 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||||
|
|
||||||
|
|
||||||
def run_gc(devices: List[DeviceParams] = []):
|
def run_gc(devices: List[DeviceParams] = []):
|
||||||
logger.debug("running garbage collection with %s active threads", threading.active_count())
|
logger.debug(
|
||||||
|
"running garbage collection with %s active threads", threading.active_count()
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
Loading…
Reference in New Issue