1
0
Fork 0

feat: implement negative prompts

This commit is contained in:
Sean Sube 2023-01-08 13:05:02 -06:00
parent 0d4c0a5942
commit f2e2b20f18
5 changed files with 126 additions and 138 deletions

View File

@ -233,6 +233,8 @@ def pipeline_from_request(pipeline):
# image params
prompt = request.args.get('prompt', default_prompt)
negative_prompt = request.args.get('negative', None);
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0)
steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps)
height = get_and_clamp_int(request.args, 'height', default_height, max_height)
@ -246,7 +248,7 @@ def pipeline_from_request(pipeline):
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe)
@app.route('/img2img', methods=['POST'])
@ -257,17 +259,18 @@ def img2img():
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
(model, provider, scheduler, prompt, cfg, steps, height,
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
rng = np.random.RandomState(seed)
image = pipe(
prompt=prompt,
image=input_image,
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
prompt,
generator=rng,
guidance_scale=cfg,
image=input_image,
negative_prompt=negative_prompt,
num_inference_steps=steps,
strength=strength,
).images[0]
(output_file, output_full) = make_output_path('img2img', (prompt, cfg, steps, height, width, seed))
@ -286,13 +289,14 @@ def img2img():
'width': default_width,
'prompt': prompt,
'seed': seed,
'negativePrompt': negative_prompt,
}
})
@app.route('/txt2img', methods=['POST'])
def txt2img():
(model, provider, scheduler, prompt, cfg, steps, height,
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
latents = get_latents_from_seed(seed, width, height)
@ -302,10 +306,11 @@ def txt2img():
prompt,
height,
width,
num_inference_steps=steps,
generator=rng,
guidance_scale=cfg,
latents=latents,
generator=rng,
negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0]
(output_file, output_full) = make_output_path('txt2img', (prompt, cfg, steps, height, width, seed))
@ -324,6 +329,7 @@ def txt2img():
'width': width,
'prompt': prompt,
'seed': seed,
'negativePrompt': negative_prompt,
}
})

View File

@ -1,42 +1,41 @@
import { doesExist } from '@apextoaster/js-utils';
export interface Img2ImgParams {
export interface BaseImgParams {
/**
* Which ONNX model to use.
*/
model?: string;
/**
* Hardware accelerator or CPU mode.
*/
platform?: string;
/**
* Scheduling algorithm.
*/
scheduler?: string;
prompt: string;
negativePrompt?: string;
cfg: number;
steps: number;
seed: number;
}
seed?: number;
export interface Img2ImgParams extends BaseImgParams {
source: File;
}
export interface Txt2ImgParams {
model?: string;
platform?: string;
scheduler?: string;
prompt: string;
cfg: number;
steps: number;
export type Img2ImgResponse = Required<Omit<Img2ImgParams, 'file'>>;
export interface Txt2ImgParams extends BaseImgParams {
width?: number;
height?: number;
seed?: number;
}
export interface Txt2ImgResponse extends Txt2ImgParams {
model: string;
platform: string;
scheduler: string;
width: number;
height: number;
seed: number;
}
export type Txt2ImgResponse = Required<Txt2ImgParams>;
export interface ApiResponse {
output: string;
@ -71,6 +70,37 @@ export async function imageFromResponse(root: string, res: Response): Promise<Ap
}
}
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
const url = new URL(type, root);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
// put prompt last, in case a load balancer decides to truncate the URL
url.searchParams.append('prompt', params.prompt);
if (doesExist(params.negativePrompt)) {
url.searchParams.append('negativePrompt', params.negativePrompt);
}
return url;
}
export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<ApiResponse> | undefined;
@ -95,27 +125,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
return pending;
}
const url = new URL('img2img', root);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
url.searchParams.append('prompt', params.prompt);
const url = makeImageURL(root, 'img2img', params);
const body = new FormData();
body.append('source', params.source, 'source');
@ -135,9 +145,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
return pending;
}
const url = new URL('txt2img', root);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
const url = makeImageURL(root, 'txt2img', params);
if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(0));
@ -147,24 +155,6 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('height', params.height.toFixed(0));
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
url.searchParams.append('prompt', params.prompt);
pending = f(url, {
method: 'POST',
}).then((res) => imageFromResponse(root, res)).finally(() => {

View File

@ -1,21 +1,14 @@
import { doesExist } from '@apextoaster/js-utils';
import { IconButton, Stack } from '@mui/material';
import { IconButton, Stack, TextField } from '@mui/material';
import { Casino } from '@mui/icons-material';
import * as React from 'react';
import { NumericField } from './NumericField.js';
export interface ImageParams {
cfg: number;
seed: number;
steps: number;
width: number;
height: number;
}
import { BaseImgParams } from '../api/client.js';
export interface ImageControlProps {
params: ImageParams;
onChange?: (params: ImageParams) => void;
params: BaseImgParams;
onChange?: (params: BaseImgParams) => void;
}
export function ImageControl(props: ImageControlProps) {
@ -54,38 +47,6 @@ export function ImageControl(props: ImageControlProps) {
}}
/>
</Stack>
<Stack direction='row' spacing={4}>
<NumericField
label='Width'
min={8}
max={512}
step={8}
value={params.width}
onChange={(width) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
width,
});
}
}}
/>
<NumericField
label='Height'
min={8}
max={512}
step={8}
value={params.height}
onChange={(height) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
height,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}>
<NumericField
label='Seed'
@ -114,5 +75,21 @@ export function ImageControl(props: ImageControlProps) {
<Casino />
</IconButton>
</Stack>
<TextField label='Prompt' variant='outlined' value={params.prompt} onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
prompt: event.target.value,
});
}
}} />
<TextField label='Negative Prompt' variant='outlined' value={params.negativePrompt} onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
negativePrompt: event.target.value,
});
}
}} />
</Stack>;
}

