diff --git a/README.md b/README.md index 19e8f227..76d6daba 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ with a CPU fallback capable of running on laptop-class machines. Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). -![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png) +![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png) ## Features diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index bc3a57ab..eb4aebc4 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -115,7 +115,12 @@ def make_tile_mask( # build gradients edge_t, edge_l, edge_b, edge_r = edges - grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [int(not edge_t), 1, 1, int(not edge_b)] + grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [ + int(not edge_t), + 1, + 1, + int(not edge_b), + ] logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y) mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)] diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index fb5ea0e8..a33fed75 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -660,8 +660,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): region_noise_pred_uncond, region_noise_pred_text = np.split( region_noise_pred, 2 ) - region_noise_pred = region_noise_pred_uncond + guidance_scale * ( - region_noise_pred_text - region_noise_pred_uncond + region_noise_pred = ( + region_noise_pred_uncond + + guidance_scale + * (region_noise_pred_text - region_noise_pred_uncond) ) # compute the previous noisy sample x_t -> x_t-1 diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 650ed17a..c9b970a4 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -502,8 +502,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix region_noise_pred_uncond, region_noise_pred_text = np.split( region_noise_pred, 2 ) - region_noise_pred = region_noise_pred_uncond + guidance_scale * ( - region_noise_pred_text - region_noise_pred_uncond + region_noise_pred = ( + region_noise_pred_uncond + + guidance_scale + * (region_noise_pred_text - region_noise_pred_uncond) ) if guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 9686a48e..651e861a 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -379,6 +379,9 @@ def encode_prompt( num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, ) -> List[np.ndarray]: + """ + TODO: does not work with SDXL, fix or turn into a pipeline patch + """ return [ pipe._encode_prompt( remove_tokens(prompt), @@ -456,7 +459,9 @@ def slice_prompt(prompt: str, slice: int) -> str: return prompt -Region = Tuple[int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str] +Region = Tuple[ + int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str +] def parse_region_group(group: Tuple[str, ...]) -> Region: @@ -475,12 +480,15 @@ def parse_region_group(group: Tuple[str, ...]) -> Region: int(bottom), int(right), float(weight), - (float(feather_radius), ( - "T" in feather_edges, - "L" in feather_edges, - "B" in feather_edges, - "R" in feather_edges, - )), + ( + float(feather_radius), + ( + "T" in feather_edges, + "L" in feather_edges, + "B" in feather_edges, + "R" in feather_edges, + ), + ), prompt, ) diff --git a/api/onnx_web/models/meta.py b/api/onnx_web/models/meta.py index dcd43c25..fd8b1297 100644 --- a/api/onnx_web/models/meta.py +++ b/api/onnx_web/models/meta.py @@ -1,18 +1,21 @@ -from typing import Literal +from typing import List, Literal NetworkType = Literal["inversion", "lora"] class NetworkModel: name: str + tokens: List[str] type: NetworkType - def __init__(self, name: str, type: NetworkType) -> None: + def __init__(self, name: str, type: NetworkType, tokens=None) -> None: self.name = name + self.tokens = tokens or [] self.type = type def tojson(self): return { "name": self.name, + "tokens": self.tokens, "type": self.type, } diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index a3759ff6..0444cc82 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list) # Loaded from extra_models extra_hashes: Dict[str, str] = {} extra_strings: Dict[str, Any] = {} +extra_tokens: Dict[str, List[str]] = {} def get_config_params(): @@ -160,6 +161,7 @@ def load_extras(server: ServerContext): """ global extra_hashes global extra_strings + global extra_tokens labels = {} strings = {} @@ -210,6 +212,14 @@ def load_extras(server: ServerContext): else: labels[model_name] = model["label"] + if "tokens" in model: + logger.debug( + "collecting tokens for model %s from %s", + model_name, + file, + ) + extra_tokens[model_name] = model["tokens"] + if "inversions" in model: for inversion in model["inversions"]: if "label" in inversion: @@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None: ) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) network_models.extend( - [NetworkModel(model, "inversion") for model in inversion_models] + [ + NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) + for model in inversion_models + ] ) lora_models = list_model_globs( @@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None: base_path=path.join(server.model_path, "lora"), ) logger.debug("loaded LoRA models from disk: %s", lora_models) - network_models.extend([NetworkModel(model, "lora") for model in lora_models]) + network_models.extend( + [ + NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) + for model in lora_models + ] + ) def load_params(server: ServerContext) -> None: diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index 30886f2c..518023e0 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -10,34 +10,53 @@ $defs: - type: number - type: string - lora_network: + tensor_format: + type: string + enum: [bin, ckpt, onnx, pt, pth, safetensors] + + embedding_network: type: object required: [name, source] properties: - name: - type: string - source: - type: string + format: + $ref: "#/$defs/tensor_format" label: type: string - weight: - type: number - - textual_inversion_network: - type: object - required: [name, source] - properties: - name: - type: string - source: - type: string - format: + model: type: string enum: [concept, embeddings] - label: + name: + type: string + source: type: string token: type: string + type: + type: string + const: inversion # TODO: add embedding + weight: + type: number + + lora_network: + type: object + required: [name, source, type] + properties: + label: + type: string + model: + type: string + enum: [cloneofsimo, sd-scripts] + name: + type: string + source: + type: string + tokens: + type: array + items: + type: string + type: + type: string + const: lora weight: type: number @@ -46,8 +65,7 @@ $defs: required: [name, source] properties: format: - type: string - enum: [bin, ckpt, onnx, pt, pth, safetensors] + $ref: "#/$defs/tensor_format" half: type: boolean label: @@ -85,7 +103,7 @@ $defs: inversions: type: array items: - $ref: "#/$defs/textual_inversion_network" + $ref: "#/$defs/embedding_network" loras: type: array items: @@ -142,31 +160,6 @@ $defs: source: type: string - source_network: - type: object - required: [name, source, type] - properties: - format: - type: string - enum: [bin, ckpt, onnx, pt, pth, safetensors] - model: - type: string - enum: [ - # inversion - concept, - embeddings, - # lora - cloneofsimo, - sd-scripts - ] - name: - type: string - source: - type: string - type: - type: string - enum: [inversion, lora] - translation: type: object additionalProperties: False @@ -194,7 +187,9 @@ properties: networks: type: array items: - $ref: "#/$defs/source_network" + oneOf: + - $ref: "#/$defs/lora_network" + - $ref: "#/$defs/embedding_network" sources: type: array items: diff --git a/docs/readme-sdxl.png b/docs/readme-sdxl.png new file mode 100644 index 00000000..eb0b2223 --- /dev/null +++ b/docs/readme-sdxl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:145b8a98ecf5cfd4948d5ab17d28b34a8fc63cbb3b2c5e3f94b4411538733a59 +size 1633570 diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index 2fc52ca3..6fbc7f09 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -1,5 +1,5 @@ -import { mustExist } from '@apextoaster/js-utils'; -import { TextField } from '@mui/material'; +import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; +import { Chip, TextField } from '@mui/material'; import { Stack } from '@mui/system'; import { useQuery } from '@tanstack/react-query'; import * as React from 'react'; @@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { QueryMenu } from '../input/QueryMenu.js'; +import { ModelResponse } from '../../types/api.js'; const { useContext } = React; @@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) { staleTime: STALE_TIME, }); - const tokens = splitPrompt(prompt); - const groups = Math.ceil(tokens.length / PROMPT_GROUP); - const { t } = useTranslation(); - const helper = t('input.prompt.tokens', { - groups, - tokens: tokens.length, - }); - function addToken(type: string, name: string, weight = 1.0) { + function addNetwork(type: string, name: string, weight = 1.0) { onChange({ prompt: `<${type}:${name}:1.0> ${prompt}`, negativePrompt, }); } + function addToken(name: string) { + onChange({ + prompt: `${prompt}, ${name}`, + }); + } + + const networks = extractNetworks(prompt); + const tokens = getNetworkTokens(models.data, networks); + return { @@ -77,6 +79,13 @@ export function PromptInput(props: PromptInputProps) { }); }} /> + + {tokens.map(([token, _weight]) => addToken(token)} + />)} + result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), }} onSelect={(name) => { - addToken('inversion', name); + addNetwork('inversion', name); }} /> result.networks.filter((network) => network.type === 'lora').map((network) => network.name), }} onSelect={(name) => { - addToken('lora', name); + addNetwork('lora', name); }} /> ; } + +export const ANY_TOKEN = /<([^>]+)>/g; + +export type TokenList = Array<[string, number]>; + +export interface PromptNetworks { + inversion: TokenList; + lora: TokenList; +} + +export function extractNetworks(prompt: string): PromptNetworks { + const inversion: TokenList = []; + const lora: TokenList = []; + + for (const token of prompt.matchAll(ANY_TOKEN)) { + const [_whole, match] = Array.from(token); + const [type, name, weight, ..._rest] = match.split(':'); + + switch (type) { + case 'inversion': + inversion.push([name, parseFloat(weight)]); + break; + case 'lora': + lora.push([name, parseFloat(weight)]); + break; + default: + // ignore others + } + } + + return { + inversion, + lora, + }; +} + +// eslint-disable-next-line sonarjs/cognitive-complexity +export function getNetworkTokens(models: Maybe, networks: PromptNetworks): TokenList { + const tokens: TokenList = []; + + if (doesExist(models)) { + for (const [name, weight] of networks.inversion) { + const model = models.networks.find((it) => it.type === 'inversion' && it.name === name); + if (doesExist(model) && model.type === 'inversion') { + tokens.push([model.token, weight]); + } + } + + for (const [name, weight] of networks.lora) { + const model = models.networks.find((it) => it.type === 'lora' && it.name === name); + if (doesExist(model) && model.type === 'lora') { + for (const token of mustDefault(model.tokens, [])) { + tokens.push([token, weight]); + } + } + } + } + + return tokens; +} diff --git a/gui/src/types/api.ts b/gui/src/types/api.ts index 70b99f7a..65179280 100644 --- a/gui/src/types/api.ts +++ b/gui/src/types/api.ts @@ -39,13 +39,28 @@ export interface ReadyResponse { ready: boolean; } -export interface NetworkModel { +export interface ControlNetwork { name: string; - type: 'control' | 'inversion' | 'lora'; - // TODO: add token - // TODO: add layer/token count + type: 'control'; } +export interface EmbeddingNetwork { + label: string; + name: string; + token: string; + type: 'inversion'; + // TODO: add layer count +} + +export interface LoraNetwork { + name: string; + label: string; + tokens: Array; + type: 'lora'; +} + +export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork; + export interface FilterResponse { mask: Array; source: Array;