pass correct paths to torch extraction
This commit is contained in:
parent
9ffe266384
commit
769350115c
|
@ -178,7 +178,7 @@ class TrainingConfig:
|
||||||
):
|
):
|
||||||
model_name = sanitize_name(model_name)
|
model_name = sanitize_name(model_name)
|
||||||
model_dir = os.path.join(conversion.cache_path, model_name)
|
model_dir = os.path.join(conversion.cache_path, model_name)
|
||||||
working_dir = os.path.join(model_dir, "working")
|
working_dir = f"{model_dir}-torch"
|
||||||
|
|
||||||
if not os.path.exists(working_dir):
|
if not os.path.exists(working_dir):
|
||||||
os.makedirs(working_dir)
|
os.makedirs(working_dir)
|
||||||
|
@ -1688,33 +1688,33 @@ def extract_checkpoint(
|
||||||
def convert_extract_checkpoint(
|
def convert_extract_checkpoint(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
source: str,
|
source: str,
|
||||||
dest: str,
|
name: str,
|
||||||
is_inpainting: Optional[bool] = False,
|
is_inpainting: Optional[bool] = False,
|
||||||
config_file: Optional[str] = None,
|
config_file: Optional[str] = None,
|
||||||
vae_file: Optional[str] = None,
|
vae_file: Optional[str] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
working_name = os.path.join(conversion.cache_path, dest, "working")
|
dest_name = os.path.join(conversion.cache_path, f"{name}-torch")
|
||||||
model_index = os.path.join(working_name, "model_index.json")
|
model_index = os.path.join(dest_name, "model_index.json")
|
||||||
|
|
||||||
if os.path.exists(working_name) and os.path.exists(model_index):
|
if os.path.exists(dest_name) and os.path.exists(model_index):
|
||||||
logger.info("extracted Torch model already exists, reusing: %s", working_name)
|
logger.info("extracted Torch model already exists, reusing: %s", dest_name)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"extracting checkpoint to Torch model: %s -> %s",
|
"extracting checkpoint to Torch model: %s -> %s",
|
||||||
source,
|
source,
|
||||||
dest,
|
name,
|
||||||
)
|
)
|
||||||
if extract_checkpoint(
|
if extract_checkpoint(
|
||||||
conversion,
|
conversion,
|
||||||
dest,
|
name,
|
||||||
source,
|
source,
|
||||||
config_file=config_file,
|
config_file=config_file,
|
||||||
is_inpainting=is_inpainting,
|
is_inpainting=is_inpainting,
|
||||||
vae_file=vae_file,
|
vae_file=vae_file,
|
||||||
):
|
):
|
||||||
logger.info("extracted checkpoint to Torch model: %s", working_name)
|
logger.info("extracted checkpoint to Torch model: %s", dest_name)
|
||||||
else:
|
else:
|
||||||
logger.error("unable to convert checkpoint to Torch model")
|
logger.error("unable to convert checkpoint to Torch model")
|
||||||
raise ValueError("unable to convert checkpoint to Torch model")
|
raise ValueError("unable to convert checkpoint to Torch model")
|
||||||
|
|
||||||
return working_name
|
return dest_name
|
||||||
|
|
|
@ -389,8 +389,6 @@ def convert_diffusion_diffusers(
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
cache_path = fetch_model(conversion, name, source, format=format)
|
cache_path = fetch_model(conversion, name, source, format=format)
|
||||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
|
||||||
|
|
||||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||||
v2, pipe_args = get_model_version(
|
v2, pipe_args = get_model_version(
|
||||||
cache_path, conversion.map_location, size=image_size, version=version
|
cache_path, conversion.map_location, size=image_size, version=version
|
||||||
|
@ -419,7 +417,7 @@ def convert_diffusion_diffusers(
|
||||||
torch_source = convert_extract_checkpoint(
|
torch_source = convert_extract_checkpoint(
|
||||||
conversion,
|
conversion,
|
||||||
cache_path,
|
cache_path,
|
||||||
temp_path,
|
name,
|
||||||
is_inpainting=is_inpainting,
|
is_inpainting=is_inpainting,
|
||||||
config_file=config_path,
|
config_file=config_path,
|
||||||
vae_file=replace_vae,
|
vae_file=replace_vae,
|
||||||
|
|
Loading…
Reference in New Issue