feat(api): add an option for custom checkpoint config to extras file (fixes #130)
This commit is contained in:
parent
82487f5771
commit
d6201c9d32
|
@ -1110,7 +1110,9 @@ def get_config_path(
|
||||||
return os.path.abspath(parts)
|
return os.path.abspath(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None):
|
||||||
|
if config_file is not None:
|
||||||
|
return config_file
|
||||||
|
|
||||||
config_base_name = "training"
|
config_base_name = "training"
|
||||||
|
|
||||||
|
@ -1142,6 +1144,7 @@ def extract_checkpoint(
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
is_512=True,
|
is_512=True,
|
||||||
|
config_file=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1229,7 +1232,7 @@ def extract_checkpoint(
|
||||||
else:
|
else:
|
||||||
prediction_type = "epsilon"
|
prediction_type = "epsilon"
|
||||||
|
|
||||||
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
|
original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file)
|
||||||
|
|
||||||
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
||||||
db_config.resolution = image_size
|
db_config.resolution = image_size
|
||||||
|
@ -1406,6 +1409,7 @@ def convert_diffusion_original(
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
):
|
||||||
|
config = model["config"]
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = source or model["source"]
|
source = source or model["source"]
|
||||||
|
|
||||||
|
@ -1424,7 +1428,7 @@ def convert_diffusion_original(
|
||||||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, source)
|
extract_checkpoint(ctx, torch_name, source, config_file=config)
|
||||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model, working_name)
|
convert_diffusion_stable(ctx, model, working_name)
|
||||||
|
|
|
@ -33,6 +33,10 @@ $defs:
|
||||||
diffusion_model:
|
diffusion_model:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: "#/$defs/base_model"
|
- $ref: "#/$defs/base_model"
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
config:
|
||||||
|
type: string
|
||||||
|
|
||||||
upscaling_model:
|
upscaling_model:
|
||||||
allOf:
|
allOf:
|
||||||
|
|
Loading…
Reference in New Issue