1
0
Fork 0

use nested config keys for experimental features

This commit is contained in:
Sean Sube 2024-02-24 11:21:51 -06:00
parent c05d74bda8
commit 500423424e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 49 additions and 29 deletions

View File

@ -35,7 +35,7 @@ from ..image import ( # mask filters; noise sources
from ..models.meta import NetworkModel from ..models.meta import NetworkModel
from ..params import DeviceParams from ..params import DeviceParams
from ..torch_before_ort import get_available_providers from ..torch_before_ort import get_available_providers
from ..utils import load_config, merge from ..utils import load_config, merge, recursive_get
from .context import ServerContext from .context import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -163,7 +163,8 @@ def get_source_filters():
def get_config_value(key: str, subkey: str = "default", default=None): def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default) val = recursive_get(config_params, key.split("."), default_value={})
return val.get(subkey, default)
def load_extras(server: ServerContext): def load_extras(server: ServerContext):

View File

@ -355,30 +355,32 @@ def build_latent_symmetry(
if data is None: if data is None:
data = request.args data = request.args
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry")) enabled = get_boolean(
data, "latentSymmetry.enabled", get_config_value("latentSymmetry.enabled")
)
gradient_start = get_and_clamp_float( gradient_start = get_and_clamp_float(
data, data,
"gradientStart", "latentSymmetry.gradientStart",
get_config_value("gradientStart"), get_config_value("latentSymmetry.gradientStart"),
get_config_value("gradientStart", "max"), get_config_value("latentSymmetry.gradientStart", "max"),
get_config_value("gradientStart", "min"), get_config_value("latentSymmetry.gradientStart", "min"),
) )
gradient_end = get_and_clamp_float( gradient_end = get_and_clamp_float(
data, data,
"gradientEnd", "latentSymmetry.gradientEnd",
get_config_value("gradientEnd"), get_config_value("latentSymmetry.gradientEnd"),
get_config_value("gradientEnd", "max"), get_config_value("latentSymmetry.gradientEnd", "max"),
get_config_value("gradientEnd", "min"), get_config_value("latentSymmetry.gradientEnd", "min"),
) )
line_of_symmetry = get_and_clamp_float( line_of_symmetry = get_and_clamp_float(
data, data,
"lineOfSymmetry", "latentSymmetry.lineOfSymmetry",
get_config_value("lineOfSymmetry"), get_config_value("latentSymmetry.lineOfSymmetry"),
get_config_value("lineOfSymmetry", "max"), get_config_value("latentSymmetry.lineOfSymmetry", "max"),
get_config_value("lineOfSymmetry", "min"), get_config_value("latentSymmetry.lineOfSymmetry", "min"),
) )
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry) return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
@ -390,17 +392,25 @@ def build_prompt_editing(
if data is None: if data is None:
data = request.args data = request.args
enabled = get_boolean(data, "enabled", get_config_value("promptEditing")) enabled = get_boolean(
data, "promptEditing.enabled", get_config_value("promptEditing.enabled")
)
prompt_filter = data.get("promptFilter", "") prompt_filter = data.get(
remove_tokens = data.get("removeTokens", "") "promptEditing.filter", get_config_value("promptEditing.filter")
add_suffix = data.get("addSuffix", "") )
remove_tokens = data.get(
"promptEditing.removeTokens", get_config_value("promptEditing.removeTokens")
)
add_suffix = data.get(
"promptEditing.addSuffix", get_config_value("promptEditing.addSuffix")
)
min_length = get_and_clamp_int( min_length = get_and_clamp_int(
data, data,
"minLength", "promptEditing.minLength",
get_config_value("minLength"), get_config_value("promptEditing.minLength"),
get_config_value("minLength", "max"), get_config_value("promptEditing.minLength", "max"),
get_config_value("minLength", "min"), get_config_value("promptEditing.minLength", "min"),
) )
return PromptEditingParams( return PromptEditingParams(

View File

@ -30,12 +30,18 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False) return get_boolean(environ, "DEBUG", False)
def recursive_get(d, *keys): def recursive_get(d, keys, default_value=None):
return reduce(lambda c, k: c.get(k, {}), keys, d) empty_dict = {}
val = reduce(lambda c, k: c.get(k, empty_dict), keys, d)
if val == empty_dict:
return default_value
return val
def get_boolean(args: Any, key: str, default_value: bool) -> bool: def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = args.get(key, str(default_value)) val = recursive_get(args, key.split("."), default_value=str(default_value))
if isinstance(val, bool): if isinstance(val, bool):
return val return val
@ -44,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool:
def get_list(args: Any, key: str, default="") -> List[str]: def get_list(args: Any, key: str, default="") -> List[str]:
return split_list(args.get(key, default)) val = recursive_get(args, key.split("."), default=default)
return split_list(val)
def get_and_clamp_float( def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0 args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float: ) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) val = recursive_get(args, key.split("."), default=default_value)
return min(max(float(val), min_value), max_value)
def get_and_clamp_int( def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1 args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int: ) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value) val = recursive_get(args, key.split("."), default=default_value)
return min(max(int(val), min_value), max_value)
TElem = TypeVar("TElem") TElem = TypeVar("TElem")