1
0
Fork 0

feat(gui): get platforms and schedulers from server

This commit is contained in:
Sean Sube 2023-01-05 21:55:05 -06:00
parent c70728d501
commit ce06837fbc
2 changed files with 87 additions and 16 deletions

View File

@ -16,6 +16,8 @@ export interface ApiResponse {
}
export interface ApiClient {
platforms(): Promise<Array<string>>;
schedulers(): Promise<Array<string>>;
txt2img(params: Txt2ImgParams): Promise<ApiResponse>;
}
@ -38,6 +40,16 @@ export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<ApiResponse> | undefined;
return {
async schedulers(): Promise<Array<string>> {
const path = new URL('/settings/schedulers', root);
const res = await f(path);
return await res.json() as Array<string>;
},
async platforms(): Promise<Array<string>> {
const path = new URL('/settings/platforms', root);
const res = await f(path);
return await res.json() as Array<string>;
},
async txt2img(params: Txt2ImgParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;

View File

@ -1,12 +1,31 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js';
import { ImageControl, ImageParams } from './ImageControl.js';
const { useState } = React;
const STALE_TIME = 3_000;
// TODO: set up i18next
const PLATFORM_NAMES: Record<string, string> = {
amd: 'AMD GPU',
cpu: 'CPU',
};
const SCHEDULER_NAMES: Record<string, string> = {
'ddim': 'DDIM',
'ddpm': 'DDPM',
'dpm-multi': 'DPM Multistep',
'euler': 'Euler',
'euler-a': 'Euler Ancestral',
'lms-discrete': 'LMS Discrete',
'pndm': 'PNDM',
};
export interface Txt2ImgProps {
client: ApiClient;
}
@ -19,6 +38,12 @@ export function Txt2Img(props: Txt2ImgProps) {
}
const generate = useMutation(generateImage);
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
const schedulers = useQuery('schedulers', async () => client.schedulers(), {
staleTime: STALE_TIME,
});
const [prompt, setPrompt] = useState('an astronaut eating a hamburger');
const [params, setParams] = useState<ImageParams>({
@ -28,6 +53,7 @@ export function Txt2Img(props: Txt2ImgProps) {
height: 512,
});
const [scheduler, setScheduler] = useState('euler-a');
const [platform, setPlatform] = useState('cpu');
function renderImage() {
switch (generate.status) {
@ -37,6 +63,8 @@ export function Txt2Img(props: Txt2ImgProps) {
} else {
return <div>Unknown error generating image.</div>;
}
case 'loading':
return <div>Generating...</div>;
case 'success':
return <img src={generate.data.output} />;
default:
@ -44,23 +72,54 @@ export function Txt2Img(props: Txt2ImgProps) {
}
}
function renderSchedulers() {
switch (schedulers.status) {
case 'error':
return <MenuItem value='error'>Error</MenuItem>;
case 'loading':
return <MenuItem value='loading'>Loading</MenuItem>;
case 'success':
return mustExist(schedulers.data).map((name) => <MenuItem key={name} value={name}>{SCHEDULER_NAMES[name]}</MenuItem>);
default:
return <MenuItem value='error'>Unknown Error</MenuItem>;
}
}
function renderPlatforms() {
switch (platforms.status) {
case 'error':
return <MenuItem value='error'>Error</MenuItem>;
case 'loading':
return <MenuItem value='loading'>Loading</MenuItem>;
case 'success':
return mustExist(platforms.data).map((name) => <MenuItem key={name} value={name}>{PLATFORM_NAMES[name]}</MenuItem>);
default:
return <MenuItem value='error'>Unknown Error</MenuItem>;
}
}
return <Box>
<Stack spacing={2}>
<Select
value={scheduler}
label="Scheduler"
onChange={(event) => {
setScheduler(event.target.value);
}}
>
<MenuItem value='ddim'>DDIM</MenuItem>
<MenuItem value='ddpm'>DDPM</MenuItem>
<MenuItem value='dpm-multi'>DPM Multistep</MenuItem>
<MenuItem value='euler'>Euler</MenuItem>
<MenuItem value='euler-a'>Euler Ancestral</MenuItem>
<MenuItem value='lms-discrete'>LMS Discrete</MenuItem>
<MenuItem value='pndm'>PNDM</MenuItem>
</Select>
<Stack direction='row' spacing={2}>
<Select
value={scheduler}
label="Scheduler"
onChange={(event) => {
setScheduler(event.target.value);
}}
>
{renderSchedulers()}
</Select>
<Select
value={platform}
label="Platform"
onChange={(event) => {
setPlatform(event.target.value);
}}
>
{renderPlatforms()}
</Select>
</Stack>
<ImageControl params={params} onChange={(newParams) => {
setParams(newParams);
}} />