1
0
Fork 0

use nested config keys for experimental features

This commit is contained in:
Sean Sube 2024-02-24 11:21:51 -06:00
parent c05d74bda8
commit 500423424e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 49 additions and 29 deletions

View File

@ -35,7 +35,7 @@ from ..image import ( # mask filters; noise sources
from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from ..utils import load_config, merge
from ..utils import load_config, merge, recursive_get
from .context import ServerContext
logger = getLogger(__name__)
@ -163,7 +163,8 @@ def get_source_filters():
def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default)
val = recursive_get(config_params, key.split("."), default_value={})
return val.get(subkey, default)
def load_extras(server: ServerContext):

View File

@ -355,30 +355,32 @@ def build_latent_symmetry(
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
enabled = get_boolean(
data, "latentSymmetry.enabled", get_config_value("latentSymmetry.enabled")
)
gradient_start = get_and_clamp_float(
data,
"gradientStart",
get_config_value("gradientStart"),
get_config_value("gradientStart", "max"),
get_config_value("gradientStart", "min"),
"latentSymmetry.gradientStart",
get_config_value("latentSymmetry.gradientStart"),
get_config_value("latentSymmetry.gradientStart", "max"),
get_config_value("latentSymmetry.gradientStart", "min"),
)
gradient_end = get_and_clamp_float(
data,
"gradientEnd",
get_config_value("gradientEnd"),
get_config_value("gradientEnd", "max"),
get_config_value("gradientEnd", "min"),
"latentSymmetry.gradientEnd",
get_config_value("latentSymmetry.gradientEnd"),
get_config_value("latentSymmetry.gradientEnd", "max"),
get_config_value("latentSymmetry.gradientEnd", "min"),
)
line_of_symmetry = get_and_clamp_float(
data,
"lineOfSymmetry",
get_config_value("lineOfSymmetry"),
get_config_value("lineOfSymmetry", "max"),
get_config_value("lineOfSymmetry", "min"),
"latentSymmetry.lineOfSymmetry",
get_config_value("latentSymmetry.lineOfSymmetry"),
get_config_value("latentSymmetry.lineOfSymmetry", "max"),
get_config_value("latentSymmetry.lineOfSymmetry", "min"),
)
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
@ -390,17 +392,25 @@ def build_prompt_editing(
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
enabled = get_boolean(
data, "promptEditing.enabled", get_config_value("promptEditing.enabled")
)
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
prompt_filter = data.get(
"promptEditing.filter", get_config_value("promptEditing.filter")
)
remove_tokens = data.get(
"promptEditing.removeTokens", get_config_value("promptEditing.removeTokens")
)
add_suffix = data.get(
"promptEditing.addSuffix", get_config_value("promptEditing.addSuffix")
)
min_length = get_and_clamp_int(
data,
"minLength",
get_config_value("minLength"),
get_config_value("minLength", "max"),
get_config_value("minLength", "min"),
"promptEditing.minLength",
get_config_value("promptEditing.minLength"),
get_config_value("promptEditing.minLength", "max"),
get_config_value("promptEditing.minLength", "min"),
)
return PromptEditingParams(

View File

@ -30,12 +30,18 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False)
def recursive_get(d, *keys):
return reduce(lambda c, k: c.get(k, {}), keys, d)
def recursive_get(d, keys, default_value=None):
empty_dict = {}
val = reduce(lambda c, k: c.get(k, empty_dict), keys, d)
if val == empty_dict:
return default_value
return val
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = args.get(key, str(default_value))
val = recursive_get(args, key.split("."), default_value=str(default_value))
if isinstance(val, bool):
return val
@ -44,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool:
def get_list(args: Any, key: str, default="") -> List[str]:
return split_list(args.get(key, default))
val = recursive_get(args, key.split("."), default=default)
return split_list(val)
def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
val = recursive_get(args, key.split("."), default=default_value)
return min(max(float(val), min_value), max_value)
def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
val = recursive_get(args, key.split("."), default=default_value)
return min(max(int(val), min_value), max_value)
TElem = TypeVar("TElem")