1
0
Fork 0

feat: allow users to add their own labels for models (#144)

This commit is contained in:
Sean Sube 2023-03-04 22:57:31 -06:00
parent 628812fb0b
commit 5d459ab17c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 123 additions and 10 deletions

View File

@ -353,20 +353,26 @@ def main() -> int:
logger.info("converting base models") logger.info("converting base models")
convert_models(ctx, args, base_models) convert_models(ctx, args, base_models)
for file in args.extras: extras = []
extras.extend(ctx.extra_models)
extras.extend(args.extras)
extras = list(set(extras))
extras.sort()
logger.debug("loading extra files: %s", extras)
with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
for file in extras:
if file is not None and file != "": if file is not None and file != "":
logger.info("loading extra models from %s", file) logger.info("loading extra models from %s", file)
try: try:
with open(file, "r") as f: with open(file, "r") as f:
data = safe_load(f.read()) data = safe_load(f.read())
with open("./schemas/extras.yaml", "r") as f: logger.debug("validating extras file %s", data)
schema = safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
try: try:
validate(data, schema) validate(data, extra_schema)
logger.info("converting extra models") logger.info("converting extra models")
convert_models(ctx, args, data) convert_models(ctx, args, data)
except ValidationError as err: except ValidationError as err:

View File

@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask):
} }
def get_extra_strings(context: ServerContext):
return jsonify(get_extra_strings())
def list_mask_filters(context: ServerContext): def list_mask_filters(context: ServerContext):
return jsonify(list(get_mask_filters().keys())) return jsonify(list(get_mask_filters().keys()))
@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec
app.route("/api/settings/params")(wrap_route(list_params, context)), app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)),
app.route("/api/img2img", methods=["POST"])( app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool) wrap_route(img2img, context, pool=pool)
), ),

View File

@ -25,6 +25,7 @@ class ServerContext:
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
show_progress: bool = True, show_progress: bool = True,
optimizations: Optional[List[str]] = None, optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -40,6 +41,7 @@ class ServerContext:
self.cache_path = cache_path or path.join(model_path, ".cache") self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress self.show_progress = show_progress
self.optimizations = optimizations or [] self.optimizations = optimizations or []
self.extra_models = extra_models or []
@classmethod @classmethod
def from_environ(cls): def from_environ(cls):
@ -63,4 +65,5 @@ class ServerContext:
cache=ModelCache(limit=cache_limit), cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
) )

View File

@ -2,11 +2,14 @@ from functools import cmp_to_key
from glob import glob from glob import glob
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Dict, List, Union from typing import Any, Dict, List, Union
from jsonschema import ValidationError, validate
import torch import torch
import yaml import yaml
from yaml import safe_load
from ..utils import merge
from ..image import ( # mask filters; noise sources from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
mask_filter_gaussian_screen, mask_filter_gaussian_screen,
@ -58,6 +61,9 @@ diffusion_models: List[str] = []
inversion_models: List[str] = [] inversion_models: List[str] = []
upscaling_models: List[str] = [] upscaling_models: List[str] = []
# Loaded from extra_models
extra_strings: Dict[str, Any] = {}
def get_config_params(): def get_config_params():
return config_params return config_params
@ -83,6 +89,10 @@ def get_upscaling_models():
return upscaling_models return upscaling_models
def get_extra_strings():
return extra_strings
def get_mask_filters(): def get_mask_filters():
return mask_filters return mask_filters
@ -101,6 +111,59 @@ def get_model_name(model: str) -> str:
return file return file
def load_extras(context: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
global extra_strings
labels = {}
strings = {}
with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
for file in context.extra_models:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())
logger.debug("validating extras file %s", data)
try:
validate(data, extra_schema)
except ValidationError as err:
logger.error("invalid data in extras file: %s", err)
continue
if "strings" in data:
logger.debug("collecting strings from %s", file)
merge(strings, data["strings"])
for model_type in ["diffusion", "correction", "upscaling"]:
if model_type in data:
for model in data[model_type]:
if "label" in model:
model_name = model["name"]
logger.debug("collecting label for model %s from %s", model_name, file)
labels[model_name] = model["label"]
except Exception as err:
logger.error("error loading extras file: %s", err)
logger.debug("adding labels to strings: %s", labels)
merge(strings, {
"en": {
"translation": {
"model": labels,
}
}
})
extra_strings = strings
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
models = [] models = []
for pattern in globs: for pattern in globs:

View File

@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
def sanitize_name(name): def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS)) return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
def merge(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
else:
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a

View File

@ -21,6 +21,8 @@ $defs:
format: format:
type: string type: string
enum: [concept, embeddings] enum: [concept, embeddings]
label:
type: string
token: token:
type: string type: string
@ -33,6 +35,8 @@ $defs:
enum: [onnx, pth, ckpt, safetensors] enum: [onnx, pth, ckpt, safetensors]
half: half:
type: boolean type: boolean
label:
type: string
name: name:
type: string type: string
opset: opset:
@ -104,4 +108,7 @@ properties:
items: items:
oneOf: oneOf:
- $ref: "#/$defs/legacy_tuple" - $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/source_model" - $ref: "#/$defs/source_model"
strings:
type: object
# /\w{2}/: translation: {}

View File

@ -220,6 +220,11 @@ export interface ApiClient {
*/ */
schedulers(): Promise<Array<string>>; schedulers(): Promise<Array<string>>;
/**
* Load extra strings from the server.
*/
strings(): Promise<Record<string, unknown>>;
/** /**
* Start a txt2img pipeline. * Start a txt2img pipeline.
*/ */
@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as Array<string>;
}, },
async strings(): Promise<Record<string, unknown>> {
const path = makeApiUrl(root, 'settings', 'strings');
const res = await f(path);
return await res.json() as Record<string, unknown>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> { async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
const url = makeImageURL(root, 'img2img', params); const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model); appendModelToURL(url, model);

View File

@ -64,7 +64,10 @@ export async function main() {
returnEmptyString: false, returnEmptyString: false,
}); });
i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys); const strings = await client.strings();
for (const [lang, data] of Object.entries(strings)) {
i18n.addResourceBundle(lang, 'translation', data, true);
}
// prep zustand with a slice for each tab, using local storage // prep zustand with a slice for each tab, using local storage
const { const {