feat(api): add an option to remove temporary Torch files after converting to ONNX (#122)
This commit is contained in:
parent
6e71775782
commit
25c41c8d66
|
@ -310,6 +310,7 @@ def main() -> int:
|
|||
|
||||
# extra models
|
||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||
parser.add_argument("--prune", nargs="*", type=str, default=[])
|
||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||
|
||||
# export options
|
||||
|
|
|
@ -1704,4 +1704,9 @@ def convert_diffusion_original(
|
|||
del model["vae"]
|
||||
|
||||
convert_diffusion_diffusers(ctx, model, working_name)
|
||||
|
||||
if "torch" in ctx.prune:
|
||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||
shutil.rmtree(torch_path)
|
||||
|
||||
logger.info("ONNX pipeline saved to %s", name)
|
||||
|
|
|
@ -29,6 +29,7 @@ class ConversionContext(ServerContext):
|
|||
half: Optional[bool] = False,
|
||||
opset: Optional[int] = None,
|
||||
token: Optional[str] = None,
|
||||
prune: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(model_path=model_path, cache_path=cache_path)
|
||||
|
@ -36,6 +37,7 @@ class ConversionContext(ServerContext):
|
|||
self.half = half
|
||||
self.opset = opset
|
||||
self.token = token
|
||||
self.prune = prune or []
|
||||
|
||||
if device is not None:
|
||||
self.training_device = device
|
||||
|
|
Loading…
Reference in New Issue