1
0
Fork 0

pass correct paths to torch extraction

This commit is contained in:
Sean Sube 2023-12-24 06:00:44 -06:00
parent 9ffe266384
commit 769350115c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 11 additions and 13 deletions

View File

@ -178,7 +178,7 @@ class TrainingConfig:
): ):
model_name = sanitize_name(model_name) model_name = sanitize_name(model_name)
model_dir = os.path.join(conversion.cache_path, model_name) model_dir = os.path.join(conversion.cache_path, model_name)
working_dir = os.path.join(model_dir, "working") working_dir = f"{model_dir}-torch"
if not os.path.exists(working_dir): if not os.path.exists(working_dir):
os.makedirs(working_dir) os.makedirs(working_dir)
@ -1688,33 +1688,33 @@ def extract_checkpoint(
def convert_extract_checkpoint( def convert_extract_checkpoint(
conversion: ConversionContext, conversion: ConversionContext,
source: str, source: str,
dest: str, name: str,
is_inpainting: Optional[bool] = False, is_inpainting: Optional[bool] = False,
config_file: Optional[str] = None, config_file: Optional[str] = None,
vae_file: Optional[str] = None, vae_file: Optional[str] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
working_name = os.path.join(conversion.cache_path, dest, "working") dest_name = os.path.join(conversion.cache_path, f"{name}-torch")
model_index = os.path.join(working_name, "model_index.json") model_index = os.path.join(dest_name, "model_index.json")
if os.path.exists(working_name) and os.path.exists(model_index): if os.path.exists(dest_name) and os.path.exists(model_index):
logger.info("extracted Torch model already exists, reusing: %s", working_name) logger.info("extracted Torch model already exists, reusing: %s", dest_name)
else: else:
logger.info( logger.info(
"extracting checkpoint to Torch model: %s -> %s", "extracting checkpoint to Torch model: %s -> %s",
source, source,
dest, name,
) )
if extract_checkpoint( if extract_checkpoint(
conversion, conversion,
dest, name,
source, source,
config_file=config_file, config_file=config_file,
is_inpainting=is_inpainting, is_inpainting=is_inpainting,
vae_file=vae_file, vae_file=vae_file,
): ):
logger.info("extracted checkpoint to Torch model: %s", working_name) logger.info("extracted checkpoint to Torch model: %s", dest_name)
else: else:
logger.error("unable to convert checkpoint to Torch model") logger.error("unable to convert checkpoint to Torch model")
raise ValueError("unable to convert checkpoint to Torch model") raise ValueError("unable to convert checkpoint to Torch model")
return working_name return dest_name

View File

@ -389,8 +389,6 @@ def convert_diffusion_diffusers(
return (False, dest_path) return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format) cache_path = fetch_model(conversion, name, source, format=format)
temp_path = path.join(conversion.cache_path, f"{name}-torch")
pipe_class = CONVERT_PIPELINES.get(pipe_type) pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version( v2, pipe_args = get_model_version(
cache_path, conversion.map_location, size=image_size, version=version cache_path, conversion.map_location, size=image_size, version=version
@ -419,7 +417,7 @@ def convert_diffusion_diffusers(
torch_source = convert_extract_checkpoint( torch_source = convert_extract_checkpoint(
conversion, conversion,
cache_path, cache_path,
temp_path, name,
is_inpainting=is_inpainting, is_inpainting=is_inpainting,
config_file=config_path, config_file=config_path,
vae_file=replace_vae, vae_file=replace_vae,