diff --git a/api/onnx_web/models/meta.py b/api/onnx_web/models/meta.py index dcd43c25..453e24b2 100644 --- a/api/onnx_web/models/meta.py +++ b/api/onnx_web/models/meta.py @@ -1,18 +1,21 @@ -from typing import Literal +from typing import List, Literal NetworkType = Literal["inversion", "lora"] class NetworkModel: name: str + tokens: List[str] type: NetworkType - def __init__(self, name: str, type: NetworkType) -> None: + def __init__(self, name: str, type: NetworkType, tokens = None) -> None: self.name = name + self.tokens = tokens or [] self.type = type def tojson(self): return { "name": self.name, + "tokens": self.tokens, "type": self.type, } diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index a3759ff6..fa8f009b 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list) # Loaded from extra_models extra_hashes: Dict[str, str] = {} extra_strings: Dict[str, Any] = {} +extra_tokens: Dict[str, List[str]] = {} def get_config_params(): @@ -160,6 +161,7 @@ def load_extras(server: ServerContext): """ global extra_hashes global extra_strings + global extra_tokens labels = {} strings = {} @@ -210,6 +212,10 @@ def load_extras(server: ServerContext): else: labels[model_name] = model["label"] + if "tokens" in model: + logger.debug("collecting tokens for model %s from %s", model_name, file) + extra_tokens[model_name] = model["tokens"] + if "inversions" in model: for inversion in model["inversions"]: if "label" in inversion: @@ -353,7 +359,7 @@ def load_models(server: ServerContext) -> None: ) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) network_models.extend( - [NetworkModel(model, "inversion") for model in inversion_models] + [NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) for model in inversion_models] ) lora_models = list_model_globs( @@ -364,7 +370,7 @@ def load_models(server: ServerContext) -> None: base_path=path.join(server.model_path, "lora"), ) logger.debug("loaded LoRA models from disk: %s", lora_models) - network_models.extend([NetworkModel(model, "lora") for model in lora_models]) + network_models.extend([NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) for model in lora_models]) def load_params(server: ServerContext) -> None: diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index 2d8c8413..fa6a7aba 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -1,4 +1,4 @@ -import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; +import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; import { Chip, TextField } from '@mui/material'; import { Stack } from '@mui/system'; import { useQuery } from '@tanstack/react-query'; @@ -120,7 +120,7 @@ export function PromptInput(props: PromptInputProps) { ; } -export const ANY_TOKEN = /<([^>])+>/g; +export const ANY_TOKEN = /<([^>]+)>/g; export type TokenList = Array<[string, number]>; @@ -167,10 +167,10 @@ export function getNetworkTokens(models: Maybe, networks: PromptN } } - for (const [name, weight] of networks.inversion) { + for (const [name, weight] of networks.lora) { const model = models.networks.find((it) => it.type === 'lora' && it.name === name); if (doesExist(model) && model.type === 'lora') { - for (const token of model.tokens) { + for (const token of mustDefault(model.tokens, [])) { tokens.push([token, weight]); } }