use nested config keys for experimental features
This commit is contained in:
parent
c05d74bda8
commit
500423424e
|
@ -35,7 +35,7 @@ from ..image import ( # mask filters; noise sources
|
|||
from ..models.meta import NetworkModel
|
||||
from ..params import DeviceParams
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from ..utils import load_config, merge
|
||||
from ..utils import load_config, merge, recursive_get
|
||||
from .context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -163,7 +163,8 @@ def get_source_filters():
|
|||
|
||||
|
||||
def get_config_value(key: str, subkey: str = "default", default=None):
|
||||
return config_params.get(key, {}).get(subkey, default)
|
||||
val = recursive_get(config_params, key.split("."), default_value={})
|
||||
return val.get(subkey, default)
|
||||
|
||||
|
||||
def load_extras(server: ServerContext):
|
||||
|
|
|
@ -355,30 +355,32 @@ def build_latent_symmetry(
|
|||
if data is None:
|
||||
data = request.args
|
||||
|
||||
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
|
||||
enabled = get_boolean(
|
||||
data, "latentSymmetry.enabled", get_config_value("latentSymmetry.enabled")
|
||||
)
|
||||
|
||||
gradient_start = get_and_clamp_float(
|
||||
data,
|
||||
"gradientStart",
|
||||
get_config_value("gradientStart"),
|
||||
get_config_value("gradientStart", "max"),
|
||||
get_config_value("gradientStart", "min"),
|
||||
"latentSymmetry.gradientStart",
|
||||
get_config_value("latentSymmetry.gradientStart"),
|
||||
get_config_value("latentSymmetry.gradientStart", "max"),
|
||||
get_config_value("latentSymmetry.gradientStart", "min"),
|
||||
)
|
||||
|
||||
gradient_end = get_and_clamp_float(
|
||||
data,
|
||||
"gradientEnd",
|
||||
get_config_value("gradientEnd"),
|
||||
get_config_value("gradientEnd", "max"),
|
||||
get_config_value("gradientEnd", "min"),
|
||||
"latentSymmetry.gradientEnd",
|
||||
get_config_value("latentSymmetry.gradientEnd"),
|
||||
get_config_value("latentSymmetry.gradientEnd", "max"),
|
||||
get_config_value("latentSymmetry.gradientEnd", "min"),
|
||||
)
|
||||
|
||||
line_of_symmetry = get_and_clamp_float(
|
||||
data,
|
||||
"lineOfSymmetry",
|
||||
get_config_value("lineOfSymmetry"),
|
||||
get_config_value("lineOfSymmetry", "max"),
|
||||
get_config_value("lineOfSymmetry", "min"),
|
||||
"latentSymmetry.lineOfSymmetry",
|
||||
get_config_value("latentSymmetry.lineOfSymmetry"),
|
||||
get_config_value("latentSymmetry.lineOfSymmetry", "max"),
|
||||
get_config_value("latentSymmetry.lineOfSymmetry", "min"),
|
||||
)
|
||||
|
||||
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
|
||||
|
@ -390,17 +392,25 @@ def build_prompt_editing(
|
|||
if data is None:
|
||||
data = request.args
|
||||
|
||||
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
|
||||
enabled = get_boolean(
|
||||
data, "promptEditing.enabled", get_config_value("promptEditing.enabled")
|
||||
)
|
||||
|
||||
prompt_filter = data.get("promptFilter", "")
|
||||
remove_tokens = data.get("removeTokens", "")
|
||||
add_suffix = data.get("addSuffix", "")
|
||||
prompt_filter = data.get(
|
||||
"promptEditing.filter", get_config_value("promptEditing.filter")
|
||||
)
|
||||
remove_tokens = data.get(
|
||||
"promptEditing.removeTokens", get_config_value("promptEditing.removeTokens")
|
||||
)
|
||||
add_suffix = data.get(
|
||||
"promptEditing.addSuffix", get_config_value("promptEditing.addSuffix")
|
||||
)
|
||||
min_length = get_and_clamp_int(
|
||||
data,
|
||||
"minLength",
|
||||
get_config_value("minLength"),
|
||||
get_config_value("minLength", "max"),
|
||||
get_config_value("minLength", "min"),
|
||||
"promptEditing.minLength",
|
||||
get_config_value("promptEditing.minLength"),
|
||||
get_config_value("promptEditing.minLength", "max"),
|
||||
get_config_value("promptEditing.minLength", "min"),
|
||||
)
|
||||
|
||||
return PromptEditingParams(
|
||||
|
|
|
@ -30,12 +30,18 @@ def is_debug() -> bool:
|
|||
return get_boolean(environ, "DEBUG", False)
|
||||
|
||||
|
||||
def recursive_get(d, *keys):
|
||||
return reduce(lambda c, k: c.get(k, {}), keys, d)
|
||||
def recursive_get(d, keys, default_value=None):
|
||||
empty_dict = {}
|
||||
val = reduce(lambda c, k: c.get(k, empty_dict), keys, d)
|
||||
|
||||
if val == empty_dict:
|
||||
return default_value
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
||||
val = args.get(key, str(default_value))
|
||||
val = recursive_get(args, key.split("."), default_value=str(default_value))
|
||||
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
|
@ -44,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
|||
|
||||
|
||||
def get_list(args: Any, key: str, default="") -> List[str]:
|
||||
return split_list(args.get(key, default))
|
||||
val = recursive_get(args, key.split("."), default=default)
|
||||
return split_list(val)
|
||||
|
||||
|
||||
def get_and_clamp_float(
|
||||
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
||||
) -> float:
|
||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||
val = recursive_get(args, key.split("."), default=default_value)
|
||||
return min(max(float(val), min_value), max_value)
|
||||
|
||||
|
||||
def get_and_clamp_int(
|
||||
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
||||
) -> int:
|
||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
val = recursive_get(args, key.split("."), default=default_value)
|
||||
return min(max(int(val), min_value), max_value)
|
||||
|
||||
|
||||
TElem = TypeVar("TElem")
|
||||
|
|
Loading…
Reference in New Issue