diff --git a/README.md b/README.md
index 19e8f227..76d6daba 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ with a CPU fallback capable of running on laptop-class machines.
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
-![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png)
+![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
## Features
diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py
index bc3a57ab..eb4aebc4 100644
--- a/api/onnx_web/chain/tile.py
+++ b/api/onnx_web/chain/tile.py
@@ -115,7 +115,12 @@ def make_tile_mask(
# build gradients
edge_t, edge_l, edge_b, edge_r = edges
- grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [int(not edge_t), 1, 1, int(not edge_b)]
+ grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
+ int(not edge_t),
+ 1,
+ 1,
+ int(not edge_b),
+ ]
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py
index fb5ea0e8..a33fed75 100644
--- a/api/onnx_web/diffusers/pipelines/panorama.py
+++ b/api/onnx_web/diffusers/pipelines/panorama.py
@@ -660,8 +660,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
- region_noise_pred = region_noise_pred_uncond + guidance_scale * (
- region_noise_pred_text - region_noise_pred_uncond
+ region_noise_pred = (
+ region_noise_pred_uncond
+ + guidance_scale
+ * (region_noise_pred_text - region_noise_pred_uncond)
)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py
index 650ed17a..c9b970a4 100644
--- a/api/onnx_web/diffusers/pipelines/panorama_xl.py
+++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py
@@ -502,8 +502,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
- region_noise_pred = region_noise_pred_uncond + guidance_scale * (
- region_noise_pred_text - region_noise_pred_uncond
+ region_noise_pred = (
+ region_noise_pred_uncond
+ + guidance_scale
+ * (region_noise_pred_text - region_noise_pred_uncond)
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py
index 9686a48e..651e861a 100644
--- a/api/onnx_web/diffusers/utils.py
+++ b/api/onnx_web/diffusers/utils.py
@@ -379,6 +379,9 @@ def encode_prompt(
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
+ """
+ TODO: does not work with SDXL, fix or turn into a pipeline patch
+ """
return [
pipe._encode_prompt(
remove_tokens(prompt),
@@ -456,7 +459,9 @@ def slice_prompt(prompt: str, slice: int) -> str:
return prompt
-Region = Tuple[int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str]
+Region = Tuple[
+ int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
+]
def parse_region_group(group: Tuple[str, ...]) -> Region:
@@ -475,12 +480,15 @@ def parse_region_group(group: Tuple[str, ...]) -> Region:
int(bottom),
int(right),
float(weight),
- (float(feather_radius), (
- "T" in feather_edges,
- "L" in feather_edges,
- "B" in feather_edges,
- "R" in feather_edges,
- )),
+ (
+ float(feather_radius),
+ (
+ "T" in feather_edges,
+ "L" in feather_edges,
+ "B" in feather_edges,
+ "R" in feather_edges,
+ ),
+ ),
prompt,
)
diff --git a/api/onnx_web/models/meta.py b/api/onnx_web/models/meta.py
index dcd43c25..fd8b1297 100644
--- a/api/onnx_web/models/meta.py
+++ b/api/onnx_web/models/meta.py
@@ -1,18 +1,21 @@
-from typing import Literal
+from typing import List, Literal
NetworkType = Literal["inversion", "lora"]
class NetworkModel:
name: str
+ tokens: List[str]
type: NetworkType
- def __init__(self, name: str, type: NetworkType) -> None:
+ def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
self.name = name
+ self.tokens = tokens or []
self.type = type
def tojson(self):
return {
"name": self.name,
+ "tokens": self.tokens,
"type": self.type,
}
diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py
index a3759ff6..0444cc82 100644
--- a/api/onnx_web/server/load.py
+++ b/api/onnx_web/server/load.py
@@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {}
+extra_tokens: Dict[str, List[str]] = {}
def get_config_params():
@@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
"""
global extra_hashes
global extra_strings
+ global extra_tokens
labels = {}
strings = {}
@@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
else:
labels[model_name] = model["label"]
+ if "tokens" in model:
+ logger.debug(
+ "collecting tokens for model %s from %s",
+ model_name,
+ file,
+ )
+ extra_tokens[model_name] = model["tokens"]
+
if "inversions" in model:
for inversion in model["inversions"]:
if "label" in inversion:
@@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
)
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend(
- [NetworkModel(model, "inversion") for model in inversion_models]
+ [
+ NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
+ for model in inversion_models
+ ]
)
lora_models = list_model_globs(
@@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
base_path=path.join(server.model_path, "lora"),
)
logger.debug("loaded LoRA models from disk: %s", lora_models)
- network_models.extend([NetworkModel(model, "lora") for model in lora_models])
+ network_models.extend(
+ [
+ NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
+ for model in lora_models
+ ]
+ )
def load_params(server: ServerContext) -> None:
diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml
index 30886f2c..518023e0 100644
--- a/api/schemas/extras.yaml
+++ b/api/schemas/extras.yaml
@@ -10,34 +10,53 @@ $defs:
- type: number
- type: string
- lora_network:
+ tensor_format:
+ type: string
+ enum: [bin, ckpt, onnx, pt, pth, safetensors]
+
+ embedding_network:
type: object
required: [name, source]
properties:
- name:
- type: string
- source:
- type: string
+ format:
+ $ref: "#/$defs/tensor_format"
label:
type: string
- weight:
- type: number
-
- textual_inversion_network:
- type: object
- required: [name, source]
- properties:
- name:
- type: string
- source:
- type: string
- format:
+ model:
type: string
enum: [concept, embeddings]
- label:
+ name:
+ type: string
+ source:
type: string
token:
type: string
+ type:
+ type: string
+ const: inversion # TODO: add embedding
+ weight:
+ type: number
+
+ lora_network:
+ type: object
+ required: [name, source, type]
+ properties:
+ label:
+ type: string
+ model:
+ type: string
+ enum: [cloneofsimo, sd-scripts]
+ name:
+ type: string
+ source:
+ type: string
+ tokens:
+ type: array
+ items:
+ type: string
+ type:
+ type: string
+ const: lora
weight:
type: number
@@ -46,8 +65,7 @@ $defs:
required: [name, source]
properties:
format:
- type: string
- enum: [bin, ckpt, onnx, pt, pth, safetensors]
+ $ref: "#/$defs/tensor_format"
half:
type: boolean
label:
@@ -85,7 +103,7 @@ $defs:
inversions:
type: array
items:
- $ref: "#/$defs/textual_inversion_network"
+ $ref: "#/$defs/embedding_network"
loras:
type: array
items:
@@ -142,31 +160,6 @@ $defs:
source:
type: string
- source_network:
- type: object
- required: [name, source, type]
- properties:
- format:
- type: string
- enum: [bin, ckpt, onnx, pt, pth, safetensors]
- model:
- type: string
- enum: [
- # inversion
- concept,
- embeddings,
- # lora
- cloneofsimo,
- sd-scripts
- ]
- name:
- type: string
- source:
- type: string
- type:
- type: string
- enum: [inversion, lora]
-
translation:
type: object
additionalProperties: False
@@ -194,7 +187,9 @@ properties:
networks:
type: array
items:
- $ref: "#/$defs/source_network"
+ oneOf:
+ - $ref: "#/$defs/lora_network"
+ - $ref: "#/$defs/embedding_network"
sources:
type: array
items:
diff --git a/docs/readme-sdxl.png b/docs/readme-sdxl.png
new file mode 100644
index 00000000..eb0b2223
--- /dev/null
+++ b/docs/readme-sdxl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:145b8a98ecf5cfd4948d5ab17d28b34a8fc63cbb3b2c5e3f94b4411538733a59
+size 1633570
diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx
index 2fc52ca3..6fbc7f09 100644
--- a/gui/src/components/input/PromptInput.tsx
+++ b/gui/src/components/input/PromptInput.tsx
@@ -1,5 +1,5 @@
-import { mustExist } from '@apextoaster/js-utils';
-import { TextField } from '@mui/material';
+import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
+import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
@@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
+import { ModelResponse } from '../../types/api.js';
const { useContext } = React;
@@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
staleTime: STALE_TIME,
});
- const tokens = splitPrompt(prompt);
- const groups = Math.ceil(tokens.length / PROMPT_GROUP);
-
const { t } = useTranslation();
- const helper = t('input.prompt.tokens', {
- groups,
- tokens: tokens.length,
- });
- function addToken(type: string, name: string, weight = 1.0) {
+ function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt,
});
}
+ function addToken(name: string) {
+ onChange({
+ prompt: `${prompt}, ${name}`,
+ });
+ }
+
+ const networks = extractNetworks(prompt);
+ const tokens = getNetworkTokens(models.data, networks);
+
return
{
@@ -77,6 +79,13 @@ export function PromptInput(props: PromptInputProps) {
});
}}
/>
+
+ {tokens.map(([token, _weight]) => addToken(token)}
+ />)}
+
result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
onSelect={(name) => {
- addToken('inversion', name);
+ addNetwork('inversion', name);
}}
/>
result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
onSelect={(name) => {
- addToken('lora', name);
+ addNetwork('lora', name);
}}
/>
;
}
+
+export const ANY_TOKEN = /<([^>]+)>/g;
+
+export type TokenList = Array<[string, number]>;
+
+export interface PromptNetworks {
+ inversion: TokenList;
+ lora: TokenList;
+}
+
+export function extractNetworks(prompt: string): PromptNetworks {
+ const inversion: TokenList = [];
+ const lora: TokenList = [];
+
+ for (const token of prompt.matchAll(ANY_TOKEN)) {
+ const [_whole, match] = Array.from(token);
+ const [type, name, weight, ..._rest] = match.split(':');
+
+ switch (type) {
+ case 'inversion':
+ inversion.push([name, parseFloat(weight)]);
+ break;
+ case 'lora':
+ lora.push([name, parseFloat(weight)]);
+ break;
+ default:
+ // ignore others
+ }
+ }
+
+ return {
+ inversion,
+ lora,
+ };
+}
+
+// eslint-disable-next-line sonarjs/cognitive-complexity
+export function getNetworkTokens(models: Maybe, networks: PromptNetworks): TokenList {
+ const tokens: TokenList = [];
+
+ if (doesExist(models)) {
+ for (const [name, weight] of networks.inversion) {
+ const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
+ if (doesExist(model) && model.type === 'inversion') {
+ tokens.push([model.token, weight]);
+ }
+ }
+
+ for (const [name, weight] of networks.lora) {
+ const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
+ if (doesExist(model) && model.type === 'lora') {
+ for (const token of mustDefault(model.tokens, [])) {
+ tokens.push([token, weight]);
+ }
+ }
+ }
+ }
+
+ return tokens;
+}
diff --git a/gui/src/types/api.ts b/gui/src/types/api.ts
index 70b99f7a..65179280 100644
--- a/gui/src/types/api.ts
+++ b/gui/src/types/api.ts
@@ -39,13 +39,28 @@ export interface ReadyResponse {
ready: boolean;
}
-export interface NetworkModel {
+export interface ControlNetwork {
name: string;
- type: 'control' | 'inversion' | 'lora';
- // TODO: add token
- // TODO: add layer/token count
+ type: 'control';
}
+export interface EmbeddingNetwork {
+ label: string;
+ name: string;
+ token: string;
+ type: 'inversion';
+ // TODO: add layer count
+}
+
+export interface LoraNetwork {
+ name: string;
+ label: string;
+ tokens: Array;
+ type: 'lora';
+}
+
+export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
+
export interface FilterResponse {
mask: Array;
source: Array;