1
0
Fork 0

feat: implement negative prompts

This commit is contained in:
Sean Sube 2023-01-08 13:05:02 -06:00
parent 0d4c0a5942
commit f2e2b20f18
5 changed files with 126 additions and 138 deletions

View File

@ -233,6 +233,8 @@ def pipeline_from_request(pipeline):
# image params # image params
prompt = request.args.get('prompt', default_prompt) prompt = request.args.get('prompt', default_prompt)
negative_prompt = request.args.get('negative', None);
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0) cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0)
steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps) steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps)
height = get_and_clamp_int(request.args, 'height', default_height, max_height) height = get_and_clamp_int(request.args, 'height', default_height, max_height)
@ -246,7 +248,7 @@ def pipeline_from_request(pipeline):
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt)) (user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
pipe = load_pipeline(pipeline, model, provider, scheduler) pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe) return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe)
@app.route('/img2img', methods=['POST']) @app.route('/img2img', methods=['POST'])
@ -257,17 +259,18 @@ def img2img():
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0) strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
(model, provider, scheduler, prompt, cfg, steps, height, (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline) width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
rng = np.random.RandomState(seed) rng = np.random.RandomState(seed)
image = pipe( image = pipe(
prompt=prompt, prompt,
image=input_image,
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
generator=rng, generator=rng,
guidance_scale=cfg,
image=input_image,
negative_prompt=negative_prompt,
num_inference_steps=steps,
strength=strength,
).images[0] ).images[0]
(output_file, output_full) = make_output_path('img2img', (prompt, cfg, steps, height, width, seed)) (output_file, output_full) = make_output_path('img2img', (prompt, cfg, steps, height, width, seed))
@ -286,13 +289,14 @@ def img2img():
'width': default_width, 'width': default_width,
'prompt': prompt, 'prompt': prompt,
'seed': seed, 'seed': seed,
'negativePrompt': negative_prompt,
} }
}) })
@app.route('/txt2img', methods=['POST']) @app.route('/txt2img', methods=['POST'])
def txt2img(): def txt2img():
(model, provider, scheduler, prompt, cfg, steps, height, (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline) width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
latents = get_latents_from_seed(seed, width, height) latents = get_latents_from_seed(seed, width, height)
@ -302,10 +306,11 @@ def txt2img():
prompt, prompt,
height, height,
width, width,
num_inference_steps=steps, generator=rng,
guidance_scale=cfg, guidance_scale=cfg,
latents=latents, latents=latents,
generator=rng, negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0] ).images[0]
(output_file, output_full) = make_output_path('txt2img', (prompt, cfg, steps, height, width, seed)) (output_file, output_full) = make_output_path('txt2img', (prompt, cfg, steps, height, width, seed))
@ -324,6 +329,7 @@ def txt2img():
'width': width, 'width': width,
'prompt': prompt, 'prompt': prompt,
'seed': seed, 'seed': seed,
'negativePrompt': negative_prompt,
} }
}) })

View File

