From 5d459ab17c6f6ede6e563e11d4294591c426d90c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 4 Mar 2023 22:57:31 -0600 Subject: [PATCH] feat: allow users to add their own labels for models (#144) --- api/onnx_web/convert/__main__.py | 20 ++++++---- api/onnx_web/server/api.py | 5 +++ api/onnx_web/server/context.py | 3 ++ api/onnx_web/server/load.py | 65 +++++++++++++++++++++++++++++++- api/onnx_web/utils.py | 16 ++++++++ api/schemas/extras.yaml | 9 ++++- gui/src/client.ts | 10 +++++ gui/src/main.tsx | 5 ++- 8 files changed, 123 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index ccfef00f..8133f07a 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -353,20 +353,26 @@ def main() -> int: logger.info("converting base models") convert_models(ctx, args, base_models) - for file in args.extras: + extras = [] + extras.extend(ctx.extra_models) + extras.extend(args.extras) + extras = list(set(extras)) + extras.sort() + logger.debug("loading extra files: %s", extras) + + with open("./schemas/extras.yaml", "r") as f: + extra_schema = safe_load(f.read()) + + for file in extras: if file is not None and file != "": logger.info("loading extra models from %s", file) try: with open(file, "r") as f: data = safe_load(f.read()) - with open("./schemas/extras.yaml", "r") as f: - schema = safe_load(f.read()) - - logger.debug("validating chain request: %s against %s", data, schema) - + logger.debug("validating extras file %s", data) try: - validate(data, schema) + validate(data, extra_schema) logger.info("converting extra models") convert_models(ctx, args, data) except ValidationError as err: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index d9b926ad..47c5773a 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask): } +def get_extra_strings(context: ServerContext): + return jsonify(get_extra_strings()) + + def list_mask_filters(context: ServerContext): return jsonify(list(get_mask_filters().keys())) @@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec app.route("/api/settings/params")(wrap_route(list_params, context)), app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)), app.route("/api/img2img", methods=["POST"])( wrap_route(img2img, context, pool=pool) ), diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 3107a79b..2e60bd41 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -25,6 +25,7 @@ class ServerContext: cache_path: Optional[str] = None, show_progress: bool = True, optimizations: Optional[List[str]] = None, + extra_models: Optional[List[str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -40,6 +41,7 @@ class ServerContext: self.cache_path = cache_path or path.join(model_path, ".cache") self.show_progress = show_progress self.optimizations = optimizations or [] + self.extra_models = extra_models or [] @classmethod def from_environ(cls): @@ -63,4 +65,5 @@ class ServerContext: cache=ModelCache(limit=cache_limit), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), + extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), ) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 18ee2cf5..83095912 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -2,11 +2,14 @@ from functools import cmp_to_key from glob import glob from logging import getLogger from os import path -from typing import Dict, List, Union +from typing import Any, Dict, List, Union +from jsonschema import ValidationError, validate import torch import yaml +from yaml import safe_load +from ..utils import merge from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -58,6 +61,9 @@ diffusion_models: List[str] = [] inversion_models: List[str] = [] upscaling_models: List[str] = [] +# Loaded from extra_models +extra_strings: Dict[str, Any] = {} + def get_config_params(): return config_params @@ -83,6 +89,10 @@ def get_upscaling_models(): return upscaling_models +def get_extra_strings(): + return extra_strings + + def get_mask_filters(): return mask_filters @@ -101,6 +111,59 @@ def get_model_name(model: str) -> str: return file +def load_extras(context: ServerContext): + """ + Load the extras file(s) and collect the relevant parts for the server: labels and strings + """ + global extra_strings + + labels = {} + strings = {} + + with open("./schemas/extras.yaml", "r") as f: + extra_schema = safe_load(f.read()) + + for file in context.extra_models: + if file is not None and file != "": + logger.info("loading extra models from %s", file) + try: + with open(file, "r") as f: + data = safe_load(f.read()) + + logger.debug("validating extras file %s", data) + try: + validate(data, extra_schema) + except ValidationError as err: + logger.error("invalid data in extras file: %s", err) + continue + + if "strings" in data: + logger.debug("collecting strings from %s", file) + merge(strings, data["strings"]) + + for model_type in ["diffusion", "correction", "upscaling"]: + if model_type in data: + for model in data[model_type]: + if "label" in model: + model_name = model["name"] + logger.debug("collecting label for model %s from %s", model_name, file) + labels[model_name] = model["label"] + + except Exception as err: + logger.error("error loading extras file: %s", err) + + logger.debug("adding labels to strings: %s", labels) + merge(strings, { + "en": { + "translation": { + "model": labels, + } + } + }) + + extra_strings = strings + + def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: models = [] for pattern in globs: diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 74f998f8..52445c6d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None): def sanitize_name(name): return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS)) + + +def merge(a, b, path=None): + "merges b into a" + if path is None: path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + merge(a[key], b[key], path + [str(key)]) + elif a[key] == b[key]: + pass # same leaf value + else: + raise Exception("Conflict at %s" % '.'.join(path + [str(key)])) + else: + a[key] = b[key] + return a diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index 6d8b8203..97844441 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -21,6 +21,8 @@ $defs: format: type: string enum: [concept, embeddings] + label: + type: string token: type: string @@ -33,6 +35,8 @@ $defs: enum: [onnx, pth, ckpt, safetensors] half: type: boolean + label: + type: string name: type: string opset: @@ -104,4 +108,7 @@ properties: items: oneOf: - $ref: "#/$defs/legacy_tuple" - - $ref: "#/$defs/source_model" \ No newline at end of file + - $ref: "#/$defs/source_model" + strings: + type: object + # /\w{2}/: translation: {} \ No newline at end of file diff --git a/gui/src/client.ts b/gui/src/client.ts index d16e70c9..bd67b480 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -220,6 +220,11 @@ export interface ApiClient { */ schedulers(): Promise>; + /** + * Load extra strings from the server. + */ + strings(): Promise>; + /** * Start a txt2img pipeline. */ @@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, + async strings(): Promise> { + const path = makeApiUrl(root, 'settings', 'strings'); + const res = await f(path); + return await res.json() as Record; + }, async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); diff --git a/gui/src/main.tsx b/gui/src/main.tsx index ed93988e..1214eafa 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -64,7 +64,10 @@ export async function main() { returnEmptyString: false, }); - i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys); + const strings = await client.strings(); + for (const [lang, data] of Object.entries(strings)) { + i18n.addResourceBundle(lang, 'translation', data, true); + } // prep zustand with a slice for each tab, using local storage const {