fix(api): use server model path while converting SD checkpoints (#221)
This commit is contained in:
parent
9d9bd1a639
commit
c45915e558
|
@ -1300,6 +1300,7 @@ def download_model(db_config: TrainingConfig, token):
|
|||
|
||||
|
||||
def get_config_path(
|
||||
context: ConversionContext,
|
||||
model_version: str = "v1",
|
||||
train_type: str = "default",
|
||||
config_base_name: str = "training",
|
||||
|
@ -1310,11 +1311,7 @@ def get_config_path(
|
|||
)
|
||||
|
||||
parts = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"models",
|
||||
context.model_path,
|
||||
"configs",
|
||||
f"{model_version}-{config_base_name}-{train_type}.yaml",
|
||||
)
|
||||
|
@ -1322,7 +1319,11 @@ def get_config_path(
|
|||
|
||||
|
||||
def get_config_file(
|
||||
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None
|
||||
context: ConversionContext,
|
||||
train_unfrozen=False,
|
||||
v2=False,
|
||||
prediction_type="epsilon",
|
||||
config_file=None,
|
||||
):
|
||||
if config_file is not None:
|
||||
return config_file
|
||||
|
@ -1344,12 +1345,16 @@ def get_config_file(
|
|||
model_train_type = train_types["default"]
|
||||
|
||||
return get_config_path(
|
||||
model_version_name, model_train_type, config_base_name, prediction_type
|
||||
context,
|
||||
model_version_name,
|
||||
model_train_type,
|
||||
config_base_name,
|
||||
prediction_type,
|
||||
)
|
||||
|
||||
|
||||
def extract_checkpoint(
|
||||
ctx: ConversionContext,
|
||||
context: ConversionContext,
|
||||
new_model_name: str,
|
||||
checkpoint_file: str,
|
||||
scheduler_type="ddim",
|
||||
|
@ -1394,7 +1399,10 @@ def extract_checkpoint(
|
|||
|
||||
# Create empty config
|
||||
db_config = TrainingConfig(
|
||||
ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file
|
||||
context,
|
||||
model_name=new_model_name,
|
||||
scheduler=scheduler_type,
|
||||
src=checkpoint_file,
|
||||
)
|
||||
|
||||
original_config_file = None
|
||||
|
@ -1434,7 +1442,7 @@ def extract_checkpoint(
|
|||
prediction_type = "epsilon"
|
||||
|
||||
original_config_file = get_config_file(
|
||||
train_unfrozen, v2, prediction_type, config_file=config_file
|
||||
context, train_unfrozen, v2, prediction_type, config_file=config_file
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
@ -1525,7 +1533,7 @@ def extract_checkpoint(
|
|||
checkpoint, vae_config
|
||||
)
|
||||
else:
|
||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||
vae_file = os.path.join(context.model_path, vae_file)
|
||||
logger.debug("loading custom VAE: %s", vae_file)
|
||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
|
|
Loading…
Reference in New Issue