fix(api): only run GC is devices are passed
This commit is contained in:
parent
3789862a6f
commit
30978e3e5b
|
@ -61,12 +61,12 @@ class ChainPipeline:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[PipelineStage] = [],
|
||||
stages: List[PipelineStage] = None,
|
||||
):
|
||||
"""
|
||||
Create a new pipeline that will run the given stages.
|
||||
"""
|
||||
self.stages = list(stages)
|
||||
self.stages = list(stages or [])
|
||||
|
||||
def append(self, stage: PipelineStage):
|
||||
"""
|
||||
|
|
|
@ -82,13 +82,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
|||
raise Exception("invalid size")
|
||||
|
||||
|
||||
def run_gc(devices: List[DeviceParams] = []):
|
||||
def run_gc(devices: List[DeviceParams] = None):
|
||||
logger.debug(
|
||||
"running garbage collection with %s active threads", threading.active_count()
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and devices is not None:
|
||||
for device in devices:
|
||||
logger.debug("running Torch garbage collection for device: %s", device)
|
||||
with torch.cuda.device(device.torch_str()):
|
||||
|
|
Loading…
Reference in New Issue