@ -1,42 +1,41 @@
import { doesExist } from '@apextoaster/js-utils'; import { doesExist } from '@apextoaster/js-utils';
export interface Img2ImgParams { export interface BaseImgParams {
/**
* Which ONNX model to use.
*/
model?: string; model?: string;
/**
* Hardware accelerator or CPU mode.
*/
platform?: string; platform?: string;
/**
* Scheduling algorithm.
*/
scheduler?: string; scheduler?: string;
prompt: string; prompt: string;
negativePrompt?: string;
cfg: number; cfg: number;
steps: number; steps: number;
seed: number;
}
seed?: number; export interface Img2ImgParams extends BaseImgParams {
source: File; source: File;
} }
export interface Txt2ImgParams { export type Img2ImgResponse = Required<Omit<Img2ImgParams, 'file'>>;
model?: string;
platform?: string;
scheduler?: string;
prompt: string;
cfg: number;
steps: number;
export interface Txt2ImgParams extends BaseImgParams {
width?: number; width?: number;
height?: number; height?: number;
seed?: number;
} }
export interface Txt2ImgResponse extends Txt2ImgParams { export type Txt2ImgResponse = Required<Txt2ImgParams>;
model: string;
platform: string;
scheduler: string;
width: number;
height: number;
seed: number;
}
export interface ApiResponse { export interface ApiResponse {
output: string; output: string;
@ -71,6 +70,37 @@ export async function imageFromResponse(root: string, res: Response): Promise<Ap
} }
} }
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
const url = new URL(type, root);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
// put prompt last, in case a load balancer decides to truncate the URL
url.searchParams.append('prompt', params.prompt);
if (doesExist(params.negativePrompt)) {
url.searchParams.append('negativePrompt', params.negativePrompt);
}
return url;
}
export function makeClient(root: string, f = fetch): ApiClient { export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<ApiResponse> | undefined; let pending: Promise<ApiResponse> | undefined;
@ -95,27 +125,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
return pending; return pending;
} }
const url = new URL('img2img', root); const url = makeImageURL(root, 'img2img', params);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
url.searchParams.append('prompt', params.prompt);
const body = new FormData(); const body = new FormData();
body.append('source', params.source, 'source'); body.append('source', params.source, 'source');
@ -135,9 +145,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
return pending; return pending;
} }
const url = new URL('txt2img', root); const url = makeImageURL(root, 'txt2img', params);
url.searchParams.append('cfg', params.cfg.toFixed(0));
url.searchParams.append('steps', params.steps.toFixed(0));
if (doesExist(params.width)) { if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(0)); url.searchParams.append('width', params.width.toFixed(0));
@ -147,24 +155,6 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('height', params.height.toFixed(0)); url.searchParams.append('height', params.height.toFixed(0));
} }
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
}
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
url.searchParams.append('prompt', params.prompt);
pending = f(url, { pending = f(url, {
method: 'POST', method: 'POST',
}).then((res) => imageFromResponse(root, res)).finally(() => { }).then((res) => imageFromResponse(root, res)).finally(() => {

View File

@ -1,21 +1,14 @@
import { doesExist } from '@apextoaster/js-utils'; import { doesExist } from '@apextoaster/js-utils';
import { IconButton, Stack } from '@mui/material'; import { IconButton, Stack, TextField } from '@mui/material';
import { Casino } from '@mui/icons-material'; import { Casino } from '@mui/icons-material';
import * as React from 'react'; import * as React from 'react';
import { NumericField } from './NumericField.js'; import { NumericField } from './NumericField.js';
import { BaseImgParams } from '../api/client.js';
export interface ImageParams {
cfg: number;
seed: number;
steps: number;
width: number;
height: number;
}
export interface ImageControlProps { export interface ImageControlProps {
params: ImageParams; params: BaseImgParams;
onChange?: (params: ImageParams) => void; onChange?: (params: BaseImgParams) => void;
} }
export function ImageControl(props: ImageControlProps) { export function ImageControl(props: ImageControlProps) {
@ -54,38 +47,6 @@ export function ImageControl(props: ImageControlProps) {
}} }}
/> />
</Stack> </Stack>
<Stack direction='row' spacing={4}>
<NumericField
label='Width'
min={8}
max={512}
step={8}
value={params.width}
onChange={(width) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
width,
});
}
}}
/>
<NumericField
label='Height'
min={8}
max={512}
step={8}
value={params.height}
onChange={(height) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
height,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}> <Stack direction='row' spacing={4}>
<NumericField <NumericField
label='Seed' label='Seed'
@ -114,5 +75,21 @@ export function ImageControl(props: ImageControlProps) {
<Casino /> <Casino />
</IconButton> </IconButton>
</Stack> </Stack>
<TextField label='Prompt' variant='outlined' value={params.prompt} onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
prompt: event.target.value,
});
}
}} />
<TextField label='Negative Prompt' variant='outlined' value={params.negativePrompt} onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...params,
negativePrompt: event.target.value,
});
}
}} />
</Stack>; </Stack>;
} }

View File

