add tokens to network response
This commit is contained in:
parent
d79af78ef0
commit
a4bf4ac651
|
@ -1,18 +1,21 @@
|
||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
NetworkType = Literal["inversion", "lora"]
|
NetworkType = Literal["inversion", "lora"]
|
||||||
|
|
||||||
|
|
||||||
class NetworkModel:
|
class NetworkModel:
|
||||||
name: str
|
name: str
|
||||||
|
tokens: List[str]
|
||||||
type: NetworkType
|
type: NetworkType
|
||||||
|
|
||||||
def __init__(self, name: str, type: NetworkType) -> None:
|
def __init__(self, name: str, type: NetworkType, tokens = None) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.tokens = tokens or []
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
|
"tokens": self.tokens,
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
|
||||||
# Loaded from extra_models
|
# Loaded from extra_models
|
||||||
extra_hashes: Dict[str, str] = {}
|
extra_hashes: Dict[str, str] = {}
|
||||||
extra_strings: Dict[str, Any] = {}
|
extra_strings: Dict[str, Any] = {}
|
||||||
|
extra_tokens: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_config_params():
|
def get_config_params():
|
||||||
|
@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
|
||||||
"""
|
"""
|
||||||
global extra_hashes
|
global extra_hashes
|
||||||
global extra_strings
|
global extra_strings
|
||||||
|
global extra_tokens
|
||||||
|
|
||||||
labels = {}
|
labels = {}
|
||||||
strings = {}
|
strings = {}
|
||||||
|
@ -210,6 +212,10 @@ def load_extras(server: ServerContext):
|
||||||
else:
|
else:
|
||||||
labels[model_name] = model["label"]
|
labels[model_name] = model["label"]
|
||||||
|
|
||||||
|
if "tokens" in model:
|
||||||
|
logger.debug("collecting tokens for model %s from %s", model_name, file)
|
||||||
|
extra_tokens[model_name] = model["tokens"]
|
||||||
|
|
||||||
if "inversions" in model:
|
if "inversions" in model:
|
||||||
for inversion in model["inversions"]:
|
for inversion in model["inversions"]:
|
||||||
if "label" in inversion:
|
if "label" in inversion:
|
||||||
|
@ -353,7 +359,7 @@ def load_models(server: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||||
network_models.extend(
|
network_models.extend(
|
||||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
[NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) for model in inversion_models]
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_models = list_model_globs(
|
lora_models = list_model_globs(
|
||||||
|
@ -364,7 +370,7 @@ def load_models(server: ServerContext) -> None:
|
||||||
base_path=path.join(server.model_path, "lora"),
|
base_path=path.join(server.model_path, "lora"),
|
||||||
)
|
)
|
||||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
network_models.extend([NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) for model in lora_models])
|
||||||
|
|
||||||
|
|
||||||
def load_params(server: ServerContext) -> None:
|
def load_params(server: ServerContext) -> None:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||||
import { Chip, TextField } from '@mui/material';
|
import { Chip, TextField } from '@mui/material';
|
||||||
import { Stack } from '@mui/system';
|
import { Stack } from '@mui/system';
|
||||||
import { useQuery } from '@tanstack/react-query';
|
import { useQuery } from '@tanstack/react-query';
|
||||||
|
@ -120,7 +120,7 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
</Stack>;
|
</Stack>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ANY_TOKEN = /<([^>])+>/g;
|
export const ANY_TOKEN = /<([^>]+)>/g;
|
||||||
|
|
||||||
export type TokenList = Array<[string, number]>;
|
export type TokenList = Array<[string, number]>;
|
||||||
|
|
||||||
|
@ -167,10 +167,10 @@ export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptN
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const [name, weight] of networks.inversion) {
|
for (const [name, weight] of networks.lora) {
|
||||||
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||||
if (doesExist(model) && model.type === 'lora') {
|
if (doesExist(model) && model.type === 'lora') {
|
||||||
for (const token of model.tokens) {
|
for (const token of mustDefault(model.tokens, [])) {
|
||||||
tokens.push([token, weight]);
|
tokens.push([token, weight]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue