fix(gui): add prompt tokens to correct tab (#296)
This commit is contained in:
parent
93fcfd1422
commit
d19bbfc1d3
|
@ -13,32 +13,11 @@ import { Inpaint } from './tab/Inpaint.js';
|
||||||
import { Settings } from './tab/Settings.js';
|
import { Settings } from './tab/Settings.js';
|
||||||
import { Txt2Img } from './tab/Txt2Img.js';
|
import { Txt2Img } from './tab/Txt2Img.js';
|
||||||
import { Upscale } from './tab/Upscale.js';
|
import { Upscale } from './tab/Upscale.js';
|
||||||
|
import { getTab, TAB_LABELS } from './utils.js';
|
||||||
const REMOVE_HASH = /^#?(.*)$/;
|
|
||||||
const TAB_LABELS = [
|
|
||||||
'txt2img',
|
|
||||||
'img2img',
|
|
||||||
'inpaint',
|
|
||||||
'upscale',
|
|
||||||
'blend',
|
|
||||||
'settings',
|
|
||||||
];
|
|
||||||
|
|
||||||
export function OnnxWeb() {
|
export function OnnxWeb() {
|
||||||
const [hash, setHash] = useHash();
|
const [hash, setHash] = useHash();
|
||||||
|
|
||||||
function tab(): string {
|
|
||||||
const match = hash.match(REMOVE_HASH);
|
|
||||||
if (doesExist(match)) {
|
|
||||||
const [_full, route] = Array.from(match);
|
|
||||||
if (route.length > 0) {
|
|
||||||
return route;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return TAB_LABELS[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Container>
|
<Container>
|
||||||
<Box sx={{ my: 4 }}>
|
<Box sx={{ my: 4 }}>
|
||||||
|
@ -47,7 +26,7 @@ export function OnnxWeb() {
|
||||||
<Box sx={{ mx: 4, my: 4 }}>
|
<Box sx={{ mx: 4, my: 4 }}>
|
||||||
<ModelControl />
|
<ModelControl />
|
||||||
</Box>
|
</Box>
|
||||||
<TabContext value={tab()}>
|
<TabContext value={getTab(hash)}>
|
||||||
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
||||||
<TabList onChange={(_e, idx) => {
|
<TabList onChange={(_e, idx) => {
|
||||||
setHash(idx);
|
setHash(idx);
|
||||||
|
|
|
@ -4,12 +4,14 @@ import * as React from 'react';
|
||||||
import { useContext } from 'react';
|
import { useContext } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useQuery } from 'react-query';
|
import { useQuery } from 'react-query';
|
||||||
|
import { useHash } from 'react-use/lib/useHash';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, StateContext } from '../../state.js';
|
import { ClientContext, StateContext } from '../../state.js';
|
||||||
import { QueryList } from '../input/QueryList.js';
|
import { QueryList } from '../input/QueryList.js';
|
||||||
import { QueryMenu } from '../input/QueryMenu.js';
|
import { QueryMenu } from '../input/QueryMenu.js';
|
||||||
|
import { getTab } from '../utils.js';
|
||||||
|
|
||||||
export function ModelControl() {
|
export function ModelControl() {
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
@ -26,6 +28,33 @@ export function ModelControl() {
|
||||||
staleTime: STALE_TIME,
|
staleTime: STALE_TIME,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const [hash, _setHash] = useHash();
|
||||||
|
|
||||||
|
function addToken(type: string, name: string, weight = 1.0) {
|
||||||
|
const tab = getTab(hash);
|
||||||
|
const current = state.getState();
|
||||||
|
|
||||||
|
|
||||||
|
switch (tab) {
|
||||||
|
case 'txt2img': {
|
||||||
|
const { prompt } = current.txt2img;
|
||||||
|
current.setTxt2Img({
|
||||||
|
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 'img2img': {
|
||||||
|
const { prompt } = current.img2img;
|
||||||
|
current.setImg2Img({
|
||||||
|
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// not supported yet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return <Stack direction='column' spacing={2}>
|
return <Stack direction='column' spacing={2}>
|
||||||
<Stack direction='row' spacing={2}>
|
<Stack direction='row' spacing={2}>
|
||||||
<QueryList
|
<QueryList
|
||||||
|
@ -110,12 +139,7 @@ export function ModelControl() {
|
||||||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||||
}}
|
}}
|
||||||
onSelect={(name) => {
|
onSelect={(name) => {
|
||||||
const current = state.getState();
|
addToken('inversion', name);
|
||||||
const { prompt } = current.txt2img;
|
|
||||||
|
|
||||||
current.setTxt2Img({
|
|
||||||
prompt: `<inversion:${name}:1.0> ${prompt}`,
|
|
||||||
});
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<QueryMenu
|
<QueryMenu
|
||||||
|
@ -127,12 +151,7 @@ export function ModelControl() {
|
||||||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||||
}}
|
}}
|
||||||
onSelect={(name) => {
|
onSelect={(name) => {
|
||||||
const current = state.getState();
|
addToken('lora', name);
|
||||||
const { prompt } = current.txt2img;
|
|
||||||
|
|
||||||
current.setTxt2Img({
|
|
||||||
prompt: `<lora:${name}:1.0> ${prompt}`,
|
|
||||||
});
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
import { trimHash } from '../utils.js';
|
||||||
|
|
||||||
|
export const TAB_LABELS = [
|
||||||
|
'txt2img',
|
||||||
|
'img2img',
|
||||||
|
'inpaint',
|
||||||
|
'upscale',
|
||||||
|
'blend',
|
||||||
|
'settings',
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
export function getTab(hash: string): string {
|
||||||
|
const route = trimHash(hash);
|
||||||
|
if (route.length > 0) {
|
||||||
|
return route;
|
||||||
|
}
|
||||||
|
|
||||||
|
return TAB_LABELS[0];
|
||||||
|
}
|
|
@ -18,3 +18,11 @@ export function range(max: number): Array<number> {
|
||||||
export function visibleIndex(idx: number): string {
|
export function visibleIndex(idx: number): string {
|
||||||
return (idx + 1).toFixed(0);
|
return (idx + 1).toFixed(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function trimHash(val: string): string {
|
||||||
|
if (val[0] === '#') {
|
||||||
|
return val.slice(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue