From 25c41c8d6677bbf008c1c9b109f234f352e26a80 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 8 Mar 2023 21:38:17 -0600 Subject: [PATCH] feat(api): add an option to remove temporary Torch files after converting to ONNX (#122) --- api/onnx_web/convert/__main__.py | 1 + api/onnx_web/convert/diffusion/original.py | 5 +++++ api/onnx_web/convert/utils.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index a2179743..b701c5ae 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -310,6 +310,7 @@ def main() -> int: # extra models parser.add_argument("--extras", nargs="*", type=str, default=[]) + parser.add_argument("--prune", nargs="*", type=str, default=[]) parser.add_argument("--skip", nargs="*", type=str, default=[]) # export options diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index a05a67e5..334b0411 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -1704,4 +1704,9 @@ def convert_diffusion_original( del model["vae"] convert_diffusion_diffusers(ctx, model, working_name) + + if "torch" in ctx.prune: + logger.info("removing intermediate Torch models: %s", torch_path) + shutil.rmtree(torch_path) + logger.info("ONNX pipeline saved to %s", name) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 295fdd87..a06d6126 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -29,6 +29,7 @@ class ConversionContext(ServerContext): half: Optional[bool] = False, opset: Optional[int] = None, token: Optional[str] = None, + prune: Optional[List[str]] = None, **kwargs, ) -> None: super().__init__(model_path=model_path, cache_path=cache_path) @@ -36,6 +37,7 @@ class ConversionContext(ServerContext): self.half = half self.opset = opset self.token = token + self.prune = prune or [] if device is not None: self.training_device = device