View File

@ -1,13 +1,13 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack, TextField } from '@mui/material';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js';
import { ApiClient, BaseImgParams } from '../api/client.js';
import { Config } from '../config.js';
import { SCHEDULER_LABELS } from '../strings.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl, ImageParams } from './ImageControl.js';
import { ImageControl } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js';
import { QueryList } from './QueryList.js';
@ -30,7 +30,6 @@ export function Img2Img(props: Img2ImgProps) {
...params,
model,
platform,
prompt,
scheduler,
source: mustExist(source), // TODO: show an error if this doesn't exist
});
@ -51,14 +50,12 @@ export function Img2Img(props: Img2ImgProps) {
});
const [source, setSource] = useState<File>();
const [params, setParams] = useState<ImageParams>({
const [params, setParams] = useState<BaseImgParams>({
cfg: 6,
seed: -1,
steps: 25,
width: 512,
height: 512,
prompt: config.default.prompt,
});
const [prompt, setPrompt] = useState(config.default.prompt);
const [scheduler, setScheduler] = useState(config.default.scheduler);
return <Box>
@ -79,9 +76,6 @@ export function Img2Img(props: Img2ImgProps) {
<ImageControl params={params} onChange={(newParams) => {
setParams(newParams);
}} />
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => {
setPrompt(event.target.value);
}} />
<Button onClick={() => upload.mutate()}>Generate</Button>
<MutationHistory result={upload} limit={4} element={ImageCard}
isEqual={(a, b) => a.output === b.output}

View File

@ -1,13 +1,14 @@
import { Box, Button, Stack, TextField } from '@mui/material';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js';
import { ApiClient, BaseImgParams } from '../api/client.js';
import { Config } from '../config.js';
import { SCHEDULER_LABELS } from '../strings.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl, ImageParams } from './ImageControl.js';
import { ImageControl } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js';
import { NumericField } from './NumericField.js';
import { QueryList } from './QueryList.js';
const { useState } = React;
@ -29,8 +30,9 @@ export function Txt2Img(props: Txt2ImgProps) {
...params,
model,
platform,
prompt,
scheduler,
height,
width,
});
}
@ -39,14 +41,14 @@ export function Txt2Img(props: Txt2ImgProps) {
staleTime: STALE_TIME,
});
const [params, setParams] = useState<ImageParams>({
const [height, setHeight] = useState(512);
const [width, setWidth] = useState(512);
const [params, setParams] = useState<BaseImgParams>({
cfg: 6,
seed: -1,
steps: 25,
width: 512,
height: 512,
prompt: config.default.prompt,
});
const [prompt, setPrompt] = useState(config.default.prompt);
const [scheduler, setScheduler] = useState(config.default.scheduler);
return <Box>
@ -66,9 +68,28 @@ export function Txt2Img(props: Txt2ImgProps) {
<ImageControl params={params} onChange={(newParams) => {
setParams(newParams);
}} />
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => {
setPrompt(event.target.value);
}} />
<Stack direction='row' spacing={4}>
<NumericField
label='Width'
min={8}
max={512}
step={8}
value={width}
onChange={(value) => {
setWidth(value);
}}
/>
<NumericField
label='Height'
min={8}
max={512}
step={8}
value={height}
onChange={(value) => {
setHeight(value);
}}
/>
</Stack>
<Button onClick={() => generate.mutate()}>Generate</Button>
<MutationHistory result={generate} limit={4} element={ImageCard}
isEqual={(a, b) => a.output === b.output}