feat: allow users to add their own labels for models (#144)
This commit is contained in:
parent
628812fb0b
commit
5d459ab17c
|
@ -353,20 +353,26 @@ def main() -> int:
|
|||
logger.info("converting base models")
|
||||
convert_models(ctx, args, base_models)
|
||||
|
||||
for file in args.extras:
|
||||
extras = []
|
||||
extras.extend(ctx.extra_models)
|
||||
extras.extend(args.extras)
|
||||
extras = list(set(extras))
|
||||
extras.sort()
|
||||
logger.debug("loading extra files: %s", extras)
|
||||
|
||||
with open("./schemas/extras.yaml", "r") as f:
|
||||
extra_schema = safe_load(f.read())
|
||||
|
||||
for file in extras:
|
||||
if file is not None and file != "":
|
||||
logger.info("loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, "r") as f:
|
||||
data = safe_load(f.read())
|
||||
|
||||
with open("./schemas/extras.yaml", "r") as f:
|
||||
schema = safe_load(f.read())
|
||||
|
||||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
|
||||
logger.debug("validating extras file %s", data)
|
||||
try:
|
||||
validate(data, schema)
|
||||
validate(data, extra_schema)
|
||||
logger.info("converting extra models")
|
||||
convert_models(ctx, args, data)
|
||||
except ValidationError as err:
|
||||
|
|
|
@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask):
|
|||
}
|
||||
|
||||
|
||||
def get_extra_strings(context: ServerContext):
|
||||
return jsonify(get_extra_strings())
|
||||
|
||||
|
||||
def list_mask_filters(context: ServerContext):
|
||||
return jsonify(list(get_mask_filters().keys()))
|
||||
|
||||
|
@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec
|
|||
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
||||
app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)),
|
||||
app.route("/api/img2img", methods=["POST"])(
|
||||
wrap_route(img2img, context, pool=pool)
|
||||
),
|
||||
|
|
|
@ -25,6 +25,7 @@ class ServerContext:
|
|||
cache_path: Optional[str] = None,
|
||||
show_progress: bool = True,
|
||||
optimizations: Optional[List[str]] = None,
|
||||
extra_models: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -40,6 +41,7 @@ class ServerContext:
|
|||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||
self.show_progress = show_progress
|
||||
self.optimizations = optimizations or []
|
||||
self.extra_models = extra_models or []
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -63,4 +65,5 @@ class ServerContext:
|
|||
cache=ModelCache(limit=cache_limit),
|
||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||
)
|
||||
|
|
|
@ -2,11 +2,14 @@ from functools import cmp_to_key
|
|||
from glob import glob
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Dict, List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from yaml import safe_load
|
||||
|
||||
from ..utils import merge
|
||||
from ..image import ( # mask filters; noise sources
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
|
@ -58,6 +61,9 @@ diffusion_models: List[str] = []
|
|||
inversion_models: List[str] = []
|
||||
upscaling_models: List[str] = []
|
||||
|
||||
# Loaded from extra_models
|
||||
extra_strings: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def get_config_params():
|
||||
return config_params
|
||||
|
@ -83,6 +89,10 @@ def get_upscaling_models():
|
|||
return upscaling_models
|
||||
|
||||
|
||||
def get_extra_strings():
|
||||
return extra_strings
|
||||
|
||||
|
||||
def get_mask_filters():
|
||||
return mask_filters
|
||||
|
||||
|
@ -101,6 +111,59 @@ def get_model_name(model: str) -> str:
|
|||
return file
|
||||
|
||||
|
||||
def load_extras(context: ServerContext):
|
||||
"""
|
||||
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||
"""
|
||||
global extra_strings
|
||||
|
||||
labels = {}
|
||||
strings = {}
|
||||
|
||||
with open("./schemas/extras.yaml", "r") as f:
|
||||
extra_schema = safe_load(f.read())
|
||||
|
||||
for file in context.extra_models:
|
||||
if file is not None and file != "":
|
||||
logger.info("loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, "r") as f:
|
||||
data = safe_load(f.read())
|
||||
|
||||
logger.debug("validating extras file %s", data)
|
||||
try:
|
||||
validate(data, extra_schema)
|
||||
except ValidationError as err:
|
||||
logger.error("invalid data in extras file: %s", err)
|
||||
continue
|
||||
|
||||
if "strings" in data:
|
||||
logger.debug("collecting strings from %s", file)
|
||||
merge(strings, data["strings"])
|
||||
|
||||
for model_type in ["diffusion", "correction", "upscaling"]:
|
||||
if model_type in data:
|
||||
for model in data[model_type]:
|
||||
if "label" in model:
|
||||
model_name = model["name"]
|
||||
logger.debug("collecting label for model %s from %s", model_name, file)
|
||||
labels[model_name] = model["label"]
|
||||
|
||||
except Exception as err:
|
||||
logger.error("error loading extras file: %s", err)
|
||||
|
||||
logger.debug("adding labels to strings: %s", labels)
|
||||
merge(strings, {
|
||||
"en": {
|
||||
"translation": {
|
||||
"model": labels,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
extra_strings = strings
|
||||
|
||||
|
||||
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
||||
models = []
|
||||
for pattern in globs:
|
||||
|
|
|
@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
|
|||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
|
||||
|
||||
|
||||
def merge(a, b, path=None):
|
||||
"merges b into a"
|
||||
if path is None: path = []
|
||||
for key in b:
|
||||
if key in a:
|
||||
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||
merge(a[key], b[key], path + [str(key)])
|
||||
elif a[key] == b[key]:
|
||||
pass # same leaf value
|
||||
else:
|
||||
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
|
||||
else:
|
||||
a[key] = b[key]
|
||||
return a
|
||||
|
|
|
@ -21,6 +21,8 @@ $defs:
|
|||
format:
|
||||
type: string
|
||||
enum: [concept, embeddings]
|
||||
label:
|
||||
type: string
|
||||
token:
|
||||
type: string
|
||||
|
||||
|
@ -33,6 +35,8 @@ $defs:
|
|||
enum: [onnx, pth, ckpt, safetensors]
|
||||
half:
|
||||
type: boolean
|
||||
label:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
opset:
|
||||
|
@ -105,3 +109,6 @@ properties:
|
|||
oneOf:
|
||||
- $ref: "#/$defs/legacy_tuple"
|
||||
- $ref: "#/$defs/source_model"
|
||||
strings:
|
||||
type: object
|
||||
# /\w{2}/: translation: {}
|
|
@ -220,6 +220,11 @@ export interface ApiClient {
|
|||
*/
|
||||
schedulers(): Promise<Array<string>>;
|
||||
|
||||
/**
|
||||
* Load extra strings from the server.
|
||||
*/
|
||||
strings(): Promise<Record<string, unknown>>;
|
||||
|
||||
/**
|
||||
* Start a txt2img pipeline.
|
||||
*/
|
||||
|
@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async strings(): Promise<Record<string, unknown>> {
|
||||
const path = makeApiUrl(root, 'settings', 'strings');
|
||||
const res = await f(path);
|
||||
return await res.json() as Record<string, unknown>;
|
||||
},
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
|
||||
const url = makeImageURL(root, 'img2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
|
|
@ -64,7 +64,10 @@ export async function main() {
|
|||
returnEmptyString: false,
|
||||
});
|
||||
|
||||
i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys);
|
||||
const strings = await client.strings();
|
||||
for (const [lang, data] of Object.entries(strings)) {
|
||||
i18n.addResourceBundle(lang, 'translation', data, true);
|
||||
}
|
||||
|
||||
// prep zustand with a slice for each tab, using local storage
|
||||
const {
|
||||
|
|
Loading…
Reference in New Issue