From cdbdd9b4e25fc8b0c51d58f95c38292cf388dc99 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 16 Dec 2023 15:17:28 -0600 Subject: [PATCH] feat(gui): add wildcard menu to web UI --- api/onnx_web/server/api.py | 5 ++++ gui/src/client/api.ts | 5 ++++ gui/src/client/base.ts | 2 ++ gui/src/client/local.ts | 3 +++ gui/src/components/input/PromptInput.tsx | 31 +++++++++++++++++------- gui/src/strings/de.ts | 1 + gui/src/strings/en.ts | 1 + gui/src/strings/es.ts | 1 + gui/src/strings/fr.ts | 1 + 9 files changed, 41 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index a9162f0f..00738184 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -161,6 +161,10 @@ def list_schedulers(server: ServerContext): return jsonify(get_pipeline_schedulers()) +def list_wildcards(server: ServerContext): + return jsonify(list(get_wildcard_data().keys())) + + def img2img(server: ServerContext, pool: DevicePoolExecutor): source_file = request.files.get("source") if source_file is None: @@ -597,6 +601,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu app.route("/api/settings/platforms")(wrap_route(list_platforms, server)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)), app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)), + app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)), app.route("/api/img2img", methods=["POST"])( wrap_route(img2img, server, pool=pool) ), diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 64fc89dc..e2c974f6 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -212,6 +212,11 @@ export function makeClient(root: string, token: Maybe = undefined, f = f translation: Record; }>; }, + async wildcards(): Promise> { + const path = makeApiUrl(root, 'settings', 'wildcards'); + const res = await f(path); + return await res.json() as Array; + }, async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); diff --git a/gui/src/client/base.ts b/gui/src/client/base.ts index 70e96706..62ef440a 100644 --- a/gui/src/client/base.ts +++ b/gui/src/client/base.ts @@ -51,6 +51,8 @@ export interface ApiClient { translation: Record; }>>; + wildcards(): Promise>; + /** * Start a txt2img pipeline. */ diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 561417a3..06dcd6b0 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -72,6 +72,9 @@ export const LOCAL_CLIENT = { async strings() { return {}; }, + async wildcards() { + throw new NoServerError(); + }, async restart() { throw new NoServerError(); }, diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index bc522569..a8081222 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -29,14 +29,6 @@ export interface PromptInputProps { export const PROMPT_GROUP = 75; -function splitPrompt(prompt: string): Array { - return prompt - .split(',') - .flatMap((phrase) => phrase.split(' ')) - .map((word) => word.trim()) - .filter((word) => word.length > 0); -} - export function PromptInput(props: PromptInputProps) { // eslint-disable-next-line @typescript-eslint/unbound-method const { selector, onChange } = props; @@ -48,12 +40,15 @@ export function PromptInput(props: PromptInputProps) { const models = useQuery(['models'], async () => client.models(), { staleTime: STALE_TIME, }); + const wildcards = useQuery(['wildcards'], async () => client.wildcards(), { + staleTime: STALE_TIME, + }); const { t } = useTranslation(); function addNetwork(type: string, name: string, weight = 1.0) { onChange({ - prompt: `<${type}:${name}:1.0> ${prompt}`, + prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`, negativePrompt, }); } @@ -64,6 +59,12 @@ export function PromptInput(props: PromptInputProps) { }); } + function addWildcard(name: string) { + onChange({ + prompt: `${prompt}, __${name}__`, + }); + } + const tokens = useMemo(() => { const networks = extractNetworks(prompt); return getNetworkTokens(models.data, networks); @@ -124,6 +125,18 @@ export function PromptInput(props: PromptInputProps) { addNetwork('lora', name); }} /> + result, + }} + onSelect={(name) => { + addWildcard(name); + }} + /> ; } diff --git a/gui/src/strings/de.ts b/gui/src/strings/de.ts index f5dd8d2c..01c1b02a 100644 --- a/gui/src/strings/de.ts +++ b/gui/src/strings/de.ts @@ -257,6 +257,7 @@ export const I18N_STRINGS_DE = { 'correction-first': 'Korrektur zuerst', 'correction-last': 'Korrektur zuletzt', }, + wildcard: '', }, }, }; diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index 304c3335..6b4e7acc 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -339,6 +339,7 @@ export const I18N_STRINGS_EN = { 'correction-first': 'Correction First', 'correction-last': 'Correction Last', }, + wildcard: 'Wildcard', } }, }; diff --git a/gui/src/strings/es.ts b/gui/src/strings/es.ts index 6630ef69..aed1dee8 100644 --- a/gui/src/strings/es.ts +++ b/gui/src/strings/es.ts @@ -257,6 +257,7 @@ export const I18N_STRINGS_ES = { 'correction-first': 'corrección primero', 'correction-last': 'última corrección', }, + wildcard: '', }, }, }; diff --git a/gui/src/strings/fr.ts b/gui/src/strings/fr.ts index 825e57fa..0cebf366 100644 --- a/gui/src/strings/fr.ts +++ b/gui/src/strings/fr.ts @@ -257,6 +257,7 @@ export const I18N_STRINGS_FR = { 'correction-first': '', 'correction-last': '', }, + wildcard: '', }, }, };