1
0
Fork 0

feat(api): add an option for custom checkpoint config to extras file (fixes #130)

This commit is contained in:
Sean Sube 2023-02-12 14:10:30 -06:00
parent 82487f5771
commit d6201c9d32
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 11 additions and 3 deletions

View File

@ -1110,7 +1110,9 @@ def get_config_path(
return os.path.abspath(parts) return os.path.abspath(parts)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None):
if config_file is not None:
return config_file
config_base_name = "training" config_base_name = "training"
@ -1142,6 +1144,7 @@ def extract_checkpoint(
extract_ema=False, extract_ema=False,
train_unfrozen=False, train_unfrozen=False,
is_512=True, is_512=True,
config_file=None,
): ):
""" """
@ -1229,7 +1232,7 @@ def extract_checkpoint(
else: else:
prediction_type = "epsilon" prediction_type = "epsilon"
original_config_file = get_config_file(train_unfrozen, v2, prediction_type) original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file)
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
db_config.resolution = image_size db_config.resolution = image_size
@ -1406,6 +1409,7 @@ def convert_diffusion_original(
model: ModelDict, model: ModelDict,
source: str, source: str,
): ):
config = model["config"]
name = model["name"] name = model["name"]
source = source or model["source"] source = source or model["source"]
@ -1424,7 +1428,7 @@ def convert_diffusion_original(
logger.info("Torch pipeline already exists, reusing: %s", torch_path) logger.info("Torch pipeline already exists, reusing: %s", torch_path)
else: else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source) extract_checkpoint(ctx, torch_name, source, config_file=config)
logger.info("Converted original Diffusers checkpoint to Torch model.") logger.info("Converted original Diffusers checkpoint to Torch model.")
convert_diffusion_stable(ctx, model, working_name) convert_diffusion_stable(ctx, model, working_name)

View File

@ -33,6 +33,10 @@ $defs:
diffusion_model: diffusion_model:
allOf: allOf:
- $ref: "#/$defs/base_model" - $ref: "#/$defs/base_model"
- type: object
properties:
config:
type: string
upscaling_model: upscaling_model:
allOf: allOf: