feat(gui): move platform selector outside of mode tabs
This commit is contained in:
parent
6d839eaa29
commit
45a097a56b
|
@ -1,10 +1,11 @@
|
|||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||
import { Box, Container, Tab, Typography } from '@mui/material';
|
||||
import { Box, Container, Stack, Tab, Typography } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
|
||||
import { ApiClient } from '../api/client.js';
|
||||
import { Config } from '../config.js';
|
||||
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
import { STALE_TIME, Txt2Img } from './Txt2Img.js';
|
||||
|
||||
|
@ -15,17 +16,17 @@ export interface OnnxWebProps {
|
|||
config: Config;
|
||||
}
|
||||
|
||||
const MODEL_LABELS = {
|
||||
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
|
||||
};
|
||||
|
||||
export function OnnxWeb(props: OnnxWebProps) {
|
||||
const { client, config } = props;
|
||||
|
||||
const [tab, setTab] = useState('1');
|
||||
const [model, setModel] = useState(config.default.model);
|
||||
const [platform, setPlatform] = useState(config.default.platform);
|
||||
|
||||
const models = useQuery('models', async () => props.client.models(), {
|
||||
const models = useQuery('models', async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
|
@ -38,9 +39,18 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
</Typography>
|
||||
</Box>
|
||||
<Box sx={{ my: 4 }}>
|
||||
<QueryList result={models} labels={MODEL_LABELS} value={model} onChange={(value) => {
|
||||
setModel(value);
|
||||
}} />
|
||||
<Stack direction='row' spacing={2}>
|
||||
<QueryList result={models} labels={MODEL_LABELS} value={model}
|
||||
onChange={(value) => {
|
||||
setModel(value);
|
||||
}}
|
||||
/>
|
||||
<QueryList result={platforms} labels={PLATFORM_LABELS} value={platform}
|
||||
onChange={(value) => {
|
||||
setPlatform(value);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
</Box>
|
||||
<TabContext value={tab}>
|
||||
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
||||
|
@ -53,7 +63,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
</TabList>
|
||||
</Box>
|
||||
<TabPanel value="1">
|
||||
<Txt2Img client={client} config={config} model={model} />
|
||||
<Txt2Img client={client} config={config} model={model} platform={platform} />
|
||||
</TabPanel>
|
||||
<TabPanel value="2">
|
||||
<Box>
|
||||
|
|
|
@ -4,6 +4,7 @@ import { useMutation, useQuery } from 'react-query';
|
|||
|
||||
import { ApiClient } from '../api/client.js';
|
||||
import { Config } from '../config.js';
|
||||
import { SCHEDULER_LABELS } from '../strings.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl, ImageParams } from './ImageControl.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
|
@ -12,45 +13,28 @@ import { QueryList } from './QueryList.js';
|
|||
const { useState } = React;
|
||||
|
||||
export const STALE_TIME = 3_000;
|
||||
|
||||
// TODO: set up i18next
|
||||
const PLATFORM_LABELS: Record<string, string> = {
|
||||
amd: 'AMD GPU',
|
||||
cpu: 'CPU',
|
||||
};
|
||||
|
||||
const SCHEDULER_LABELS: Record<string, string> = {
|
||||
'ddim': 'DDIM',
|
||||
'ddpm': 'DDPM',
|
||||
'dpm-multi': 'DPM Multistep',
|
||||
'euler': 'Euler',
|
||||
'euler-a': 'Euler Ancestral',
|
||||
'lms-discrete': 'LMS Discrete',
|
||||
'pndm': 'PNDM',
|
||||
};
|
||||
|
||||
export interface Txt2ImgProps {
|
||||
client: ApiClient;
|
||||
config: Config;
|
||||
|
||||
model: string;
|
||||
platform: string;
|
||||
}
|
||||
|
||||
export function Txt2Img(props: Txt2ImgProps) {
|
||||
const { client, model } = props;
|
||||
const { client, config, model, platform } = props;
|
||||
|
||||
async function generateImage() {
|
||||
return client.txt2img({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
prompt,
|
||||
scheduler,
|
||||
});
|
||||
}
|
||||
|
||||
const generate = useMutation(generateImage);
|
||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const schedulers = useQuery('schedulers', async () => client.schedulers(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
@ -61,9 +45,8 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
width: 512,
|
||||
height: 512,
|
||||
});
|
||||
const [prompt, setPrompt] = useState(props.config.default.prompt);
|
||||
const [platform, setPlatform] = useState(props.config.default.platform);
|
||||
const [scheduler, setScheduler] = useState(props.config.default.scheduler);
|
||||
const [prompt, setPrompt] = useState(config.default.prompt);
|
||||
const [scheduler, setScheduler] = useState(config.default.scheduler);
|
||||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
|
@ -73,11 +56,6 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
setScheduler(value);
|
||||
}}
|
||||
/>
|
||||
<QueryList result={platforms} value={platform} labels={PLATFORM_LABELS}
|
||||
onChange={(value) => {
|
||||
setPlatform(value);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<ImageControl params={params} onChange={(newParams) => {
|
||||
setParams(newParams);
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// TODO: set up i18next
|
||||
export const MODEL_LABELS = {
|
||||
'stable-diffusion-onnx-v1-5': 'Stable Diffusion v1.5',
|
||||
};
|
||||
|
||||
export const PLATFORM_LABELS: Record<string, string> = {
|
||||
amd: 'AMD GPU',
|
||||
cpu: 'CPU',
|
||||
};
|
||||
|
||||
export const SCHEDULER_LABELS: Record<string, string> = {
|
||||
'ddim': 'DDIM',
|
||||
'ddpm': 'DDPM',
|
||||
'dpm-multi': 'DPM Multistep',
|
||||
'euler': 'Euler',
|
||||
'euler-a': 'Euler Ancestral',
|
||||
'lms-discrete': 'LMS Discrete',
|
||||
'pndm': 'PNDM',
|
||||
};
|
Loading…
Reference in New Issue