feat(api): add an option for custom checkpoint config to extras file (fixes #130)
This commit is contained in:
parent
82487f5771
commit
d6201c9d32
|
@ -1110,7 +1110,9 @@ def get_config_path(
|
|||
return os.path.abspath(parts)
|
||||
|
||||
|
||||
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
||||
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None):
|
||||
if config_file is not None:
|
||||
return config_file
|
||||
|
||||
config_base_name = "training"
|
||||
|
||||
|
@ -1142,6 +1144,7 @@ def extract_checkpoint(
|
|||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
is_512=True,
|
||||
config_file=None,
|
||||
):
|
||||
"""
|
||||
|
||||
|
@ -1229,7 +1232,7 @@ def extract_checkpoint(
|
|||
else:
|
||||
prediction_type = "epsilon"
|
||||
|
||||
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
|
||||
original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file)
|
||||
|
||||
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
||||
db_config.resolution = image_size
|
||||
|
@ -1406,6 +1409,7 @@ def convert_diffusion_original(
|
|||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
config = model["config"]
|
||||
name = model["name"]
|
||||
source = source or model["source"]
|
||||
|
||||
|
@ -1424,7 +1428,7 @@ def convert_diffusion_original(
|
|||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||
else:
|
||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, source)
|
||||
extract_checkpoint(ctx, torch_name, source, config_file=config)
|
||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||
|
||||
convert_diffusion_stable(ctx, model, working_name)
|
||||
|
|
|
@ -33,6 +33,10 @@ $defs:
|
|||
diffusion_model:
|
||||
allOf:
|
||||
- $ref: "#/$defs/base_model"
|
||||
- type: object
|
||||
properties:
|
||||
config:
|
||||
type: string
|
||||
|
||||
upscaling_model:
|
||||
allOf:
|
||||
|
|
Loading…
Reference in New Issue