implement experimental parameters
This commit is contained in:
parent
48d7a51666
commit
0bbdc3b96b
|
@ -31,15 +31,16 @@ class TextPromptStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||
exclude_tokens: Optional[str] = None,
|
||||
prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||
remove_tokens: Optional[str] = None,
|
||||
add_suffix: Optional[str] = None,
|
||||
min_length: int = 75,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
device = worker.device.torch_str()
|
||||
text_pipe = pipeline(
|
||||
"text-generation",
|
||||
model=prompt_model,
|
||||
model=prompt_filter,
|
||||
device=device,
|
||||
framework="pt",
|
||||
)
|
||||
|
@ -64,11 +65,11 @@ class TextPromptStage(BaseStage):
|
|||
)
|
||||
prompt = result[0]["generated_text"].strip()
|
||||
|
||||
if exclude_tokens:
|
||||
if remove_tokens:
|
||||
logger.debug(
|
||||
"removing excluded tokens from prompt: %s", exclude_tokens
|
||||
"removing excluded tokens from prompt: %s", remove_tokens
|
||||
)
|
||||
prompt = sub(exclude_tokens, "", prompt)
|
||||
prompt = sub(remove_tokens, "", prompt)
|
||||
|
||||
if retries >= RETRY_LIMIT:
|
||||
logger.warning(
|
||||
|
@ -78,6 +79,10 @@ class TextPromptStage(BaseStage):
|
|||
prompt,
|
||||
)
|
||||
|
||||
if add_suffix:
|
||||
prompt = f"{prompt}, {add_suffix}"
|
||||
logger.trace("adding suffix to prompt: %s", prompt)
|
||||
|
||||
prompt_results.append(prompt)
|
||||
|
||||
complete_prompt = " || ".join(prompt_results)
|
||||
|
|
|
@ -94,7 +94,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
# prepare the chain pipeline and first stage
|
||||
chain = ChainPipeline()
|
||||
add_prompt_filter(server, chain)
|
||||
add_prompt_filter(server, chain, request.experimental)
|
||||
|
||||
chain.stage(
|
||||
SourceTxt2ImgStage(),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/* eslint-disable max-params */
|
||||
/* eslint-disable camelcase */
|
||||
/* eslint-disable max-lines */
|
||||
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
||||
|
@ -351,7 +352,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiURL(root, 'img2img');
|
||||
const json = makeImageJSON({
|
||||
model,
|
||||
|
@ -359,6 +360,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
upscale,
|
||||
highres,
|
||||
img2img: params,
|
||||
experimental,
|
||||
});
|
||||
|
||||
const form = new FormData();
|
||||
|
@ -379,7 +381,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
},
|
||||
};
|
||||
},
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiURL(root, 'txt2img');
|
||||
const json = makeImageJSON({
|
||||
model,
|
||||
|
@ -387,6 +389,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
size: params,
|
||||
upscale,
|
||||
highres,
|
||||
experimental,
|
||||
});
|
||||
|
||||
const form = new FormData();
|
||||
|
@ -407,7 +410,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
},
|
||||
};
|
||||
},
|
||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiURL(root, 'inpaint');
|
||||
const json = makeImageJSON({
|
||||
model,
|
||||
|
@ -415,6 +418,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
upscale,
|
||||
highres,
|
||||
inpaint: params,
|
||||
experimental,
|
||||
});
|
||||
|
||||
const form = new FormData();
|
||||
|
@ -436,7 +440,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
},
|
||||
};
|
||||
},
|
||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiURL(root, 'inpaint');
|
||||
const json = makeImageJSON({
|
||||
model,
|
||||
|
@ -444,6 +448,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
|||
upscale,
|
||||
highres,
|
||||
inpaint: params,
|
||||
experimental,
|
||||
});
|
||||
|
||||
const form = new FormData();
|
||||
|
|
|
@ -4,7 +4,18 @@ import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/
|
|||
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
|
||||
import { ChainPipeline } from '../types/chain.js';
|
||||
import { ExtrasFile } from '../types/model.js';
|
||||
import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js';
|
||||
import {
|
||||
BlendParams,
|
||||
ExperimentalParams,
|
||||
HighresParams,
|
||||
Img2ImgParams,
|
||||
InpaintParams,
|
||||
ModelParams,
|
||||
OutpaintParams,
|
||||
Txt2ImgParams,
|
||||
UpscaleParams,
|
||||
UpscaleReqParams,
|
||||
} from '../types/params.js';
|
||||
|
||||
export interface ApiClient {
|
||||
/**
|
||||
|
@ -67,22 +78,22 @@ export interface ApiClient {
|
|||
/**
|
||||
* Start a txt2img pipeline.
|
||||
*/
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an im2img pipeline.
|
||||
*/
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an inpaint pipeline.
|
||||
*/
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an outpaint pipeline.
|
||||
*/
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an upscale pipeline.
|
||||
|
|
|
@ -32,7 +32,7 @@ export function Img2Img() {
|
|||
const { job, retry } = await client.img2img(model, {
|
||||
...img2img,
|
||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ export function Inpaint() {
|
|||
...outpaint,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||
|
||||
pushHistory(job, retry);
|
||||
} else {
|
||||
|
@ -54,7 +54,7 @@ export function Inpaint() {
|
|||
...inpaint,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
|
|
|
@ -63,16 +63,17 @@ export function Txt2Img() {
|
|||
async function generateImage() {
|
||||
const state = store.getState();
|
||||
const grid = selectVariable(state);
|
||||
const params2 = selectParams(state);
|
||||
const params = selectParams(state);
|
||||
const upscale = selectUpscale(state);
|
||||
const highres = selectHighres(state);
|
||||
const experimental = selectExperimental(state);
|
||||
|
||||
if (grid.enabled) {
|
||||
const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres);
|
||||
const chain = makeTxt2ImgGridPipeline(grid, model, params, upscale, highres);
|
||||
const image = await client.chain(model, chain);
|
||||
pushHistory(image);
|
||||
} else {
|
||||
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
|
||||
const { job, retry } = await client.txt2img(model, params, upscale, highres, experimental);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue