diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index ee2731a1..a486d3f9 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -178,7 +178,7 @@ class TrainingConfig: ): model_name = sanitize_name(model_name) model_dir = os.path.join(conversion.cache_path, model_name) - working_dir = os.path.join(model_dir, "working") + working_dir = f"{model_dir}-torch" if not os.path.exists(working_dir): os.makedirs(working_dir) @@ -1688,33 +1688,33 @@ def extract_checkpoint( def convert_extract_checkpoint( conversion: ConversionContext, source: str, - dest: str, + name: str, is_inpainting: Optional[bool] = False, config_file: Optional[str] = None, vae_file: Optional[str] = None, ) -> Tuple[bool, str]: - working_name = os.path.join(conversion.cache_path, dest, "working") - model_index = os.path.join(working_name, "model_index.json") + dest_name = os.path.join(conversion.cache_path, f"{name}-torch") + model_index = os.path.join(dest_name, "model_index.json") - if os.path.exists(working_name) and os.path.exists(model_index): - logger.info("extracted Torch model already exists, reusing: %s", working_name) + if os.path.exists(dest_name) and os.path.exists(model_index): + logger.info("extracted Torch model already exists, reusing: %s", dest_name) else: logger.info( "extracting checkpoint to Torch model: %s -> %s", source, - dest, + name, ) if extract_checkpoint( conversion, - dest, + name, source, config_file=config_file, is_inpainting=is_inpainting, vae_file=vae_file, ): - logger.info("extracted checkpoint to Torch model: %s", working_name) + logger.info("extracted checkpoint to Torch model: %s", dest_name) else: logger.error("unable to convert checkpoint to Torch model") raise ValueError("unable to convert checkpoint to Torch model") - return working_name + return dest_name diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 303bb8e6..214d3de5 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -389,8 +389,6 @@ def convert_diffusion_diffusers( return (False, dest_path) cache_path = fetch_model(conversion, name, source, format=format) - temp_path = path.join(conversion.cache_path, f"{name}-torch") - pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( cache_path, conversion.map_location, size=image_size, version=version @@ -419,7 +417,7 @@ def convert_diffusion_diffusers( torch_source = convert_extract_checkpoint( conversion, cache_path, - temp_path, + name, is_inpainting=is_inpainting, config_file=config_path, vae_file=replace_vae,