feat: allow users to add their own labels for models (#144)
This commit is contained in:
parent
628812fb0b
commit
5d459ab17c
|
@ -353,20 +353,26 @@ def main() -> int:
|
||||||
logger.info("converting base models")
|
logger.info("converting base models")
|
||||||
convert_models(ctx, args, base_models)
|
convert_models(ctx, args, base_models)
|
||||||
|
|
||||||
for file in args.extras:
|
extras = []
|
||||||
|
extras.extend(ctx.extra_models)
|
||||||
|
extras.extend(args.extras)
|
||||||
|
extras = list(set(extras))
|
||||||
|
extras.sort()
|
||||||
|
logger.debug("loading extra files: %s", extras)
|
||||||
|
|
||||||
|
with open("./schemas/extras.yaml", "r") as f:
|
||||||
|
extra_schema = safe_load(f.read())
|
||||||
|
|
||||||
|
for file in extras:
|
||||||
if file is not None and file != "":
|
if file is not None and file != "":
|
||||||
logger.info("loading extra models from %s", file)
|
logger.info("loading extra models from %s", file)
|
||||||
try:
|
try:
|
||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
data = safe_load(f.read())
|
data = safe_load(f.read())
|
||||||
|
|
||||||
with open("./schemas/extras.yaml", "r") as f:
|
logger.debug("validating extras file %s", data)
|
||||||
schema = safe_load(f.read())
|
|
||||||
|
|
||||||
logger.debug("validating chain request: %s against %s", data, schema)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate(data, schema)
|
validate(data, extra_schema)
|
||||||
logger.info("converting extra models")
|
logger.info("converting extra models")
|
||||||
convert_models(ctx, args, data)
|
convert_models(ctx, args, data)
|
||||||
except ValidationError as err:
|
except ValidationError as err:
|
||||||
|
|
|
@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_extra_strings(context: ServerContext):
|
||||||
|
return jsonify(get_extra_strings())
|
||||||
|
|
||||||
|
|
||||||
def list_mask_filters(context: ServerContext):
|
def list_mask_filters(context: ServerContext):
|
||||||
return jsonify(list(get_mask_filters().keys()))
|
return jsonify(list(get_mask_filters().keys()))
|
||||||
|
|
||||||
|
@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec
|
||||||
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
||||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
||||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
||||||
|
app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)),
|
||||||
app.route("/api/img2img", methods=["POST"])(
|
app.route("/api/img2img", methods=["POST"])(
|
||||||
wrap_route(img2img, context, pool=pool)
|
wrap_route(img2img, context, pool=pool)
|
||||||
),
|
),
|
||||||
|
|
|
@ -25,6 +25,7 @@ class ServerContext:
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
optimizations: Optional[List[str]] = None,
|
optimizations: Optional[List[str]] = None,
|
||||||
|
extra_models: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -40,6 +41,7 @@ class ServerContext:
|
||||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
self.optimizations = optimizations or []
|
self.optimizations = optimizations or []
|
||||||
|
self.extra_models = extra_models or []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -63,4 +65,5 @@ class ServerContext:
|
||||||
cache=ModelCache(limit=cache_limit),
|
cache=ModelCache(limit=cache_limit),
|
||||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||||
|
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,11 +2,14 @@ from functools import cmp_to_key
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
from jsonschema import ValidationError, validate
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
from yaml import safe_load
|
||||||
|
|
||||||
|
from ..utils import merge
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -58,6 +61,9 @@ diffusion_models: List[str] = []
|
||||||
inversion_models: List[str] = []
|
inversion_models: List[str] = []
|
||||||
upscaling_models: List[str] = []
|
upscaling_models: List[str] = []
|
||||||
|
|
||||||
|
# Loaded from extra_models
|
||||||
|
extra_strings: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_config_params():
|
def get_config_params():
|
||||||
return config_params
|
return config_params
|
||||||
|
@ -83,6 +89,10 @@ def get_upscaling_models():
|
||||||
return upscaling_models
|
return upscaling_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_extra_strings():
|
||||||
|
return extra_strings
|
||||||
|
|
||||||
|
|
||||||
def get_mask_filters():
|
def get_mask_filters():
|
||||||
return mask_filters
|
return mask_filters
|
||||||
|
|
||||||
|
@ -101,6 +111,59 @@ def get_model_name(model: str) -> str:
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
def load_extras(context: ServerContext):
|
||||||
|
"""
|
||||||
|
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||||
|
"""
|
||||||
|
global extra_strings
|
||||||
|
|
||||||
|
labels = {}
|
||||||
|
strings = {}
|
||||||
|
|
||||||
|
with open("./schemas/extras.yaml", "r") as f:
|
||||||
|
extra_schema = safe_load(f.read())
|
||||||
|
|
||||||
|
for file in context.extra_models:
|
||||||
|
if file is not None and file != "":
|
||||||
|
logger.info("loading extra models from %s", file)
|
||||||
|
try:
|
||||||
|
with open(file, "r") as f:
|
||||||
|
data = safe_load(f.read())
|
||||||
|
|
||||||
|
logger.debug("validating extras file %s", data)
|
||||||
|
try:
|
||||||
|
validate(data, extra_schema)
|
||||||
|
except ValidationError as err:
|
||||||
|
logger.error("invalid data in extras file: %s", err)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "strings" in data:
|
||||||
|
logger.debug("collecting strings from %s", file)
|
||||||
|
merge(strings, data["strings"])
|
||||||
|
|
||||||
|
for model_type in ["diffusion", "correction", "upscaling"]:
|
||||||
|
if model_type in data:
|
||||||
|
for model in data[model_type]:
|
||||||
|
if "label" in model:
|
||||||
|
model_name = model["name"]
|
||||||
|
logger.debug("collecting label for model %s from %s", model_name, file)
|
||||||
|
labels[model_name] = model["label"]
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
logger.error("error loading extras file: %s", err)
|
||||||
|
|
||||||
|
logger.debug("adding labels to strings: %s", labels)
|
||||||
|
merge(strings, {
|
||||||
|
"en": {
|
||||||
|
"translation": {
|
||||||
|
"model": labels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
extra_strings = strings
|
||||||
|
|
||||||
|
|
||||||
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
||||||
models = []
|
models = []
|
||||||
for pattern in globs:
|
for pattern in globs:
|
||||||
|
|
|
@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
|
||||||
|
|
||||||
def sanitize_name(name):
|
def sanitize_name(name):
|
||||||
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
|
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
|
||||||
|
|
||||||
|
|
||||||
|
def merge(a, b, path=None):
|
||||||
|
"merges b into a"
|
||||||
|
if path is None: path = []
|
||||||
|
for key in b:
|
||||||
|
if key in a:
|
||||||
|
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||||
|
merge(a[key], b[key], path + [str(key)])
|
||||||
|
elif a[key] == b[key]:
|
||||||
|
pass # same leaf value
|
||||||
|
else:
|
||||||
|
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
|
||||||
|
else:
|
||||||
|
a[key] = b[key]
|
||||||
|
return a
|
||||||
|
|
|
@ -21,6 +21,8 @@ $defs:
|
||||||
format:
|
format:
|
||||||
type: string
|
type: string
|
||||||
enum: [concept, embeddings]
|
enum: [concept, embeddings]
|
||||||
|
label:
|
||||||
|
type: string
|
||||||
token:
|
token:
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
@ -33,6 +35,8 @@ $defs:
|
||||||
enum: [onnx, pth, ckpt, safetensors]
|
enum: [onnx, pth, ckpt, safetensors]
|
||||||
half:
|
half:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
label:
|
||||||
|
type: string
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
opset:
|
opset:
|
||||||
|
@ -104,4 +108,7 @@ properties:
|
||||||
items:
|
items:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: "#/$defs/legacy_tuple"
|
- $ref: "#/$defs/legacy_tuple"
|
||||||
- $ref: "#/$defs/source_model"
|
- $ref: "#/$defs/source_model"
|
||||||
|
strings:
|
||||||
|
type: object
|
||||||
|
# /\w{2}/: translation: {}
|
|
@ -220,6 +220,11 @@ export interface ApiClient {
|
||||||
*/
|
*/
|
||||||
schedulers(): Promise<Array<string>>;
|
schedulers(): Promise<Array<string>>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load extra strings from the server.
|
||||||
|
*/
|
||||||
|
strings(): Promise<Record<string, unknown>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start a txt2img pipeline.
|
* Start a txt2img pipeline.
|
||||||
*/
|
*/
|
||||||
|
@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as Array<string>;
|
||||||
},
|
},
|
||||||
|
async strings(): Promise<Record<string, unknown>> {
|
||||||
|
const path = makeApiUrl(root, 'settings', 'strings');
|
||||||
|
const res = await f(path);
|
||||||
|
return await res.json() as Record<string, unknown>;
|
||||||
|
},
|
||||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
|
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
|
||||||
const url = makeImageURL(root, 'img2img', params);
|
const url = makeImageURL(root, 'img2img', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
|
@ -64,7 +64,10 @@ export async function main() {
|
||||||
returnEmptyString: false,
|
returnEmptyString: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys);
|
const strings = await client.strings();
|
||||||
|
for (const [lang, data] of Object.entries(strings)) {
|
||||||
|
i18n.addResourceBundle(lang, 'translation', data, true);
|
||||||
|
}
|
||||||
|
|
||||||
// prep zustand with a slice for each tab, using local storage
|
// prep zustand with a slice for each tab, using local storage
|
||||||
const {
|
const {
|
||||||
|
|
Loading…
Reference in New Issue