feat(api): add an option to remove temporary Torch files after converting to ONNX (#122)
This commit is contained in:
parent
6e71775782
commit
25c41c8d66
|
@ -310,6 +310,7 @@ def main() -> int:
|
||||||
|
|
||||||
# extra models
|
# extra models
|
||||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||||
|
parser.add_argument("--prune", nargs="*", type=str, default=[])
|
||||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||||
|
|
||||||
# export options
|
# export options
|
||||||
|
|
|
@ -1704,4 +1704,9 @@ def convert_diffusion_original(
|
||||||
del model["vae"]
|
del model["vae"]
|
||||||
|
|
||||||
convert_diffusion_diffusers(ctx, model, working_name)
|
convert_diffusion_diffusers(ctx, model, working_name)
|
||||||
|
|
||||||
|
if "torch" in ctx.prune:
|
||||||
|
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||||
|
shutil.rmtree(torch_path)
|
||||||
|
|
||||||
logger.info("ONNX pipeline saved to %s", name)
|
logger.info("ONNX pipeline saved to %s", name)
|
||||||
|
|
|
@ -29,6 +29,7 @@ class ConversionContext(ServerContext):
|
||||||
half: Optional[bool] = False,
|
half: Optional[bool] = False,
|
||||||
opset: Optional[int] = None,
|
opset: Optional[int] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
|
prune: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(model_path=model_path, cache_path=cache_path)
|
super().__init__(model_path=model_path, cache_path=cache_path)
|
||||||
|
@ -36,6 +37,7 @@ class ConversionContext(ServerContext):
|
||||||
self.half = half
|
self.half = half
|
||||||
self.opset = opset
|
self.opset = opset
|
||||||
self.token = token
|
self.token = token
|
||||||
|
self.prune = prune or []
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.training_device = device
|
self.training_device = device
|
||||||
|
|
Loading…
Reference in New Issue