1
0
Fork 0

fix(api): only run GC is devices are passed

This commit is contained in:
Sean Sube 2023-02-19 07:41:16 -06:00
parent 3789862a6f
commit 30978e3e5b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 4 additions and 4 deletions

View File

@ -61,12 +61,12 @@ class ChainPipeline:
def __init__(
self,
stages: List[PipelineStage] = [],
stages: List[PipelineStage] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages)
self.stages = list(stages or [])
def append(self, stage: PipelineStage):
"""

View File

@ -82,13 +82,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
raise Exception("invalid size")
def run_gc(devices: List[DeviceParams] = []):
def run_gc(devices: List[DeviceParams] = None):
logger.debug(
"running garbage collection with %s active threads", threading.active_count()
)
gc.collect()
if torch.cuda.is_available():
if torch.cuda.is_available() and devices is not None:
for device in devices:
logger.debug("running Torch garbage collection for device: %s", device)
with torch.cuda.device(device.torch_str()):