fix(gui): add prompt tokens to correct tab (#296)
This commit is contained in:
parent
93fcfd1422
commit
d19bbfc1d3
|
@ -13,32 +13,11 @@ import { Inpaint } from './tab/Inpaint.js';
|
|||
import { Settings } from './tab/Settings.js';
|
||||
import { Txt2Img } from './tab/Txt2Img.js';
|
||||
import { Upscale } from './tab/Upscale.js';
|
||||
|
||||
const REMOVE_HASH = /^#?(.*)$/;
|
||||
const TAB_LABELS = [
|
||||
'txt2img',
|
||||
'img2img',
|
||||
'inpaint',
|
||||
'upscale',
|
||||
'blend',
|
||||
'settings',
|
||||
];
|
||||
import { getTab, TAB_LABELS } from './utils.js';
|
||||
|
||||
export function OnnxWeb() {
|
||||
const [hash, setHash] = useHash();
|
||||
|
||||
function tab(): string {
|
||||
const match = hash.match(REMOVE_HASH);
|
||||
if (doesExist(match)) {
|
||||
const [_full, route] = Array.from(match);
|
||||
if (route.length > 0) {
|
||||
return route;
|
||||
}
|
||||
}
|
||||
|
||||
return TAB_LABELS[0];
|
||||
}
|
||||
|
||||
return (
|
||||
<Container>
|
||||
<Box sx={{ my: 4 }}>
|
||||
|
@ -47,7 +26,7 @@ export function OnnxWeb() {
|
|||
<Box sx={{ mx: 4, my: 4 }}>
|
||||
<ModelControl />
|
||||
</Box>
|
||||
<TabContext value={tab()}>
|
||||
<TabContext value={getTab(hash)}>
|
||||
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
||||
<TabList onChange={(_e, idx) => {
|
||||
setHash(idx);
|
||||
|
|
|
@ -4,12 +4,14 @@ import * as React from 'react';
|
|||
import { useContext } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useQuery } from 'react-query';
|
||||
import { useHash } from 'react-use/lib/useHash';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, StateContext } from '../../state.js';
|
||||
import { QueryList } from '../input/QueryList.js';
|
||||
import { QueryMenu } from '../input/QueryMenu.js';
|
||||
import { getTab } from '../utils.js';
|
||||
|
||||
export function ModelControl() {
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
@ -26,6 +28,33 @@ export function ModelControl() {
|
|||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const [hash, _setHash] = useHash();
|
||||
|
||||
function addToken(type: string, name: string, weight = 1.0) {
|
||||
const tab = getTab(hash);
|
||||
const current = state.getState();
|
||||
|
||||
|
||||
switch (tab) {
|
||||
case 'txt2img': {
|
||||
const { prompt } = current.txt2img;
|
||||
current.setTxt2Img({
|
||||
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'img2img': {
|
||||
const { prompt } = current.img2img;
|
||||
current.setImg2Img({
|
||||
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// not supported yet
|
||||
}
|
||||
}
|
||||
|
||||
return <Stack direction='column' spacing={2}>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<QueryList
|
||||
|
@ -110,12 +139,7 @@ export function ModelControl() {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
const current = state.getState();
|
||||
const { prompt } = current.txt2img;
|
||||
|
||||
current.setTxt2Img({
|
||||
prompt: `<inversion:${name}:1.0> ${prompt}`,
|
||||
});
|
||||
addToken('inversion', name);
|
||||
}}
|
||||
/>
|
||||
<QueryMenu
|
||||
|
@ -127,12 +151,7 @@ export function ModelControl() {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
const current = state.getState();
|
||||
const { prompt } = current.txt2img;
|
||||
|
||||
current.setTxt2Img({
|
||||
prompt: `<lora:${name}:1.0> ${prompt}`,
|
||||
});
|
||||
addToken('lora', name);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import { trimHash } from '../utils.js';
|
||||
|
||||
export const TAB_LABELS = [
|
||||
'txt2img',
|
||||
'img2img',
|
||||
'inpaint',
|
||||
'upscale',
|
||||
'blend',
|
||||
'settings',
|
||||
] as const;
|
||||
|
||||
export function getTab(hash: string): string {
|
||||
const route = trimHash(hash);
|
||||
if (route.length > 0) {
|
||||
return route;
|
||||
}
|
||||
|
||||
return TAB_LABELS[0];
|
||||
}
|
|
@ -18,3 +18,11 @@ export function range(max: number): Array<number> {
|
|||
export function visibleIndex(idx: number): string {
|
||||
return (idx + 1).toFixed(0);
|
||||
}
|
||||
|
||||
export function trimHash(val: string): string {
|
||||
if (val[0] === '#') {
|
||||
return val.slice(1);
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue