use nested config keys for experimental features
This commit is contained in:
parent
c05d74bda8
commit
500423424e
|
@ -35,7 +35,7 @@ from ..image import ( # mask filters; noise sources
|
||||||
from ..models.meta import NetworkModel
|
from ..models.meta import NetworkModel
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..torch_before_ort import get_available_providers
|
from ..torch_before_ort import get_available_providers
|
||||||
from ..utils import load_config, merge
|
from ..utils import load_config, merge, recursive_get
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -163,7 +163,8 @@ def get_source_filters():
|
||||||
|
|
||||||
|
|
||||||
def get_config_value(key: str, subkey: str = "default", default=None):
|
def get_config_value(key: str, subkey: str = "default", default=None):
|
||||||
return config_params.get(key, {}).get(subkey, default)
|
val = recursive_get(config_params, key.split("."), default_value={})
|
||||||
|
return val.get(subkey, default)
|
||||||
|
|
||||||
|
|
||||||
def load_extras(server: ServerContext):
|
def load_extras(server: ServerContext):
|
||||||
|
|
|
@ -355,30 +355,32 @@ def build_latent_symmetry(
|
||||||
if data is None:
|
if data is None:
|
||||||
data = request.args
|
data = request.args
|
||||||
|
|
||||||
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
|
enabled = get_boolean(
|
||||||
|
data, "latentSymmetry.enabled", get_config_value("latentSymmetry.enabled")
|
||||||
|
)
|
||||||
|
|
||||||
gradient_start = get_and_clamp_float(
|
gradient_start = get_and_clamp_float(
|
||||||
data,
|
data,
|
||||||
"gradientStart",
|
"latentSymmetry.gradientStart",
|
||||||
get_config_value("gradientStart"),
|
get_config_value("latentSymmetry.gradientStart"),
|
||||||
get_config_value("gradientStart", "max"),
|
get_config_value("latentSymmetry.gradientStart", "max"),
|
||||||
get_config_value("gradientStart", "min"),
|
get_config_value("latentSymmetry.gradientStart", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
gradient_end = get_and_clamp_float(
|
gradient_end = get_and_clamp_float(
|
||||||
data,
|
data,
|
||||||
"gradientEnd",
|
"latentSymmetry.gradientEnd",
|
||||||
get_config_value("gradientEnd"),
|
get_config_value("latentSymmetry.gradientEnd"),
|
||||||
get_config_value("gradientEnd", "max"),
|
get_config_value("latentSymmetry.gradientEnd", "max"),
|
||||||
get_config_value("gradientEnd", "min"),
|
get_config_value("latentSymmetry.gradientEnd", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
line_of_symmetry = get_and_clamp_float(
|
line_of_symmetry = get_and_clamp_float(
|
||||||
data,
|
data,
|
||||||
"lineOfSymmetry",
|
"latentSymmetry.lineOfSymmetry",
|
||||||
get_config_value("lineOfSymmetry"),
|
get_config_value("latentSymmetry.lineOfSymmetry"),
|
||||||
get_config_value("lineOfSymmetry", "max"),
|
get_config_value("latentSymmetry.lineOfSymmetry", "max"),
|
||||||
get_config_value("lineOfSymmetry", "min"),
|
get_config_value("latentSymmetry.lineOfSymmetry", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
|
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
|
||||||
|
@ -390,17 +392,25 @@ def build_prompt_editing(
|
||||||
if data is None:
|
if data is None:
|
||||||
data = request.args
|
data = request.args
|
||||||
|
|
||||||
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
|
enabled = get_boolean(
|
||||||
|
data, "promptEditing.enabled", get_config_value("promptEditing.enabled")
|
||||||
|
)
|
||||||
|
|
||||||
prompt_filter = data.get("promptFilter", "")
|
prompt_filter = data.get(
|
||||||
remove_tokens = data.get("removeTokens", "")
|
"promptEditing.filter", get_config_value("promptEditing.filter")
|
||||||
add_suffix = data.get("addSuffix", "")
|
)
|
||||||
|
remove_tokens = data.get(
|
||||||
|
"promptEditing.removeTokens", get_config_value("promptEditing.removeTokens")
|
||||||
|
)
|
||||||
|
add_suffix = data.get(
|
||||||
|
"promptEditing.addSuffix", get_config_value("promptEditing.addSuffix")
|
||||||
|
)
|
||||||
min_length = get_and_clamp_int(
|
min_length = get_and_clamp_int(
|
||||||
data,
|
data,
|
||||||
"minLength",
|
"promptEditing.minLength",
|
||||||
get_config_value("minLength"),
|
get_config_value("promptEditing.minLength"),
|
||||||
get_config_value("minLength", "max"),
|
get_config_value("promptEditing.minLength", "max"),
|
||||||
get_config_value("minLength", "min"),
|
get_config_value("promptEditing.minLength", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return PromptEditingParams(
|
return PromptEditingParams(
|
||||||
|
|
|
@ -30,12 +30,18 @@ def is_debug() -> bool:
|
||||||
return get_boolean(environ, "DEBUG", False)
|
return get_boolean(environ, "DEBUG", False)
|
||||||
|
|
||||||
|
|
||||||
def recursive_get(d, *keys):
|
def recursive_get(d, keys, default_value=None):
|
||||||
return reduce(lambda c, k: c.get(k, {}), keys, d)
|
empty_dict = {}
|
||||||
|
val = reduce(lambda c, k: c.get(k, empty_dict), keys, d)
|
||||||
|
|
||||||
|
if val == empty_dict:
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
||||||
val = args.get(key, str(default_value))
|
val = recursive_get(args, key.split("."), default_value=str(default_value))
|
||||||
|
|
||||||
if isinstance(val, bool):
|
if isinstance(val, bool):
|
||||||
return val
|
return val
|
||||||
|
@ -44,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_list(args: Any, key: str, default="") -> List[str]:
|
def get_list(args: Any, key: str, default="") -> List[str]:
|
||||||
return split_list(args.get(key, default))
|
val = recursive_get(args, key.split("."), default=default)
|
||||||
|
return split_list(val)
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(
|
def get_and_clamp_float(
|
||||||
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
||||||
) -> float:
|
) -> float:
|
||||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
val = recursive_get(args, key.split("."), default=default_value)
|
||||||
|
return min(max(float(val), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_int(
|
def get_and_clamp_int(
|
||||||
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
||||||
) -> int:
|
) -> int:
|
||||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
val = recursive_get(args, key.split("."), default=default_value)
|
||||||
|
return min(max(int(val), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
TElem = TypeVar("TElem")
|
TElem = TypeVar("TElem")
|
||||||
|
|
Loading…
Reference in New Issue