From c45915e5587c25e7b2d87cf20b2afa2a16222be5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 7 Mar 2023 18:55:14 -0600 Subject: [PATCH] fix(api): use server model path while converting SD checkpoints (#221) --- api/onnx_web/convert/diffusion/original.py | 30 ++++++++++++++-------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index 3a304183..8833806e 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -1300,6 +1300,7 @@ def download_model(db_config: TrainingConfig, token): def get_config_path( + context: ConversionContext, model_version: str = "v1", train_type: str = "default", config_base_name: str = "training", @@ -1310,11 +1311,7 @@ def get_config_path( ) parts = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "..", - "models", + context.model_path, "configs", f"{model_version}-{config_base_name}-{train_type}.yaml", ) @@ -1322,7 +1319,11 @@ def get_config_path( def get_config_file( - train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None + context: ConversionContext, + train_unfrozen=False, + v2=False, + prediction_type="epsilon", + config_file=None, ): if config_file is not None: return config_file @@ -1344,12 +1345,16 @@ def get_config_file( model_train_type = train_types["default"] return get_config_path( - model_version_name, model_train_type, config_base_name, prediction_type + context, + model_version_name, + model_train_type, + config_base_name, + prediction_type, ) def extract_checkpoint( - ctx: ConversionContext, + context: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", @@ -1394,7 +1399,10 @@ def extract_checkpoint( # Create empty config db_config = TrainingConfig( - ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file + context, + model_name=new_model_name, + scheduler=scheduler_type, + src=checkpoint_file, ) original_config_file = None @@ -1434,7 +1442,7 @@ def extract_checkpoint( prediction_type = "epsilon" original_config_file = get_config_file( - train_unfrozen, v2, prediction_type, config_file=config_file + context, train_unfrozen, v2, prediction_type, config_file=config_file ) logger.info( @@ -1525,7 +1533,7 @@ def extract_checkpoint( checkpoint, vae_config ) else: - vae_file = os.path.join(ctx.model_path, vae_file) + vae_file = os.path.join(context.model_path, vae_file) logger.debug("loading custom VAE: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) converted_vae_checkpoint = convert_ldm_vae_checkpoint(