1
0
Fork 0

implement experimental parameters

This commit is contained in:
Sean Sube 2024-02-17 15:50:04 -06:00
parent 48d7a51666
commit 0bbdc3b96b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 44 additions and 22 deletions

View File

@ -31,15 +31,16 @@ class TextPromptStage(BaseStage):
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
exclude_tokens: Optional[str] = None,
prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None,
min_length: int = 75,
**kwargs,
) -> StageResult:
device = worker.device.torch_str()
text_pipe = pipeline(
"text-generation",
model=prompt_model,
model=prompt_filter,
device=device,
framework="pt",
)
@ -64,11 +65,11 @@ class TextPromptStage(BaseStage):
)
prompt = result[0]["generated_text"].strip()
if exclude_tokens:
if remove_tokens:
logger.debug(
"removing excluded tokens from prompt: %s", exclude_tokens
"removing excluded tokens from prompt: %s", remove_tokens
)
prompt = sub(exclude_tokens, "", prompt)
prompt = sub(remove_tokens, "", prompt)
if retries >= RETRY_LIMIT:
logger.warning(
@ -78,6 +79,10 @@ class TextPromptStage(BaseStage):
prompt,
)
if add_suffix:
prompt = f"{prompt}, {add_suffix}"
logger.trace("adding suffix to prompt: %s", prompt)
prompt_results.append(prompt)
complete_prompt = " || ".join(prompt_results)

View File

@ -94,7 +94,7 @@ def run_txt2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
add_prompt_filter(server, chain)
add_prompt_filter(server, chain, request.experimental)
chain.stage(
SourceTxt2ImgStage(),

View File

@ -1,3 +1,4 @@
/* eslint-disable max-params */
/* eslint-disable camelcase */
/* eslint-disable max-lines */
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
@ -351,7 +352,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'img2img');
const json = makeImageJSON({
model,
@ -359,6 +360,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale,
highres,
img2img: params,
experimental,
});
const form = new FormData();
@ -379,7 +381,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
},
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'txt2img');
const json = makeImageJSON({
model,
@ -387,6 +389,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
size: params,
upscale,
highres,
experimental,
});
const form = new FormData();
@ -407,7 +410,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
},
};
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'inpaint');
const json = makeImageJSON({
model,
@ -415,6 +418,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale,
highres,
inpaint: params,
experimental,
});
const form = new FormData();
@ -436,7 +440,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
},
};
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'inpaint');
const json = makeImageJSON({
model,
@ -444,6 +448,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale,
highres,
inpaint: params,
experimental,
});
const form = new FormData();

View File

@ -4,7 +4,18 @@ import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js';
import { ExtrasFile } from '../types/model.js';
import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js';
import {
BlendParams,
ExperimentalParams,
HighresParams,
Img2ImgParams,
InpaintParams,
ModelParams,
OutpaintParams,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
} from '../types/params.js';
export interface ApiClient {
/**
@ -67,22 +78,22 @@ export interface ApiClient {
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/**
* Start an im2img pipeline.
*/
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/**
* Start an inpaint pipeline.
*/
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/**
* Start an outpaint pipeline.
*/
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/**
* Start an upscale pipeline.

View File

@ -32,7 +32,7 @@ export function Img2Img() {
const { job, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, selectUpscale(state), selectHighres(state));
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry);
}

View File

@ -46,7 +46,7 @@ export function Inpaint() {
...outpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(state), selectHighres(state));
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry);
} else {
@ -54,7 +54,7 @@ export function Inpaint() {
...inpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(state), selectHighres(state));
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry);
}

View File

@ -63,16 +63,17 @@ export function Txt2Img() {
async function generateImage() {
const state = store.getState();
const grid = selectVariable(state);
const params2 = selectParams(state);
const params = selectParams(state);
const upscale = selectUpscale(state);
const highres = selectHighres(state);
const experimental = selectExperimental(state);
if (grid.enabled) {
const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres);
const chain = makeTxt2ImgGridPipeline(grid, model, params, upscale, highres);
const image = await client.chain(model, chain);
pushHistory(image);
} else {
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
const { job, retry } = await client.txt2img(model, params, upscale, highres, experimental);
pushHistory(job, retry);
}
}