diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 78f875ae..16bbfcc8 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -35,7 +35,7 @@ from ..image import ( # mask filters; noise sources from ..models.meta import NetworkModel from ..params import DeviceParams from ..torch_before_ort import get_available_providers -from ..utils import load_config, merge +from ..utils import load_config, merge, recursive_get from .context import ServerContext logger = getLogger(__name__) @@ -163,7 +163,8 @@ def get_source_filters(): def get_config_value(key: str, subkey: str = "default", default=None): - return config_params.get(key, {}).get(subkey, default) + val = recursive_get(config_params, key.split("."), default_value={}) + return val.get(subkey, default) def load_extras(server: ServerContext): diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 0e2e7c98..2e0cc16e 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -355,30 +355,32 @@ def build_latent_symmetry( if data is None: data = request.args - enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry")) + enabled = get_boolean( + data, "latentSymmetry.enabled", get_config_value("latentSymmetry.enabled") + ) gradient_start = get_and_clamp_float( data, - "gradientStart", - get_config_value("gradientStart"), - get_config_value("gradientStart", "max"), - get_config_value("gradientStart", "min"), + "latentSymmetry.gradientStart", + get_config_value("latentSymmetry.gradientStart"), + get_config_value("latentSymmetry.gradientStart", "max"), + get_config_value("latentSymmetry.gradientStart", "min"), ) gradient_end = get_and_clamp_float( data, - "gradientEnd", - get_config_value("gradientEnd"), - get_config_value("gradientEnd", "max"), - get_config_value("gradientEnd", "min"), + "latentSymmetry.gradientEnd", + get_config_value("latentSymmetry.gradientEnd"), + get_config_value("latentSymmetry.gradientEnd", "max"), + get_config_value("latentSymmetry.gradientEnd", "min"), ) line_of_symmetry = get_and_clamp_float( data, - "lineOfSymmetry", - get_config_value("lineOfSymmetry"), - get_config_value("lineOfSymmetry", "max"), - get_config_value("lineOfSymmetry", "min"), + "latentSymmetry.lineOfSymmetry", + get_config_value("latentSymmetry.lineOfSymmetry"), + get_config_value("latentSymmetry.lineOfSymmetry", "max"), + get_config_value("latentSymmetry.lineOfSymmetry", "min"), ) return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry) @@ -390,17 +392,25 @@ def build_prompt_editing( if data is None: data = request.args - enabled = get_boolean(data, "enabled", get_config_value("promptEditing")) + enabled = get_boolean( + data, "promptEditing.enabled", get_config_value("promptEditing.enabled") + ) - prompt_filter = data.get("promptFilter", "") - remove_tokens = data.get("removeTokens", "") - add_suffix = data.get("addSuffix", "") + prompt_filter = data.get( + "promptEditing.filter", get_config_value("promptEditing.filter") + ) + remove_tokens = data.get( + "promptEditing.removeTokens", get_config_value("promptEditing.removeTokens") + ) + add_suffix = data.get( + "promptEditing.addSuffix", get_config_value("promptEditing.addSuffix") + ) min_length = get_and_clamp_int( data, - "minLength", - get_config_value("minLength"), - get_config_value("minLength", "max"), - get_config_value("minLength", "min"), + "promptEditing.minLength", + get_config_value("promptEditing.minLength"), + get_config_value("promptEditing.minLength", "max"), + get_config_value("promptEditing.minLength", "min"), ) return PromptEditingParams( diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 5d36f918..74ebbd90 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -30,12 +30,18 @@ def is_debug() -> bool: return get_boolean(environ, "DEBUG", False) -def recursive_get(d, *keys): - return reduce(lambda c, k: c.get(k, {}), keys, d) +def recursive_get(d, keys, default_value=None): + empty_dict = {} + val = reduce(lambda c, k: c.get(k, empty_dict), keys, d) + + if val == empty_dict: + return default_value + + return val def get_boolean(args: Any, key: str, default_value: bool) -> bool: - val = args.get(key, str(default_value)) + val = recursive_get(args, key.split("."), default_value=str(default_value)) if isinstance(val, bool): return val @@ -44,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool: def get_list(args: Any, key: str, default="") -> List[str]: - return split_list(args.get(key, default)) + val = recursive_get(args, key.split("."), default=default) + return split_list(val) def get_and_clamp_float( args: Any, key: str, default_value: float, max_value: float, min_value=0.0 ) -> float: - return min(max(float(args.get(key, default_value)), min_value), max_value) + val = recursive_get(args, key.split("."), default=default_value) + return min(max(float(val), min_value), max_value) def get_and_clamp_int( args: Any, key: str, default_value: int, max_value: int, min_value=1 ) -> int: - return min(max(int(args.get(key, default_value)), min_value), max_value) + val = recursive_get(args, key.split("."), default=default_value) + return min(max(int(val), min_value), max_value) TElem = TypeVar("TElem")