1
0
Fork 0

pass correct paths to torch extraction

This commit is contained in:
Sean Sube 2023-12-24 06:00:44 -06:00
parent 9ffe266384
commit 769350115c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 11 additions and 13 deletions

View File

@ -178,7 +178,7 @@ class TrainingConfig:
):
model_name = sanitize_name(model_name)
model_dir = os.path.join(conversion.cache_path, model_name)
working_dir = os.path.join(model_dir, "working")
working_dir = f"{model_dir}-torch"
if not os.path.exists(working_dir):
os.makedirs(working_dir)
@ -1688,33 +1688,33 @@ def extract_checkpoint(
def convert_extract_checkpoint(
conversion: ConversionContext,
source: str,
dest: str,
name: str,
is_inpainting: Optional[bool] = False,
config_file: Optional[str] = None,
vae_file: Optional[str] = None,
) -> Tuple[bool, str]:
working_name = os.path.join(conversion.cache_path, dest, "working")
model_index = os.path.join(working_name, "model_index.json")
dest_name = os.path.join(conversion.cache_path, f"{name}-torch")
model_index = os.path.join(dest_name, "model_index.json")
if os.path.exists(working_name) and os.path.exists(model_index):
logger.info("extracted Torch model already exists, reusing: %s", working_name)
if os.path.exists(dest_name) and os.path.exists(model_index):
logger.info("extracted Torch model already exists, reusing: %s", dest_name)
else:
logger.info(
"extracting checkpoint to Torch model: %s -> %s",
source,
dest,
name,
)
if extract_checkpoint(
conversion,
dest,
name,
source,
config_file=config_file,
is_inpainting=is_inpainting,
vae_file=vae_file,
):
logger.info("extracted checkpoint to Torch model: %s", working_name)
logger.info("extracted checkpoint to Torch model: %s", dest_name)
else:
logger.error("unable to convert checkpoint to Torch model")
raise ValueError("unable to convert checkpoint to Torch model")
return working_name
return dest_name

View File

@ -389,8 +389,6 @@ def convert_diffusion_diffusers(
return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format)
temp_path = path.join(conversion.cache_path, f"{name}-torch")
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
cache_path, conversion.map_location, size=image_size, version=version
@ -419,7 +417,7 @@ def convert_diffusion_diffusers(
torch_source = convert_extract_checkpoint(
conversion,
cache_path,
temp_path,
name,
is_inpainting=is_inpainting,
config_file=config_path,
vae_file=replace_vae,