1
0
Fork 0

fix(gui): add prompt tokens to correct tab (#296)

This commit is contained in:
Sean Sube 2023-03-28 17:51:40 -05:00
parent 93fcfd1422
commit d19bbfc1d3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 60 additions and 35 deletions

View File

@ -13,32 +13,11 @@ import { Inpaint } from './tab/Inpaint.js';
import { Settings } from './tab/Settings.js'; import { Settings } from './tab/Settings.js';
import { Txt2Img } from './tab/Txt2Img.js'; import { Txt2Img } from './tab/Txt2Img.js';
import { Upscale } from './tab/Upscale.js'; import { Upscale } from './tab/Upscale.js';
import { getTab, TAB_LABELS } from './utils.js';
const REMOVE_HASH = /^#?(.*)$/;
const TAB_LABELS = [
'txt2img',
'img2img',
'inpaint',
'upscale',
'blend',
'settings',
];
export function OnnxWeb() { export function OnnxWeb() {
const [hash, setHash] = useHash(); const [hash, setHash] = useHash();
function tab(): string {
const match = hash.match(REMOVE_HASH);
if (doesExist(match)) {
const [_full, route] = Array.from(match);
if (route.length > 0) {
return route;
}
}
return TAB_LABELS[0];
}
return ( return (
<Container> <Container>
<Box sx={{ my: 4 }}> <Box sx={{ my: 4 }}>
@ -47,7 +26,7 @@ export function OnnxWeb() {
<Box sx={{ mx: 4, my: 4 }}> <Box sx={{ mx: 4, my: 4 }}>
<ModelControl /> <ModelControl />
</Box> </Box>
<TabContext value={tab()}> <TabContext value={getTab(hash)}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}> <Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
<TabList onChange={(_e, idx) => { <TabList onChange={(_e, idx) => {
setHash(idx); setHash(idx);

View File

@ -4,12 +4,14 @@ import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useQuery } from 'react-query'; import { useQuery } from 'react-query';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js'; import { ClientContext, StateContext } from '../../state.js';
import { QueryList } from '../input/QueryList.js'; import { QueryList } from '../input/QueryList.js';
import { QueryMenu } from '../input/QueryMenu.js'; import { QueryMenu } from '../input/QueryMenu.js';
import { getTab } from '../utils.js';
export function ModelControl() { export function ModelControl() {
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
@ -26,6 +28,33 @@ export function ModelControl() {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const [hash, _setHash] = useHash();
function addToken(type: string, name: string, weight = 1.0) {
const tab = getTab(hash);
const current = state.getState();
switch (tab) {
case 'txt2img': {
const { prompt } = current.txt2img;
current.setTxt2Img({
prompt: `<${type}:${name}:1.0> ${prompt}`,
});
break;
}
case 'img2img': {
const { prompt } = current.img2img;
current.setImg2Img({
prompt: `<${type}:${name}:1.0> ${prompt}`,
});
break;
}
default:
// not supported yet
}
}
return <Stack direction='column' spacing={2}> return <Stack direction='column' spacing={2}>
<Stack direction='row' spacing={2}> <Stack direction='row' spacing={2}>
<QueryList <QueryList
@ -110,12 +139,7 @@ export function ModelControl() {
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
const current = state.getState(); addToken('inversion', name);
const { prompt } = current.txt2img;
current.setTxt2Img({
prompt: `<inversion:${name}:1.0> ${prompt}`,
});
}} }}
/> />
<QueryMenu <QueryMenu
@ -127,12 +151,7 @@ export function ModelControl() {
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
const current = state.getState(); addToken('lora', name);
const { prompt } = current.txt2img;
current.setTxt2Img({
prompt: `<lora:${name}:1.0> ${prompt}`,
});
}} }}
/> />
</Stack> </Stack>

View File

@ -0,0 +1,19 @@
import { trimHash } from '../utils.js';
export const TAB_LABELS = [
'txt2img',
'img2img',
'inpaint',
'upscale',
'blend',
'settings',
] as const;
export function getTab(hash: string): string {
const route = trimHash(hash);
if (route.length > 0) {
return route;
}
return TAB_LABELS[0];
}

View File

@ -18,3 +18,11 @@ export function range(max: number): Array<number> {
export function visibleIndex(idx: number): string { export function visibleIndex(idx: number): string {
return (idx + 1).toFixed(0); return (idx + 1).toFixed(0);
} }
export function trimHash(val: string): string {
if (val[0] === '#') {
return val.slice(1);
}
return val;
}