replace pydantic polyfill
This commit is contained in:
parent
eb53238fd5
commit
1562f4011e
|
@ -57,6 +57,29 @@ from ..utils import ConversionContext, load_tensor
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Config(object):
|
||||
"""
|
||||
Shim for pydantic-style config.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
for k, v in self.__dict__.items():
|
||||
Config.config_from_key(self, k, v)
|
||||
|
||||
def __iter__(self):
|
||||
for k in self.__dict__.keys():
|
||||
yield k
|
||||
|
||||
@classmethod
|
||||
def config_from_key(cls, target, k, v):
|
||||
if isinstance(v, dict):
|
||||
tmp = Config(v)
|
||||
setattr(target, k, tmp)
|
||||
else:
|
||||
setattr(target, k, v)
|
||||
|
||||
|
||||
class TrainingConfig:
|
||||
"""
|
||||
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
|
||||
|
@ -1471,7 +1494,7 @@ def extract_checkpoint(
|
|||
return False
|
||||
|
||||
logger.debug("trying to load: %s", original_config_file)
|
||||
original_config = load_config(original_config_file)
|
||||
original_config = Config(load_config(original_config_file))
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
|
|
Loading…
Reference in New Issue