1
0
Fork 0

fix(api): use server model path while converting SD checkpoints (#221)

This commit is contained in:
Sean Sube 2023-03-07 18:55:14 -06:00
parent 9d9bd1a639
commit c45915e558
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 19 additions and 11 deletions

View File

@ -1300,6 +1300,7 @@ def download_model(db_config: TrainingConfig, token):
def get_config_path(
context: ConversionContext,
model_version: str = "v1",
train_type: str = "default",
config_base_name: str = "training",
@ -1310,11 +1311,7 @@ def get_config_path(
)
parts = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"..",
"models",
context.model_path,
"configs",
f"{model_version}-{config_base_name}-{train_type}.yaml",
)
@ -1322,7 +1319,11 @@ def get_config_path(
def get_config_file(
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None
context: ConversionContext,
train_unfrozen=False,
v2=False,
prediction_type="epsilon",
config_file=None,
):
if config_file is not None:
return config_file
@ -1344,12 +1345,16 @@ def get_config_file(
model_train_type = train_types["default"]
return get_config_path(
model_version_name, model_train_type, config_base_name, prediction_type
context,
model_version_name,
model_train_type,
config_base_name,
prediction_type,
)
def extract_checkpoint(
ctx: ConversionContext,
context: ConversionContext,
new_model_name: str,
checkpoint_file: str,
scheduler_type="ddim",
@ -1394,7 +1399,10 @@ def extract_checkpoint(
# Create empty config
db_config = TrainingConfig(
ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file
context,
model_name=new_model_name,
scheduler=scheduler_type,
src=checkpoint_file,
)
original_config_file = None
@ -1434,7 +1442,7 @@ def extract_checkpoint(
prediction_type = "epsilon"
original_config_file = get_config_file(
train_unfrozen, v2, prediction_type, config_file=config_file
context, train_unfrozen, v2, prediction_type, config_file=config_file
)
logger.info(
@ -1525,7 +1533,7 @@ def extract_checkpoint(
checkpoint, vae_config
)
else:
vae_file = os.path.join(ctx.model_path, vae_file)
vae_file = os.path.join(context.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(