pass correct paths to torch extraction
This commit is contained in:
parent
9ffe266384
commit
769350115c
|
@ -178,7 +178,7 @@ class TrainingConfig:
|
|||
):
|
||||
model_name = sanitize_name(model_name)
|
||||
model_dir = os.path.join(conversion.cache_path, model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
working_dir = f"{model_dir}-torch"
|
||||
|
||||
if not os.path.exists(working_dir):
|
||||
os.makedirs(working_dir)
|
||||
|
@ -1688,33 +1688,33 @@ def extract_checkpoint(
|
|||
def convert_extract_checkpoint(
|
||||
conversion: ConversionContext,
|
||||
source: str,
|
||||
dest: str,
|
||||
name: str,
|
||||
is_inpainting: Optional[bool] = False,
|
||||
config_file: Optional[str] = None,
|
||||
vae_file: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
working_name = os.path.join(conversion.cache_path, dest, "working")
|
||||
model_index = os.path.join(working_name, "model_index.json")
|
||||
dest_name = os.path.join(conversion.cache_path, f"{name}-torch")
|
||||
model_index = os.path.join(dest_name, "model_index.json")
|
||||
|
||||
if os.path.exists(working_name) and os.path.exists(model_index):
|
||||
logger.info("extracted Torch model already exists, reusing: %s", working_name)
|
||||
if os.path.exists(dest_name) and os.path.exists(model_index):
|
||||
logger.info("extracted Torch model already exists, reusing: %s", dest_name)
|
||||
else:
|
||||
logger.info(
|
||||
"extracting checkpoint to Torch model: %s -> %s",
|
||||
source,
|
||||
dest,
|
||||
name,
|
||||
)
|
||||
if extract_checkpoint(
|
||||
conversion,
|
||||
dest,
|
||||
name,
|
||||
source,
|
||||
config_file=config_file,
|
||||
is_inpainting=is_inpainting,
|
||||
vae_file=vae_file,
|
||||
):
|
||||
logger.info("extracted checkpoint to Torch model: %s", working_name)
|
||||
logger.info("extracted checkpoint to Torch model: %s", dest_name)
|
||||
else:
|
||||
logger.error("unable to convert checkpoint to Torch model")
|
||||
raise ValueError("unable to convert checkpoint to Torch model")
|
||||
|
||||
return working_name
|
||||
return dest_name
|
||||
|
|
|
@ -389,8 +389,6 @@ def convert_diffusion_diffusers(
|
|||
return (False, dest_path)
|
||||
|
||||
cache_path = fetch_model(conversion, name, source, format=format)
|
||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||
|
||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(
|
||||
cache_path, conversion.map_location, size=image_size, version=version
|
||||
|
@ -419,7 +417,7 @@ def convert_diffusion_diffusers(
|
|||
torch_source = convert_extract_checkpoint(
|
||||
conversion,
|
||||
cache_path,
|
||||
temp_path,
|
||||
name,
|
||||
is_inpainting=is_inpainting,
|
||||
config_file=config_path,
|
||||
vae_file=replace_vae,
|
||||
|
|
Loading…
Reference in New Issue