add tokens to network response
This commit is contained in:
parent
d79af78ef0
commit
a4bf4ac651
|
@ -1,18 +1,21 @@
|
|||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
NetworkType = Literal["inversion", "lora"]
|
||||
|
||||
|
||||
class NetworkModel:
|
||||
name: str
|
||||
tokens: List[str]
|
||||
type: NetworkType
|
||||
|
||||
def __init__(self, name: str, type: NetworkType) -> None:
|
||||
def __init__(self, name: str, type: NetworkType, tokens = None) -> None:
|
||||
self.name = name
|
||||
self.tokens = tokens or []
|
||||
self.type = type
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"tokens": self.tokens,
|
||||
"type": self.type,
|
||||
}
|
||||
|
|
|
@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
|
|||
# Loaded from extra_models
|
||||
extra_hashes: Dict[str, str] = {}
|
||||
extra_strings: Dict[str, Any] = {}
|
||||
extra_tokens: Dict[str, List[str]] = {}
|
||||
|
||||
|
||||
def get_config_params():
|
||||
|
@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
|
|||
"""
|
||||
global extra_hashes
|
||||
global extra_strings
|
||||
global extra_tokens
|
||||
|
||||
labels = {}
|
||||
strings = {}
|
||||
|
@ -210,6 +212,10 @@ def load_extras(server: ServerContext):
|
|||
else:
|
||||
labels[model_name] = model["label"]
|
||||
|
||||
if "tokens" in model:
|
||||
logger.debug("collecting tokens for model %s from %s", model_name, file)
|
||||
extra_tokens[model_name] = model["tokens"]
|
||||
|
||||
if "inversions" in model:
|
||||
for inversion in model["inversions"]:
|
||||
if "label" in inversion:
|
||||
|
@ -353,7 +359,7 @@ def load_models(server: ServerContext) -> None:
|
|||
)
|
||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||
network_models.extend(
|
||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
||||
[NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) for model in inversion_models]
|
||||
)
|
||||
|
||||
lora_models = list_model_globs(
|
||||
|
@ -364,7 +370,7 @@ def load_models(server: ServerContext) -> None:
|
|||
base_path=path.join(server.model_path, "lora"),
|
||||
)
|
||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||
network_models.extend([NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) for model in lora_models])
|
||||
|
||||
|
||||
def load_params(server: ServerContext) -> None:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { Chip, TextField } from '@mui/material';
|
||||
import { Stack } from '@mui/system';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
|
@ -120,7 +120,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
</Stack>;
|
||||
}
|
||||
|
||||
export const ANY_TOKEN = /<([^>])+>/g;
|
||||
export const ANY_TOKEN = /<([^>]+)>/g;
|
||||
|
||||
export type TokenList = Array<[string, number]>;
|
||||
|
||||
|
@ -167,10 +167,10 @@ export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptN
|
|||
}
|
||||
}
|
||||
|
||||
for (const [name, weight] of networks.inversion) {
|
||||
for (const [name, weight] of networks.lora) {
|
||||
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||
if (doesExist(model) && model.type === 'lora') {
|
||||
for (const token of model.tokens) {
|
||||
for (const token of mustDefault(model.tokens, [])) {
|
||||
tokens.push([token, weight]);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue