1
0
Fork 0

feat: allow users to add their own labels for models (#144)

This commit is contained in:
Sean Sube 2023-03-04 22:57:31 -06:00
parent 628812fb0b
commit 5d459ab17c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 123 additions and 10 deletions

View File

@ -353,20 +353,26 @@ def main() -> int:
logger.info("converting base models")
convert_models(ctx, args, base_models)
for file in args.extras:
extras = []
extras.extend(ctx.extra_models)
extras.extend(args.extras)
extras = list(set(extras))
extras.sort()
logger.debug("loading extra files: %s", extras)
with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
for file in extras:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())
with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
logger.debug("validating extras file %s", data)
try:
validate(data, schema)
validate(data, extra_schema)
logger.info("converting extra models")
convert_models(ctx, args, data)
except ValidationError as err:

View File

@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask):
}
def get_extra_strings(context: ServerContext):
return jsonify(get_extra_strings())
def list_mask_filters(context: ServerContext):
return jsonify(list(get_mask_filters().keys()))
@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)),
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool)
),

View File

@ -25,6 +25,7 @@ class ServerContext:
cache_path: Optional[str] = None,
show_progress: bool = True,
optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -40,6 +41,7 @@ class ServerContext:
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
self.optimizations = optimizations or []
self.extra_models = extra_models or []
@classmethod
def from_environ(cls):
@ -63,4 +65,5 @@ class ServerContext:
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
)

View File

@ -2,11 +2,14 @@ from functools import cmp_to_key
from glob import glob
from logging import getLogger
from os import path
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
from jsonschema import ValidationError, validate
import torch
import yaml
from yaml import safe_load
from ..utils import merge
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -58,6 +61,9 @@ diffusion_models: List[str] = []
inversion_models: List[str] = []
upscaling_models: List[str] = []
# Loaded from extra_models
extra_strings: Dict[str, Any] = {}
def get_config_params():
return config_params
@ -83,6 +89,10 @@ def get_upscaling_models():
return upscaling_models
def get_extra_strings():
return extra_strings
def get_mask_filters():
return mask_filters
@ -101,6 +111,59 @@ def get_model_name(model: str) -> str:
return file
def load_extras(context: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
global extra_strings
labels = {}
strings = {}
with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
for file in context.extra_models:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())
logger.debug("validating extras file %s", data)
try:
validate(data, extra_schema)
except ValidationError as err:
logger.error("invalid data in extras file: %s", err)
continue
if "strings" in data:
logger.debug("collecting strings from %s", file)
merge(strings, data["strings"])
for model_type in ["diffusion", "correction", "upscaling"]:
if model_type in data:
for model in data[model_type]:
if "label" in model:
model_name = model["name"]
logger.debug("collecting label for model %s from %s", model_name, file)
labels[model_name] = model["label"]
except Exception as err:
logger.error("error loading extras file: %s", err)
logger.debug("adding labels to strings: %s", labels)
merge(strings, {
"en": {
"translation": {
"model": labels,
}
}
})
extra_strings = strings
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
models = []
for pattern in globs:

View File

@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
def merge(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
else:
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a

View File

@ -21,6 +21,8 @@ $defs:
format:
type: string
enum: [concept, embeddings]
label:
type: string
token:
type: string
@ -33,6 +35,8 @@ $defs:
enum: [onnx, pth, ckpt, safetensors]
half:
type: boolean
label:
type: string
name:
type: string
opset:
@ -104,4 +108,7 @@ properties:
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/source_model"
- $ref: "#/$defs/source_model"
strings:
type: object
# /\w{2}/: translation: {}

View File

@ -220,6 +220,11 @@ export interface ApiClient {
*/
schedulers(): Promise<Array<string>>;
/**
* Load extra strings from the server.
*/
strings(): Promise<Record<string, unknown>>;
/**
* Start a txt2img pipeline.
*/
@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async strings(): Promise<Record<string, unknown>> {
const path = makeApiUrl(root, 'settings', 'strings');
const res = await f(path);
return await res.json() as Record<string, unknown>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);

View File

@ -64,7 +64,10 @@ export async function main() {
returnEmptyString: false,
});
i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys);
const strings = await client.strings();
for (const [lang, data] of Object.entries(strings)) {
i18n.addResourceBundle(lang, 'translation', data, true);
}
// prep zustand with a slice for each tab, using local storage
const {