feat(gui): get platforms and schedulers from server
This commit is contained in:
parent
c70728d501
commit
ce06837fbc
|
@ -16,6 +16,8 @@ export interface ApiResponse {
|
|||
}
|
||||
|
||||
export interface ApiClient {
|
||||
platforms(): Promise<Array<string>>;
|
||||
schedulers(): Promise<Array<string>>;
|
||||
txt2img(params: Txt2ImgParams): Promise<ApiResponse>;
|
||||
}
|
||||
|
||||
|
@ -38,6 +40,16 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
let pending: Promise<ApiResponse> | undefined;
|
||||
|
||||
return {
|
||||
async schedulers(): Promise<Array<string>> {
|
||||
const path = new URL('/settings/schedulers', root);
|
||||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async platforms(): Promise<Array<string>> {
|
||||
const path = new URL('/settings/platforms', root);
|
||||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async txt2img(params: Txt2ImgParams): Promise<ApiResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
|
|
|
@ -1,12 +1,31 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation } from 'react-query';
|
||||
import { useMutation, useQuery } from 'react-query';
|
||||
|
||||
import { ApiClient } from '../api/client.js';
|
||||
import { ImageControl, ImageParams } from './ImageControl.js';
|
||||
|
||||
const { useState } = React;
|
||||
|
||||
const STALE_TIME = 3_000;
|
||||
|
||||
// TODO: set up i18next
|
||||
const PLATFORM_NAMES: Record<string, string> = {
|
||||
amd: 'AMD GPU',
|
||||
cpu: 'CPU',
|
||||
};
|
||||
|
||||
const SCHEDULER_NAMES: Record<string, string> = {
|
||||
'ddim': 'DDIM',
|
||||
'ddpm': 'DDPM',
|
||||
'dpm-multi': 'DPM Multistep',
|
||||
'euler': 'Euler',
|
||||
'euler-a': 'Euler Ancestral',
|
||||
'lms-discrete': 'LMS Discrete',
|
||||
'pndm': 'PNDM',
|
||||
};
|
||||
|
||||
export interface Txt2ImgProps {
|
||||
client: ApiClient;
|
||||
}
|
||||
|
@ -19,6 +38,12 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
}
|
||||
|
||||
const generate = useMutation(generateImage);
|
||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const schedulers = useQuery('schedulers', async () => client.schedulers(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const [prompt, setPrompt] = useState('an astronaut eating a hamburger');
|
||||
const [params, setParams] = useState<ImageParams>({
|
||||
|
@ -28,6 +53,7 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
height: 512,
|
||||
});
|
||||
const [scheduler, setScheduler] = useState('euler-a');
|
||||
const [platform, setPlatform] = useState('cpu');
|
||||
|
||||
function renderImage() {
|
||||
switch (generate.status) {
|
||||
|
@ -37,6 +63,8 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
} else {
|
||||
return <div>Unknown error generating image.</div>;
|
||||
}
|
||||
case 'loading':
|
||||
return <div>Generating...</div>;
|
||||
case 'success':
|
||||
return <img src={generate.data.output} />;
|
||||
default:
|
||||
|
@ -44,23 +72,54 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
}
|
||||
}
|
||||
|
||||
function renderSchedulers() {
|
||||
switch (schedulers.status) {
|
||||
case 'error':
|
||||
return <MenuItem value='error'>Error</MenuItem>;
|
||||
case 'loading':
|
||||
return <MenuItem value='loading'>Loading</MenuItem>;
|
||||
case 'success':
|
||||
return mustExist(schedulers.data).map((name) => <MenuItem key={name} value={name}>{SCHEDULER_NAMES[name]}</MenuItem>);
|
||||
default:
|
||||
return <MenuItem value='error'>Unknown Error</MenuItem>;
|
||||
}
|
||||
}
|
||||
|
||||
function renderPlatforms() {
|
||||
switch (platforms.status) {
|
||||
case 'error':
|
||||
return <MenuItem value='error'>Error</MenuItem>;
|
||||
case 'loading':
|
||||
return <MenuItem value='loading'>Loading</MenuItem>;
|
||||
case 'success':
|
||||
return mustExist(platforms.data).map((name) => <MenuItem key={name} value={name}>{PLATFORM_NAMES[name]}</MenuItem>);
|
||||
default:
|
||||
return <MenuItem value='error'>Unknown Error</MenuItem>;
|
||||
}
|
||||
}
|
||||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
<Select
|
||||
value={scheduler}
|
||||
label="Scheduler"
|
||||
onChange={(event) => {
|
||||
setScheduler(event.target.value);
|
||||
}}
|
||||
>
|
||||
<MenuItem value='ddim'>DDIM</MenuItem>
|
||||
<MenuItem value='ddpm'>DDPM</MenuItem>
|
||||
<MenuItem value='dpm-multi'>DPM Multistep</MenuItem>
|
||||
<MenuItem value='euler'>Euler</MenuItem>
|
||||
<MenuItem value='euler-a'>Euler Ancestral</MenuItem>
|
||||
<MenuItem value='lms-discrete'>LMS Discrete</MenuItem>
|
||||
<MenuItem value='pndm'>PNDM</MenuItem>
|
||||
</Select>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<Select
|
||||
value={scheduler}
|
||||
label="Scheduler"
|
||||
onChange={(event) => {
|
||||
setScheduler(event.target.value);
|
||||
}}
|
||||
>
|
||||
{renderSchedulers()}
|
||||
</Select>
|
||||
<Select
|
||||
value={platform}
|
||||
label="Platform"
|
||||
onChange={(event) => {
|
||||
setPlatform(event.target.value);
|
||||
}}
|
||||
>
|
||||
{renderPlatforms()}
|
||||
</Select>
|
||||
</Stack>
|
||||
<ImageControl params={params} onChange={(newParams) => {
|
||||
setParams(newParams);
|
||||
}} />
|
||||
|
|
Loading…
Reference in New Issue