1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-02-17 22:49:13 -06:00
parent d09446ca68
commit ce8c7205dc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 8 additions and 3 deletions

View File

@ -266,6 +266,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
except Exception as e: except Exception as e:
logger.error("error converting correction model %s: %s", name, e) logger.error("error converting correction model %s: %s", name, e)
def main() -> int: def main() -> int:
parser = ArgumentParser( parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX" prog="onnx-web model converter", description="convert checkpoint models to ONNX"

View File

@ -222,7 +222,9 @@ def load_tensor(name: str, map_location=None):
) )
checkpoint = torch.load(name, map_location=map_location) checkpoint = torch.load(name, map_location=map_location)
checkpoint = ( checkpoint = (
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint checkpoint["state_dict"]
if "state_dict" in checkpoint
else checkpoint
) )
else: else:
logger.debug("loading ckpt") logger.debug("loading ckpt")

View File

@ -1,7 +1,7 @@
import gc import gc
import threading
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
import threading
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
@ -136,7 +136,9 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
def run_gc(devices: List[DeviceParams] = []): def run_gc(devices: List[DeviceParams] = []):
logger.debug("running garbage collection with %s active threads", threading.active_count()) logger.debug(
"running garbage collection with %s active threads", threading.active_count()
)
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():