feat: implement negative prompts
This commit is contained in:
parent
0d4c0a5942
commit
f2e2b20f18
|
@ -233,6 +233,8 @@ def pipeline_from_request(pipeline):
|
|||
|
||||
# image params
|
||||
prompt = request.args.get('prompt', default_prompt)
|
||||
negative_prompt = request.args.get('negative', None);
|
||||
|
||||
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0)
|
||||
steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps)
|
||||
height = get_and_clamp_int(request.args, 'height', default_height, max_height)
|
||||
|
@ -246,7 +248,7 @@ def pipeline_from_request(pipeline):
|
|||
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
pipe = load_pipeline(pipeline, model, provider, scheduler)
|
||||
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
|
||||
return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe)
|
||||
|
||||
|
||||
@app.route('/img2img', methods=['POST'])
|
||||
|
@ -257,17 +259,18 @@ def img2img():
|
|||
|
||||
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
|
||||
|
||||
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
|
||||
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=input_image,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=cfg,
|
||||
strength=strength,
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=cfg,
|
||||
image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
strength=strength,
|
||||
).images[0]
|
||||
|
||||
(output_file, output_full) = make_output_path('img2img', (prompt, cfg, steps, height, width, seed))
|
||||
|
@ -286,13 +289,14 @@ def img2img():
|
|||
'width': default_width,
|
||||
'prompt': prompt,
|
||||
'seed': seed,
|
||||
'negativePrompt': negative_prompt,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.route('/txt2img', methods=['POST'])
|
||||
def txt2img():
|
||||
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
|
||||
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
|
||||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
|
@ -302,10 +306,11 @@ def txt2img():
|
|||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps=steps,
|
||||
generator=rng,
|
||||
guidance_scale=cfg,
|
||||
latents=latents,
|
||||
generator=rng,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
).images[0]
|
||||
|
||||
(output_file, output_full) = make_output_path('txt2img', (prompt, cfg, steps, height, width, seed))
|
||||
|
@ -324,6 +329,7 @@ def txt2img():
|
|||
'width': width,
|
||||
'prompt': prompt,
|
||||
'seed': seed,
|
||||
'negativePrompt': negative_prompt,
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -1,42 +1,41 @@
|
|||
import { doesExist } from '@apextoaster/js-utils';
|
||||
|
||||
export interface Img2ImgParams {
|
||||
export interface BaseImgParams {
|
||||
/**
|
||||
* Which ONNX model to use.
|
||||
*/
|
||||
model?: string;
|
||||
|
||||
/**
|
||||
* Hardware accelerator or CPU mode.
|
||||
*/
|
||||
platform?: string;
|
||||
|
||||
/**
|
||||
* Scheduling algorithm.
|
||||
*/
|
||||
scheduler?: string;
|
||||
|
||||
prompt: string;
|
||||
negativePrompt?: string;
|
||||
|
||||
cfg: number;
|
||||
steps: number;
|
||||
seed: number;
|
||||
}
|
||||
|
||||
seed?: number;
|
||||
|
||||
export interface Img2ImgParams extends BaseImgParams {
|
||||
source: File;
|
||||
}
|
||||
|
||||
export interface Txt2ImgParams {
|
||||
model?: string;
|
||||
platform?: string;
|
||||
scheduler?: string;
|
||||
|
||||
prompt: string;
|
||||
cfg: number;
|
||||
steps: number;
|
||||
export type Img2ImgResponse = Required<Omit<Img2ImgParams, 'file'>>;
|
||||
|
||||
export interface Txt2ImgParams extends BaseImgParams {
|
||||
width?: number;
|
||||
height?: number;
|
||||
seed?: number;
|
||||
}
|
||||
|
||||
export interface Txt2ImgResponse extends Txt2ImgParams {
|
||||
model: string;
|
||||
platform: string;
|
||||
scheduler: string;
|
||||
|
||||
width: number;
|
||||
height: number;
|
||||
seed: number;
|
||||
}
|
||||
export type Txt2ImgResponse = Required<Txt2ImgParams>;
|
||||
|
||||
export interface ApiResponse {
|
||||
output: string;
|
||||
|
@ -71,6 +70,37 @@ export async function imageFromResponse(root: string, res: Response): Promise<Ap
|
|||
}
|
||||
}
|
||||
|
||||
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
|
||||
const url = new URL(type, root);
|
||||
url.searchParams.append('cfg', params.cfg.toFixed(0));
|
||||
url.searchParams.append('steps', params.steps.toFixed(0));
|
||||
|
||||
if (doesExist(params.model)) {
|
||||
url.searchParams.append('model', params.model);
|
||||
}
|
||||
|
||||
if (doesExist(params.platform)) {
|
||||
url.searchParams.append('platform', params.platform);
|
||||
}
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
url.searchParams.append('scheduler', params.scheduler);
|
||||
}
|
||||
|
||||
if (doesExist(params.seed)) {
|
||||
url.searchParams.append('seed', params.seed.toFixed(0));
|
||||
}
|
||||
|
||||
// put prompt last, in case a load balancer decides to truncate the URL
|
||||
url.searchParams.append('prompt', params.prompt);
|
||||
|
||||
if (doesExist(params.negativePrompt)) {
|
||||
url.searchParams.append('negativePrompt', params.negativePrompt);
|
||||
}
|
||||
|
||||
return url;
|
||||
}
|
||||
|
||||
export function makeClient(root: string, f = fetch): ApiClient {
|
||||
let pending: Promise<ApiResponse> | undefined;
|
||||
|
||||
|
@ -95,27 +125,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
return pending;
|
||||
}
|
||||
|
||||
const url = new URL('img2img', root);
|
||||
url.searchParams.append('cfg', params.cfg.toFixed(0));
|
||||
url.searchParams.append('steps', params.steps.toFixed(0));
|
||||
|
||||
if (doesExist(params.model)) {
|
||||
url.searchParams.append('model', params.model);
|
||||
}
|
||||
|
||||
if (doesExist(params.platform)) {
|
||||
url.searchParams.append('platform', params.platform);
|
||||
}
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
url.searchParams.append('scheduler', params.scheduler);
|
||||
}
|
||||
|
||||
if (doesExist(params.seed)) {
|
||||
url.searchParams.append('seed', params.seed.toFixed(0));
|
||||
}
|
||||
|
||||
url.searchParams.append('prompt', params.prompt);
|
||||
const url = makeImageURL(root, 'img2img', params);
|
||||
|
||||
const body = new FormData();
|
||||
body.append('source', params.source, 'source');
|
||||
|
@ -135,9 +145,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
return pending;
|
||||
}
|
||||
|
||||
const url = new URL('txt2img', root);
|
||||
url.searchParams.append('cfg', params.cfg.toFixed(0));
|
||||
url.searchParams.append('steps', params.steps.toFixed(0));
|
||||
const url = makeImageURL(root, 'txt2img', params);
|
||||
|
||||
if (doesExist(params.width)) {
|
||||
url.searchParams.append('width', params.width.toFixed(0));
|
||||
|
@ -147,24 +155,6 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
url.searchParams.append('height', params.height.toFixed(0));
|
||||
}
|
||||
|
||||
if (doesExist(params.seed)) {
|
||||
url.searchParams.append('seed', params.seed.toFixed(0));
|
||||
}
|
||||
|
||||
if (doesExist(params.model)) {
|
||||
url.searchParams.append('model', params.model);
|
||||
}
|
||||
|
||||
if (doesExist(params.platform)) {
|
||||
url.searchParams.append('platform', params.platform);
|
||||
}
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
url.searchParams.append('scheduler', params.scheduler);
|
||||
}
|
||||
|
||||
url.searchParams.append('prompt', params.prompt);
|
||||
|
||||
pending = f(url, {
|
||||
method: 'POST',
|
||||
}).then((res) => imageFromResponse(root, res)).finally(() => {
|
||||
|
|
|
@ -1,21 +1,14 @@
|
|||
import { doesExist } from '@apextoaster/js-utils';
|
||||
import { IconButton, Stack } from '@mui/material';
|
||||
import { IconButton, Stack, TextField } from '@mui/material';
|
||||
import { Casino } from '@mui/icons-material';
|
||||
import * as React from 'react';
|
||||
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
export interface ImageParams {
|
||||
cfg: number;
|
||||
seed: number;
|
||||
steps: number;
|
||||
width: number;
|
||||
height: number;
|
||||
}
|
||||
import { BaseImgParams } from '../api/client.js';
|
||||
|
||||
export interface ImageControlProps {
|
||||
params: ImageParams;
|
||||
onChange?: (params: ImageParams) => void;
|
||||
params: BaseImgParams;
|
||||
onChange?: (params: BaseImgParams) => void;
|
||||
}
|
||||
|
||||
export function ImageControl(props: ImageControlProps) {
|
||||
|
@ -54,38 +47,6 @@ export function ImageControl(props: ImageControlProps) {
|
|||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
label='Width'
|
||||
min={8}
|
||||
max={512}
|
||||
step={8}
|
||||
value={params.width}
|
||||
onChange={(width) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...params,
|
||||
width,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Height'
|
||||
min={8}
|
||||
max={512}
|
||||
step={8}
|
||||
value={params.height}
|
||||
onChange={(height) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...params,
|
||||
height,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
label='Seed'
|
||||
|
@ -114,5 +75,21 @@ export function ImageControl(props: ImageControlProps) {
|
|||
<Casino />
|
||||
</IconButton>
|
||||
</Stack>
|
||||
<TextField label='Prompt' variant='outlined' value={params.prompt} onChange={(event) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...params,
|
||||
prompt: event.target.value,
|
||||
});
|
||||
}
|
||||
}} />
|
||||
<TextField label='Negative Prompt' variant='outlined' value={params.negativePrompt} onChange={(event) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...params,
|
||||
negativePrompt: event.target.value,
|
||||
});
|
||||
}
|
||||
}} />
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack, TextField } from '@mui/material';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQuery } from 'react-query';
|
||||
|
||||
import { ApiClient } from '../api/client.js';
|
||||
import { ApiClient, BaseImgParams } from '../api/client.js';
|
||||
import { Config } from '../config.js';
|
||||
import { SCHEDULER_LABELS } from '../strings.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl, ImageParams } from './ImageControl.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
|
||||
|
@ -30,7 +30,6 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
...params,
|
||||
model,
|
||||
platform,
|
||||
prompt,
|
||||
scheduler,
|
||||
source: mustExist(source), // TODO: show an error if this doesn't exist
|
||||
});
|
||||
|
@ -51,14 +50,12 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
});
|
||||
|
||||
const [source, setSource] = useState<File>();
|
||||
const [params, setParams] = useState<ImageParams>({
|
||||
const [params, setParams] = useState<BaseImgParams>({
|
||||
cfg: 6,
|
||||
seed: -1,
|
||||
steps: 25,
|
||||
width: 512,
|
||||
height: 512,
|
||||
prompt: config.default.prompt,
|
||||
});
|
||||
const [prompt, setPrompt] = useState(config.default.prompt);
|
||||
const [scheduler, setScheduler] = useState(config.default.scheduler);
|
||||
|
||||
return <Box>
|
||||
|
@ -79,9 +76,6 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
<ImageControl params={params} onChange={(newParams) => {
|
||||
setParams(newParams);
|
||||
}} />
|
||||
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => {
|
||||
setPrompt(event.target.value);
|
||||
}} />
|
||||
<Button onClick={() => upload.mutate()}>Generate</Button>
|
||||
<MutationHistory result={upload} limit={4} element={ImageCard}
|
||||
isEqual={(a, b) => a.output === b.output}
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import { Box, Button, Stack, TextField } from '@mui/material';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQuery } from 'react-query';
|
||||
|
||||
import { ApiClient } from '../api/client.js';
|
||||
import { ApiClient, BaseImgParams } from '../api/client.js';
|
||||
import { Config } from '../config.js';
|
||||
import { SCHEDULER_LABELS } from '../strings.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl, ImageParams } from './ImageControl.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
|
||||
const { useState } = React;
|
||||
|
@ -29,8 +30,9 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
...params,
|
||||
model,
|
||||
platform,
|
||||
prompt,
|
||||
scheduler,
|
||||
height,
|
||||
width,
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -39,14 +41,14 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const [params, setParams] = useState<ImageParams>({
|
||||
const [height, setHeight] = useState(512);
|
||||
const [width, setWidth] = useState(512);
|
||||
const [params, setParams] = useState<BaseImgParams>({
|
||||
cfg: 6,
|
||||
seed: -1,
|
||||
steps: 25,
|
||||
width: 512,
|
||||
height: 512,
|
||||
prompt: config.default.prompt,
|
||||
});
|
||||
const [prompt, setPrompt] = useState(config.default.prompt);
|
||||
const [scheduler, setScheduler] = useState(config.default.scheduler);
|
||||
|
||||
return <Box>
|
||||
|
@ -66,9 +68,28 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
<ImageControl params={params} onChange={(newParams) => {
|
||||
setParams(newParams);
|
||||
}} />
|
||||
<TextField label='Prompt' variant='outlined' value={prompt} onChange={(event) => {
|
||||
setPrompt(event.target.value);
|
||||
}} />
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
label='Width'
|
||||
min={8}
|
||||
max={512}
|
||||
step={8}
|
||||
value={width}
|
||||
onChange={(value) => {
|
||||
setWidth(value);
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Height'
|
||||
min={8}
|
||||
max={512}
|
||||
step={8}
|
||||
value={height}
|
||||
onChange={(value) => {
|
||||
setHeight(value);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Button onClick={() => generate.mutate()}>Generate</Button>
|
||||
<MutationHistory result={generate} limit={4} element={ImageCard}
|
||||
isEqual={(a, b) => a.output === b.output}
|
||||
|
|
Loading…
Reference in New Issue