diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index d85aacdf..0c3adb5f 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -233,6 +233,8 @@ def pipeline_from_request(pipeline): # image params prompt = request.args.get('prompt', default_prompt) + negative_prompt = request.args.get('negative', None); + cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0) steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps) height = get_and_clamp_int(request.args, 'height', default_height, max_height) @@ -246,7 +248,7 @@ def pipeline_from_request(pipeline): (user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt)) pipe = load_pipeline(pipeline, model, provider, scheduler) - return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe) + return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe) @app.route('/img2img', methods=['POST']) @@ -257,17 +259,18 @@ def img2img(): strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0) - (model, provider, scheduler, prompt, cfg, steps, height, + (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline) rng = np.random.RandomState(seed) image = pipe( - prompt=prompt, - image=input_image, - num_inference_steps=steps, - guidance_scale=cfg, - strength=strength, + prompt, generator=rng, + guidance_scale=cfg, + image=input_image, + negative_prompt=negative_prompt, + num_inference_steps=steps, + strength=strength, ).images[0] (output_file, output_full) = make_output_path('img2img', (prompt, cfg, steps, height, width, seed)) @@ -286,13 +289,14 @@ def img2img(): 'width': default_width, 'prompt': prompt, 'seed': seed, + 'negativePrompt': negative_prompt, } }) @app.route('/txt2img', methods=['POST']) def txt2img(): - (model, provider, scheduler, prompt, cfg, steps, height, + (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline) latents = get_latents_from_seed(seed, width, height) @@ -302,10 +306,11 @@ def txt2img(): prompt, height, width, - num_inference_steps=steps, + generator=rng, guidance_scale=cfg, latents=latents, - generator=rng, + negative_prompt=negative_prompt, + num_inference_steps=steps, ).images[0] (output_file, output_full) = make_output_path('txt2img', (prompt, cfg, steps, height, width, seed)) @@ -324,6 +329,7 @@ def txt2img(): 'width': width, 'prompt': prompt, 'seed': seed, + 'negativePrompt': negative_prompt, } }) diff --git a/gui/src/api/client.ts b/gui/src/api/client.ts index 7577dc69..76524a44 100644 --- a/gui/src/api/client.ts +++ b/gui/src/api/client.ts @@ -1,42 +1,41 @@ import { doesExist } from '@apextoaster/js-utils'; -export interface Img2ImgParams { +export interface BaseImgParams { + /** + * Which ONNX model to use. + */ model?: string; + + /** + * Hardware accelerator or CPU mode. + */ platform?: string; + + /** + * Scheduling algorithm. + */ scheduler?: string; prompt: string; + negativePrompt?: string; + cfg: number; steps: number; + seed: number; +} - seed?: number; - +export interface Img2ImgParams extends BaseImgParams { source: File; } -export interface Txt2ImgParams { - model?: string; - platform?: string; - scheduler?: string; - - prompt: string; - cfg: number; - steps: number; +export type Img2ImgResponse = Required>; +export interface Txt2ImgParams extends BaseImgParams { width?: number; height?: number; - seed?: number; } -export interface Txt2ImgResponse extends Txt2ImgParams { - model: string; - platform: string; - scheduler: string; - - width: number; - height: number; - seed: number; -} +export type Txt2ImgResponse = Required; export interface ApiResponse { output: string; @@ -71,6 +70,37 @@ export async function imageFromResponse(root: string, res: Response): Promise | undefined; @@ -95,27 +125,7 @@ export function makeClient(root: string, f = fetch): ApiClient { return pending; } - const url = new URL('img2img', root); - url.searchParams.append('cfg', params.cfg.toFixed(0)); - url.searchParams.append('steps', params.steps.toFixed(0)); - - if (doesExist(params.model)) { - url.searchParams.append('model', params.model); - } - - if (doesExist(params.platform)) { - url.searchParams.append('platform', params.platform); - } - - if (doesExist(params.scheduler)) { - url.searchParams.append('scheduler', params.scheduler); - } - - if (doesExist(params.seed)) { - url.searchParams.append('seed', params.seed.toFixed(0)); - } - - url.searchParams.append('prompt', params.prompt); + const url = makeImageURL(root, 'img2img', params); const body = new FormData(); body.append('source', params.source, 'source'); @@ -135,9 +145,7 @@ export function makeClient(root: string, f = fetch): ApiClient { return pending; } - const url = new URL('txt2img', root); - url.searchParams.append('cfg', params.cfg.toFixed(0)); - url.searchParams.append('steps', params.steps.toFixed(0)); + const url = makeImageURL(root, 'txt2img', params); if (doesExist(params.width)) { url.searchParams.append('width', params.width.toFixed(0)); @@ -147,24 +155,6 @@ export function makeClient(root: string, f = fetch): ApiClient { url.searchParams.append('height', params.height.toFixed(0)); } - if (doesExist(params.seed)) { - url.searchParams.append('seed', params.seed.toFixed(0)); - } - - if (doesExist(params.model)) { - url.searchParams.append('model', params.model); - } - - if (doesExist(params.platform)) { - url.searchParams.append('platform', params.platform); - } - - if (doesExist(params.scheduler)) { - url.searchParams.append('scheduler', params.scheduler); - } - - url.searchParams.append('prompt', params.prompt); - pending = f(url, { method: 'POST', }).then((res) => imageFromResponse(root, res)).finally(() => { diff --git a/gui/src/components/ImageControl.tsx b/gui/src/components/ImageControl.tsx index a9169c95..8a8ac64f 100644 --- a/gui/src/components/ImageControl.tsx +++ b/gui/src/components/ImageControl.tsx @@ -1,21 +1,14 @@ import { doesExist } from '@apextoaster/js-utils'; -import { IconButton, Stack } from '@mui/material'; +import { IconButton, Stack, TextField } from '@mui/material'; import { Casino } from '@mui/icons-material'; import * as React from 'react'; import { NumericField } from './NumericField.js'; - -export interface ImageParams { - cfg: number; - seed: number; - steps: number; - width: number; - height: number; -} +import { BaseImgParams } from '../api/client.js'; export interface ImageControlProps { - params: ImageParams; - onChange?: (params: ImageParams) => void; + params: BaseImgParams; + onChange?: (params: BaseImgParams) => void; } export function ImageControl(props: ImageControlProps) { @@ -54,38 +47,6 @@ export function ImageControl(props: ImageControlProps) { }} /> - - { - if (doesExist(props.onChange)) { - props.onChange({ - ...params, - width, - }); - } - }} - /> - { - if (doesExist(props.onChange)) { - props.onChange({ - ...params, - height, - }); - } - }} - /> - + { + if (doesExist(props.onChange)) { + props.onChange({ + ...params, + prompt: event.target.value, + }); + } + }} /> + { + if (doesExist(props.onChange)) { + props.onChange({ + ...params, + negativePrompt: event.target.value, + }); + } + }} /> ; } diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index 6654c48e..61e8297a 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -1,13 +1,13 @@ import { doesExist, mustExist } from '@apextoaster/js-utils'; -import { Box, Button, Stack, TextField } from '@mui/material'; +import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; import { useMutation, useQuery } from 'react-query'; -import { ApiClient } from '../api/client.js'; +import { ApiClient, BaseImgParams } from '../api/client.js'; import { Config } from '../config.js'; import { SCHEDULER_LABELS } from '../strings.js'; import { ImageCard } from './ImageCard.js'; -import { ImageControl, ImageParams } from './ImageControl.js'; +import { ImageControl } from './ImageControl.js'; import { MutationHistory } from './MutationHistory.js'; import { QueryList } from './QueryList.js'; @@ -30,7 +30,6 @@ export function Img2Img(props: Img2ImgProps) { ...params, model, platform, - prompt, scheduler, source: mustExist(source), // TODO: show an error if this doesn't exist }); @@ -51,14 +50,12 @@ export function Img2Img(props: Img2ImgProps) { }); const [source, setSource] = useState(); - const [params, setParams] = useState({ + const [params, setParams] = useState({ cfg: 6, seed: -1, steps: 25, - width: 512, - height: 512, + prompt: config.default.prompt, }); - const [prompt, setPrompt] = useState(config.default.prompt); const [scheduler, setScheduler] = useState(config.default.scheduler); return @@ -79,9 +76,6 @@ export function Img2Img(props: Img2ImgProps) { { setParams(newParams); }} /> - { - setPrompt(event.target.value); - }} /> a.output === b.output} diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index 8d1dd2e4..c68b40f7 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -1,13 +1,14 @@ -import { Box, Button, Stack, TextField } from '@mui/material'; +import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; import { useMutation, useQuery } from 'react-query'; -import { ApiClient } from '../api/client.js'; +import { ApiClient, BaseImgParams } from '../api/client.js'; import { Config } from '../config.js'; import { SCHEDULER_LABELS } from '../strings.js'; import { ImageCard } from './ImageCard.js'; -import { ImageControl, ImageParams } from './ImageControl.js'; +import { ImageControl } from './ImageControl.js'; import { MutationHistory } from './MutationHistory.js'; +import { NumericField } from './NumericField.js'; import { QueryList } from './QueryList.js'; const { useState } = React; @@ -29,8 +30,9 @@ export function Txt2Img(props: Txt2ImgProps) { ...params, model, platform, - prompt, scheduler, + height, + width, }); } @@ -39,14 +41,14 @@ export function Txt2Img(props: Txt2ImgProps) { staleTime: STALE_TIME, }); - const [params, setParams] = useState({ + const [height, setHeight] = useState(512); + const [width, setWidth] = useState(512); + const [params, setParams] = useState({ cfg: 6, seed: -1, steps: 25, - width: 512, - height: 512, + prompt: config.default.prompt, }); - const [prompt, setPrompt] = useState(config.default.prompt); const [scheduler, setScheduler] = useState(config.default.scheduler); return @@ -66,9 +68,28 @@ export function Txt2Img(props: Txt2ImgProps) { { setParams(newParams); }} /> - { - setPrompt(event.target.value); - }} /> + + { + setWidth(value); + }} + /> + { + setHeight(value); + }} + /> + a.output === b.output}