1
0
Fork 0

replace pydantic polyfill

This commit is contained in:
Sean Sube 2023-05-20 20:19:41 -05:00
parent eb53238fd5
commit 1562f4011e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 24 additions and 1 deletions

View File

@ -57,6 +57,29 @@ from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
class Config(object):
"""
Shim for pydantic-style config.
"""
def __init__(self, kwargs):
self.__dict__.update(kwargs)
for k, v in self.__dict__.items():
Config.config_from_key(self, k, v)
def __iter__(self):
for k in self.__dict__.keys():
yield k
@classmethod
def config_from_key(cls, target, k, v):
if isinstance(v, dict):
tmp = Config(v)
setattr(target, k, tmp)
else:
setattr(target, k, v)
class TrainingConfig:
"""
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
@ -1471,7 +1494,7 @@ def extract_checkpoint(
return False
logger.debug("trying to load: %s", original_config_file)
original_config = load_config(original_config_file)
original_config = Config(load_config(original_config_file))
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start