diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 4a3db675..e88bf615 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -1110,7 +1110,9 @@ def get_config_path( return os.path.abspath(parts) -def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): +def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None): + if config_file is not None: + return config_file config_base_name = "training" @@ -1142,6 +1144,7 @@ def extract_checkpoint( extract_ema=False, train_unfrozen=False, is_512=True, + config_file=None, ): """ @@ -1229,7 +1232,7 @@ def extract_checkpoint( else: prediction_type = "epsilon" - original_config_file = get_config_file(train_unfrozen, v2, prediction_type) + original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file) logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") db_config.resolution = image_size @@ -1406,6 +1409,7 @@ def convert_diffusion_original( model: ModelDict, source: str, ): + config = model["config"] name = model["name"] source = source or model["source"] @@ -1424,7 +1428,7 @@ def convert_diffusion_original( logger.info("Torch pipeline already exists, reusing: %s", torch_path) else: logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) - extract_checkpoint(ctx, torch_name, source) + extract_checkpoint(ctx, torch_name, source, config_file=config) logger.info("Converted original Diffusers checkpoint to Torch model.") convert_diffusion_stable(ctx, model, working_name) diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index b611d0e3..78c4c9b1 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -33,6 +33,10 @@ $defs: diffusion_model: allOf: - $ref: "#/$defs/base_model" + - type: object + properties: + config: + type: string upscaling_model: allOf: