1
0
Fork 0

add tokens to network response

This commit is contained in:
Sean Sube 2023-11-12 15:36:51 -06:00
parent d79af78ef0
commit a4bf4ac651
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 17 additions and 8 deletions

View File

@ -1,18 +1,21 @@
from typing import Literal from typing import List, Literal
NetworkType = Literal["inversion", "lora"] NetworkType = Literal["inversion", "lora"]
class NetworkModel: class NetworkModel:
name: str name: str
tokens: List[str]
type: NetworkType type: NetworkType
def __init__(self, name: str, type: NetworkType) -> None: def __init__(self, name: str, type: NetworkType, tokens = None) -> None:
self.name = name self.name = name
self.tokens = tokens or []
self.type = type self.type = type
def tojson(self): def tojson(self):
return { return {
"name": self.name, "name": self.name,
"tokens": self.tokens,
"type": self.type, "type": self.type,
} }

View File

@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models # Loaded from extra_models
extra_hashes: Dict[str, str] = {} extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {} extra_strings: Dict[str, Any] = {}
extra_tokens: Dict[str, List[str]] = {}
def get_config_params(): def get_config_params():
@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
""" """
global extra_hashes global extra_hashes
global extra_strings global extra_strings
global extra_tokens
labels = {} labels = {}
strings = {} strings = {}
@ -210,6 +212,10 @@ def load_extras(server: ServerContext):
else: else:
labels[model_name] = model["label"] labels[model_name] = model["label"]
if "tokens" in model:
logger.debug("collecting tokens for model %s from %s", model_name, file)
extra_tokens[model_name] = model["tokens"]
if "inversions" in model: if "inversions" in model:
for inversion in model["inversions"]: for inversion in model["inversions"]:
if "label" in inversion: if "label" in inversion:
@ -353,7 +359,7 @@ def load_models(server: ServerContext) -> None:
) )
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend( network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models] [NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) for model in inversion_models]
) )
lora_models = list_model_globs( lora_models = list_model_globs(
@ -364,7 +370,7 @@ def load_models(server: ServerContext) -> None:
base_path=path.join(server.model_path, "lora"), base_path=path.join(server.model_path, "lora"),
) )
logger.debug("loaded LoRA models from disk: %s", lora_models) logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models]) network_models.extend([NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) for model in lora_models])
def load_params(server: ServerContext) -> None: def load_params(server: ServerContext) -> None:

View File

@ -1,4 +1,4 @@
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Chip, TextField } from '@mui/material'; import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system'; import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
@ -120,7 +120,7 @@ export function PromptInput(props: PromptInputProps) {
</Stack>; </Stack>;
} }
export const ANY_TOKEN = /<([^>])+>/g; export const ANY_TOKEN = /<([^>]+)>/g;
export type TokenList = Array<[string, number]>; export type TokenList = Array<[string, number]>;
@ -167,10 +167,10 @@ export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptN
} }
} }
for (const [name, weight] of networks.inversion) { for (const [name, weight] of networks.lora) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name); const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') { if (doesExist(model) && model.type === 'lora') {
for (const token of model.tokens) { for (const token of mustDefault(model.tokens, [])) {
tokens.push([token, weight]); tokens.push([token, weight]);
} }
} }