diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index c47dd836..880167af 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -57,6 +57,29 @@ from ..utils import ConversionContext, load_tensor logger = getLogger(__name__) +class Config(object): + """ + Shim for pydantic-style config. + """ + + def __init__(self, kwargs): + self.__dict__.update(kwargs) + for k, v in self.__dict__.items(): + Config.config_from_key(self, k, v) + + def __iter__(self): + for k in self.__dict__.keys(): + yield k + + @classmethod + def config_from_key(cls, target, k, v): + if isinstance(v, dict): + tmp = Config(v) + setattr(target, k, tmp) + else: + setattr(target, k, v) + + class TrainingConfig: """ From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py @@ -1471,7 +1494,7 @@ def extract_checkpoint( return False logger.debug("trying to load: %s", original_config_file) - original_config = load_config(original_config_file) + original_config = Config(load_config(original_config_file)) num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start