1
0
Fork 0

feat(gui): move platform selector outside of mode tabs

This commit is contained in:
Sean Sube 2023-01-06 11:12:32 -06:00
parent 6d839eaa29
commit 45a097a56b
4 changed files with 46 additions and 39 deletions

View File

@ -1,10 +1,11 @@
import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, Tab, Typography } from '@mui/material';
import { Box, Container, Stack, Tab, Typography } from '@mui/material';
import * as React from 'react';
import { useQuery } from 'react-query';
import { ApiClient } from '../api/client.js';
import { Config } from '../config.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { QueryList } from './QueryList.js';
import { STALE_TIME, Txt2Img } from './Txt2Img.js';
@ -15,17 +16,17 @@ export interface OnnxWebProps {
config: Config;
}
const MODEL_LABELS = {
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
};
export function OnnxWeb(props: OnnxWebProps) {
const { client, config } = props;
const [tab, setTab] = useState('1');
const [model, setModel] = useState(config.default.model);
const [platform, setPlatform] = useState(config.default.platform);
const models = useQuery('models', async () => props.client.models(), {
const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
@ -38,9 +39,18 @@ export function OnnxWeb(props: OnnxWebProps) {
</Typography>
</Box>
<Box sx={{ my: 4 }}>
<QueryList result={models} labels={MODEL_LABELS} value={model} onChange={(value) => {
setModel(value);
}} />
<Stack direction='row' spacing={2}>
<QueryList result={models} labels={MODEL_LABELS} value={model}
onChange={(value) => {
setModel(value);
}}
/>
<QueryList result={platforms} labels={PLATFORM_LABELS} value={platform}
onChange={(value) => {
setPlatform(value);
}}
/>
</Stack>
</Box>
<TabContext value={tab}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
@ -53,7 +63,7 @@ export function OnnxWeb(props: OnnxWebProps) {
</TabList>
</Box>
<TabPanel value="1">
<Txt2Img client={client} config={config} model={model} />
<Txt2Img client={client} config={config} model={model} platform={platform} />
</TabPanel>
<TabPanel value="2">
<Box>

View File

@ -4,6 +4,7 @@ import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js';
import { Config } from '../config.js';
import { SCHEDULER_LABELS } from '../strings.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl, ImageParams } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js';
@ -12,45 +13,28 @@ import { QueryList } from './QueryList.js';
const { useState } = React;
export const STALE_TIME = 3_000;
// TODO: set up i18next
const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU',
cpu: 'CPU',
};
const SCHEDULER_LABELS: Record<string, string> = {
'ddim': 'DDIM',
'ddpm': 'DDPM',
'dpm-multi': 'DPM Multistep',
'euler': 'Euler',
'euler-a': 'Euler Ancestral',
'lms-discrete': 'LMS Discrete',
'pndm': 'PNDM',
};
export interface Txt2ImgProps {
client: ApiClient;
config: Config;
model: string;
platform: string;
}
export function Txt2Img(props: Txt2ImgProps) {
const { client, model } = props;
const { client, config, model, platform } = props;
async function generateImage() {
return client.txt2img({
...params,
model,
platform,
prompt,
scheduler,
});
}
const generate = useMutation(generateImage);
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
const schedulers = useQuery('schedulers', async () => client.schedulers(), {
staleTime: STALE_TIME,
});
@ -61,9 +45,8 @@ export function Txt2Img(props: Txt2ImgProps) {
width: 512,
height: 512,
});
const [prompt, setPrompt] = useState(props.config.default.prompt);
const [platform, setPlatform] = useState(props.config.default.platform);
const [scheduler, setScheduler] = useState(props.config.default.scheduler);
const [prompt, setPrompt] = useState(config.default.prompt);
const [scheduler, setScheduler] = useState(config.default.scheduler);
return <Box>
<Stack spacing={2}>
@ -73,11 +56,6 @@ export function Txt2Img(props: Txt2ImgProps) {
setScheduler(value);
}}
/>
<QueryList result={platforms} value={platform} labels={PLATFORM_LABELS}
onChange={(value) => {
setPlatform(value);
}}
/>
</Stack>
<ImageControl params={params} onChange={(newParams) => {
setParams(newParams);

19
gui/src/strings.ts Normal file
View File

@ -0,0 +1,19 @@
// TODO: set up i18next
export const MODEL_LABELS = {
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
};
export const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU',
cpu: 'CPU',
};
export const SCHEDULER_LABELS: Record<string, string> = {
'ddim': 'DDIM',
'ddpm': 'DDPM',
'dpm-multi': 'DPM Multistep',
'euler': 'Euler',
'euler-a': 'Euler Ancestral',
'lms-discrete': 'LMS Discrete',
'pndm': 'PNDM',
};

View File