fix(api): use server model path while converting SD checkpoints (#221)
This commit is contained in:
parent
9d9bd1a639
commit
c45915e558
|
@ -1300,6 +1300,7 @@ def download_model(db_config: TrainingConfig, token):
|
||||||
|
|
||||||
|
|
||||||
def get_config_path(
|
def get_config_path(
|
||||||
|
context: ConversionContext,
|
||||||
model_version: str = "v1",
|
model_version: str = "v1",
|
||||||
train_type: str = "default",
|
train_type: str = "default",
|
||||||
config_base_name: str = "training",
|
config_base_name: str = "training",
|
||||||
|
@ -1310,11 +1311,7 @@ def get_config_path(
|
||||||
)
|
)
|
||||||
|
|
||||||
parts = os.path.join(
|
parts = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)),
|
context.model_path,
|
||||||
"..",
|
|
||||||
"..",
|
|
||||||
"..",
|
|
||||||
"models",
|
|
||||||
"configs",
|
"configs",
|
||||||
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
||||||
)
|
)
|
||||||
|
@ -1322,7 +1319,11 @@ def get_config_path(
|
||||||
|
|
||||||
|
|
||||||
def get_config_file(
|
def get_config_file(
|
||||||
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None
|
context: ConversionContext,
|
||||||
|
train_unfrozen=False,
|
||||||
|
v2=False,
|
||||||
|
prediction_type="epsilon",
|
||||||
|
config_file=None,
|
||||||
):
|
):
|
||||||
if config_file is not None:
|
if config_file is not None:
|
||||||
return config_file
|
return config_file
|
||||||
|
@ -1344,12 +1345,16 @@ def get_config_file(
|
||||||
model_train_type = train_types["default"]
|
model_train_type = train_types["default"]
|
||||||
|
|
||||||
return get_config_path(
|
return get_config_path(
|
||||||
model_version_name, model_train_type, config_base_name, prediction_type
|
context,
|
||||||
|
model_version_name,
|
||||||
|
model_train_type,
|
||||||
|
config_base_name,
|
||||||
|
prediction_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_checkpoint(
|
def extract_checkpoint(
|
||||||
ctx: ConversionContext,
|
context: ConversionContext,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
checkpoint_file: str,
|
checkpoint_file: str,
|
||||||
scheduler_type="ddim",
|
scheduler_type="ddim",
|
||||||
|
@ -1394,7 +1399,10 @@ def extract_checkpoint(
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = TrainingConfig(
|
db_config = TrainingConfig(
|
||||||
ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file
|
context,
|
||||||
|
model_name=new_model_name,
|
||||||
|
scheduler=scheduler_type,
|
||||||
|
src=checkpoint_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
|
@ -1434,7 +1442,7 @@ def extract_checkpoint(
|
||||||
prediction_type = "epsilon"
|
prediction_type = "epsilon"
|
||||||
|
|
||||||
original_config_file = get_config_file(
|
original_config_file = get_config_file(
|
||||||
train_unfrozen, v2, prediction_type, config_file=config_file
|
context, train_unfrozen, v2, prediction_type, config_file=config_file
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -1525,7 +1533,7 @@ def extract_checkpoint(
|
||||||
checkpoint, vae_config
|
checkpoint, vae_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
vae_file = os.path.join(context.model_path, vae_file)
|
||||||
logger.debug("loading custom VAE: %s", vae_file)
|
logger.debug("loading custom VAE: %s", vae_file)
|
||||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||||
|
|
Loading…
Reference in New Issue