1
0
Fork 0

fix(api): use server model path while converting SD checkpoints (#221)

This commit is contained in:
Sean Sube 2023-03-07 18:55:14 -06:00
parent 9d9bd1a639
commit c45915e558
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 19 additions and 11 deletions

View File

@ -1300,6 +1300,7 @@ def download_model(db_config: TrainingConfig, token):
def get_config_path( def get_config_path(
context: ConversionContext,
model_version: str = "v1", model_version: str = "v1",
train_type: str = "default", train_type: str = "default",
config_base_name: str = "training", config_base_name: str = "training",
@ -1310,11 +1311,7 @@ def get_config_path(
) )
parts = os.path.join( parts = os.path.join(
os.path.dirname(os.path.realpath(__file__)), context.model_path,
"..",
"..",
"..",
"models",
"configs", "configs",
f"{model_version}-{config_base_name}-{train_type}.yaml", f"{model_version}-{config_base_name}-{train_type}.yaml",
) )
@ -1322,7 +1319,11 @@ def get_config_path(
def get_config_file( def get_config_file(
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None context: ConversionContext,
train_unfrozen=False,
v2=False,
prediction_type="epsilon",
config_file=None,
): ):
if config_file is not None: if config_file is not None:
return config_file return config_file
@ -1344,12 +1345,16 @@ def get_config_file(
model_train_type = train_types["default"] model_train_type = train_types["default"]
return get_config_path( return get_config_path(
model_version_name, model_train_type, config_base_name, prediction_type context,
model_version_name,
model_train_type,
config_base_name,
prediction_type,
) )
def extract_checkpoint( def extract_checkpoint(
ctx: ConversionContext, context: ConversionContext,
new_model_name: str, new_model_name: str,
checkpoint_file: str, checkpoint_file: str,
scheduler_type="ddim", scheduler_type="ddim",
@ -1394,7 +1399,10 @@ def extract_checkpoint(
# Create empty config # Create empty config
db_config = TrainingConfig( db_config = TrainingConfig(
ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file context,
model_name=new_model_name,
scheduler=scheduler_type,
src=checkpoint_file,
) )
original_config_file = None original_config_file = None
@ -1434,7 +1442,7 @@ def extract_checkpoint(
prediction_type = "epsilon" prediction_type = "epsilon"
original_config_file = get_config_file( original_config_file = get_config_file(
train_unfrozen, v2, prediction_type, config_file=config_file context, train_unfrozen, v2, prediction_type, config_file=config_file
) )
logger.info( logger.info(
@ -1525,7 +1533,7 @@ def extract_checkpoint(
checkpoint, vae_config checkpoint, vae_config
) )
else: else:
vae_file = os.path.join(ctx.model_path, vae_file) vae_file = os.path.join(context.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file) logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location) vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint( converted_vae_checkpoint = convert_ldm_vae_checkpoint(