@ -1,13 +1,13 @@
import { doesExist, mustExist } from '@apextoaster/js-utils'; import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack, TextField } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useMutation, useQuery } from 'react-query'; import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js'; import { ApiClient, BaseImgParams } from '../api/client.js';
import { Config } from '../config.js'; import { Config } from '../config.js';
import { SCHEDULER_LABELS } from '../strings.js'; import { SCHEDULER_LABELS } from '../strings.js';
import { ImageCard } from './ImageCard.js'; import { ImageCard } from './ImageCard.js';
import { ImageControl, ImageParams } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js'; import { MutationHistory } from './MutationHistory.js';
import { QueryList } from './QueryList.js'; import { QueryList } from './QueryList.js';
@ -30,7 +30,6 @@ export function Img2Img(props: Img2ImgProps) {
...params, ...params,
model, model,
platform, platform,
prompt,
scheduler, scheduler,
source: mustExist(source), // TODO: show an error if this doesn't exist source: mustExist(source), // TODO: show an error if this doesn't exist
}); });
@ -51,14 +50,12 @@ export function Img2Img(props: Img2ImgProps) {
}); });
const [source, setSource] = useState<File>(); const [source, setSource] = useState<File>();
const [params, setParams] = useState<ImageParams>({ const [params, setParams] = useState<BaseImgParams>({
cfg: 6, cfg: 6,
seed: -1, seed: -1,
steps: 25, steps: 25,
width: 512, prompt: config.default.prompt,
height: 512,
}); });
const [prompt, setPrompt] = useState(config.default.prompt);
const [scheduler, setScheduler] = useState(config.default.scheduler); const [scheduler, setScheduler] = useState(config.default.scheduler);
return <Box> return <Box>
@ -79,9 +76,6 @@ export function Img2Img(props: Img2ImgProps) {
<ImageControl params={params} onChange={(newParams) => { <ImageControl params={params} onChange={(newParams) => {
setParams(newParams); setParams(newParams);
}} /> }} />
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => {
setPrompt(event.target.value);
}} />
<Button onClick={() => upload.mutate()}>Generate</Button> <Button onClick={() => upload.mutate()}>Generate</Button>
<MutationHistory result={upload} limit={4} element={ImageCard} <MutationHistory result={upload} limit={4} element={ImageCard}
isEqual={(a, b) => a.output === b.output} isEqual={(a, b) => a.output === b.output}

View File

@ -1,13 +1,14 @@
import { Box, Button, Stack, TextField } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useMutation, useQuery } from 'react-query'; import { useMutation, useQuery } from 'react-query';
import { ApiClient } from '../api/client.js'; import { ApiClient, BaseImgParams } from '../api/client.js';
import { Config } from '../config.js'; import { Config } from '../config.js';
import { SCHEDULER_LABELS } from '../strings.js'; import { SCHEDULER_LABELS } from '../strings.js';
import { ImageCard } from './ImageCard.js'; import { ImageCard } from './ImageCard.js';
import { ImageControl, ImageParams } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js'; import { MutationHistory } from './MutationHistory.js';
import { NumericField } from './NumericField.js';
import { QueryList } from './QueryList.js'; import { QueryList } from './QueryList.js';
const { useState } = React; const { useState } = React;
@ -29,8 +30,9 @@ export function Txt2Img(props: Txt2ImgProps) {
...params, ...params,
model, model,
platform, platform,
prompt,
scheduler, scheduler,
height,
width,
}); });
} }
@ -39,14 +41,14 @@ export function Txt2Img(props: Txt2ImgProps) {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const [params, setParams] = useState<ImageParams>({ const [height, setHeight] = useState(512);
const [width, setWidth] = useState(512);
const [params, setParams] = useState<BaseImgParams>({
cfg: 6, cfg: 6,
seed: -1, seed: -1,
steps: 25, steps: 25,
width: 512, prompt: config.default.prompt,
height: 512,
}); });
const [prompt, setPrompt] = useState(config.default.prompt);
const [scheduler, setScheduler] = useState(config.default.scheduler); const [scheduler, setScheduler] = useState(config.default.scheduler);
return <Box> return <Box>
@ -66,9 +68,28 @@ export function Txt2Img(props: Txt2ImgProps) {
<ImageControl params={params} onChange={(newParams) => { <ImageControl params={params} onChange={(newParams) => {
setParams(newParams); setParams(newParams);
}} /> }} />
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => { <Stack direction='row' spacing={4}>
setPrompt(event.target.value); <NumericField
}} /> label='Width'
min={8}
max={512}
step={8}
value={width}
onChange={(value) => {
setWidth(value);
}}
/>
<NumericField
label='Height'
min={8}
max={512}
step={8}
value={height}
onChange={(value) => {
setHeight(value);
}}
/>
</Stack>
<Button onClick={() => generate.mutate()}>Generate</Button> <Button onClick={() => generate.mutate()}>Generate</Button>
<MutationHistory result={generate} limit={4} element={ImageCard} <MutationHistory result={generate} limit={4} element={ImageCard}
isEqual={(a, b) => a.output === b.output} isEqual={(a, b) => a.output === b.output}