diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index a2179743..b701c5ae 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -310,6 +310,7 @@ def main() -> int: # extra models parser.add_argument("--extras", nargs="*", type=str, default=[]) + parser.add_argument("--prune", nargs="*", type=str, default=[]) parser.add_argument("--skip", nargs="*", type=str, default=[]) # export options diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index a05a67e5..334b0411 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -1704,4 +1704,9 @@ def convert_diffusion_original( del model["vae"] convert_diffusion_diffusers(ctx, model, working_name) + + if "torch" in ctx.prune: + logger.info("removing intermediate Torch models: %s", torch_path) + shutil.rmtree(torch_path) + logger.info("ONNX pipeline saved to %s", name) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 295fdd87..a06d6126 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -29,6 +29,7 @@ class ConversionContext(ServerContext): half: Optional[bool] = False, opset: Optional[int] = None, token: Optional[str] = None, + prune: Optional[List[str]] = None, **kwargs, ) -> None: super().__init__(model_path=model_path, cache_path=cache_path) @@ -36,6 +37,7 @@ class ConversionContext(ServerContext): self.half = half self.opset = opset self.token = token + self.prune = prune or [] if device is not None: self.training_device = device