1
0
Fork 0

feat(api): add an option to remove temporary Torch files after converting to ONNX (#122)

This commit is contained in:
Sean Sube 2023-03-08 21:38:17 -06:00
parent 6e71775782
commit 25c41c8d66
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 8 additions and 0 deletions

View File

@ -310,6 +310,7 @@ def main() -> int:
# extra models # extra models
parser.add_argument("--extras", nargs="*", type=str, default=[]) parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument("--prune", nargs="*", type=str, default=[])
parser.add_argument("--skip", nargs="*", type=str, default=[]) parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options # export options

View File

@ -1704,4 +1704,9 @@ def convert_diffusion_original(
del model["vae"] del model["vae"]
convert_diffusion_diffusers(ctx, model, working_name) convert_diffusion_diffusers(ctx, model, working_name)
if "torch" in ctx.prune:
logger.info("removing intermediate Torch models: %s", torch_path)
shutil.rmtree(torch_path)
logger.info("ONNX pipeline saved to %s", name) logger.info("ONNX pipeline saved to %s", name)

View File

@ -29,6 +29,7 @@ class ConversionContext(ServerContext):
half: Optional[bool] = False, half: Optional[bool] = False,
opset: Optional[int] = None, opset: Optional[int] = None,
token: Optional[str] = None, token: Optional[str] = None,
prune: Optional[List[str]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(model_path=model_path, cache_path=cache_path) super().__init__(model_path=model_path, cache_path=cache_path)
@ -36,6 +37,7 @@ class ConversionContext(ServerContext):
self.half = half self.half = half
self.opset = opset self.opset = opset
self.token = token self.token = token
self.prune = prune or []
if device is not None: if device is not None:
self.training_device = device self.training_device = device