implement experimental parameters
This commit is contained in:
parent
48d7a51666
commit
0bbdc3b96b
|
@ -31,15 +31,16 @@ class TextPromptStage(BaseStage):
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||||
exclude_tokens: Optional[str] = None,
|
remove_tokens: Optional[str] = None,
|
||||||
|
add_suffix: Optional[str] = None,
|
||||||
min_length: int = 75,
|
min_length: int = 75,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
device = worker.device.torch_str()
|
device = worker.device.torch_str()
|
||||||
text_pipe = pipeline(
|
text_pipe = pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
model=prompt_model,
|
model=prompt_filter,
|
||||||
device=device,
|
device=device,
|
||||||
framework="pt",
|
framework="pt",
|
||||||
)
|
)
|
||||||
|
@ -64,11 +65,11 @@ class TextPromptStage(BaseStage):
|
||||||
)
|
)
|
||||||
prompt = result[0]["generated_text"].strip()
|
prompt = result[0]["generated_text"].strip()
|
||||||
|
|
||||||
if exclude_tokens:
|
if remove_tokens:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"removing excluded tokens from prompt: %s", exclude_tokens
|
"removing excluded tokens from prompt: %s", remove_tokens
|
||||||
)
|
)
|
||||||
prompt = sub(exclude_tokens, "", prompt)
|
prompt = sub(remove_tokens, "", prompt)
|
||||||
|
|
||||||
if retries >= RETRY_LIMIT:
|
if retries >= RETRY_LIMIT:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -78,6 +79,10 @@ class TextPromptStage(BaseStage):
|
||||||
prompt,
|
prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_suffix:
|
||||||
|
prompt = f"{prompt}, {add_suffix}"
|
||||||
|
logger.trace("adding suffix to prompt: %s", prompt)
|
||||||
|
|
||||||
prompt_results.append(prompt)
|
prompt_results.append(prompt)
|
||||||
|
|
||||||
complete_prompt = " || ".join(prompt_results)
|
complete_prompt = " || ".join(prompt_results)
|
||||||
|
|
|
@ -94,7 +94,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
add_prompt_filter(server, chain)
|
add_prompt_filter(server, chain, request.experimental)
|
||||||
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
SourceTxt2ImgStage(),
|
SourceTxt2ImgStage(),
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
/* eslint-disable max-params */
|
||||||
/* eslint-disable camelcase */
|
/* eslint-disable camelcase */
|
||||||
/* eslint-disable max-lines */
|
/* eslint-disable max-lines */
|
||||||
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
||||||
|
@ -351,7 +352,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as Array<string>;
|
||||||
},
|
},
|
||||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiURL(root, 'img2img');
|
const url = makeApiURL(root, 'img2img');
|
||||||
const json = makeImageJSON({
|
const json = makeImageJSON({
|
||||||
model,
|
model,
|
||||||
|
@ -359,6 +360,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
img2img: params,
|
img2img: params,
|
||||||
|
experimental,
|
||||||
});
|
});
|
||||||
|
|
||||||
const form = new FormData();
|
const form = new FormData();
|
||||||
|
@ -379,7 +381,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiURL(root, 'txt2img');
|
const url = makeApiURL(root, 'txt2img');
|
||||||
const json = makeImageJSON({
|
const json = makeImageJSON({
|
||||||
model,
|
model,
|
||||||
|
@ -387,6 +389,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
size: params,
|
size: params,
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
|
experimental,
|
||||||
});
|
});
|
||||||
|
|
||||||
const form = new FormData();
|
const form = new FormData();
|
||||||
|
@ -407,7 +410,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiURL(root, 'inpaint');
|
const url = makeApiURL(root, 'inpaint');
|
||||||
const json = makeImageJSON({
|
const json = makeImageJSON({
|
||||||
model,
|
model,
|
||||||
|
@ -415,6 +418,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
inpaint: params,
|
inpaint: params,
|
||||||
|
experimental,
|
||||||
});
|
});
|
||||||
|
|
||||||
const form = new FormData();
|
const form = new FormData();
|
||||||
|
@ -436,7 +440,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiURL(root, 'inpaint');
|
const url = makeApiURL(root, 'inpaint');
|
||||||
const json = makeImageJSON({
|
const json = makeImageJSON({
|
||||||
model,
|
model,
|
||||||
|
@ -444,6 +448,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
inpaint: params,
|
inpaint: params,
|
||||||
|
experimental,
|
||||||
});
|
});
|
||||||
|
|
||||||
const form = new FormData();
|
const form = new FormData();
|
||||||
|
|
|
@ -4,7 +4,18 @@ import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/
|
||||||
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
|
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
|
||||||
import { ChainPipeline } from '../types/chain.js';
|
import { ChainPipeline } from '../types/chain.js';
|
||||||
import { ExtrasFile } from '../types/model.js';
|
import { ExtrasFile } from '../types/model.js';
|
||||||
import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js';
|
import {
|
||||||
|
BlendParams,
|
||||||
|
ExperimentalParams,
|
||||||
|
HighresParams,
|
||||||
|
Img2ImgParams,
|
||||||
|
InpaintParams,
|
||||||
|
ModelParams,
|
||||||
|
OutpaintParams,
|
||||||
|
Txt2ImgParams,
|
||||||
|
UpscaleParams,
|
||||||
|
UpscaleReqParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
|
||||||
export interface ApiClient {
|
export interface ApiClient {
|
||||||
/**
|
/**
|
||||||
|
@ -67,22 +78,22 @@ export interface ApiClient {
|
||||||
/**
|
/**
|
||||||
* Start a txt2img pipeline.
|
* Start a txt2img pipeline.
|
||||||
*/
|
*/
|
||||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an im2img pipeline.
|
* Start an im2img pipeline.
|
||||||
*/
|
*/
|
||||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an inpaint pipeline.
|
* Start an inpaint pipeline.
|
||||||
*/
|
*/
|
||||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an outpaint pipeline.
|
* Start an outpaint pipeline.
|
||||||
*/
|
*/
|
||||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an upscale pipeline.
|
* Start an upscale pipeline.
|
||||||
|
|
|
@ -32,7 +32,7 @@ export function Img2Img() {
|
||||||
const { job, retry } = await client.img2img(model, {
|
const { job, retry } = await client.img2img(model, {
|
||||||
...img2img,
|
...img2img,
|
||||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||||
|
|
||||||
pushHistory(job, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ export function Inpaint() {
|
||||||
...outpaint,
|
...outpaint,
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||||
|
|
||||||
pushHistory(job, retry);
|
pushHistory(job, retry);
|
||||||
} else {
|
} else {
|
||||||
|
@ -54,7 +54,7 @@ export function Inpaint() {
|
||||||
...inpaint,
|
...inpaint,
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state), selectExperimental(state));
|
||||||
|
|
||||||
pushHistory(job, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,16 +63,17 @@ export function Txt2Img() {
|
||||||
async function generateImage() {
|
async function generateImage() {
|
||||||
const state = store.getState();
|
const state = store.getState();
|
||||||
const grid = selectVariable(state);
|
const grid = selectVariable(state);
|
||||||
const params2 = selectParams(state);
|
const params = selectParams(state);
|
||||||
const upscale = selectUpscale(state);
|
const upscale = selectUpscale(state);
|
||||||
const highres = selectHighres(state);
|
const highres = selectHighres(state);
|
||||||
|
const experimental = selectExperimental(state);
|
||||||
|
|
||||||
if (grid.enabled) {
|
if (grid.enabled) {
|
||||||
const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres);
|
const chain = makeTxt2ImgGridPipeline(grid, model, params, upscale, highres);
|
||||||
const image = await client.chain(model, chain);
|
const image = await client.chain(model, chain);
|
||||||
pushHistory(image);
|
pushHistory(image);
|
||||||
} else {
|
} else {
|
||||||
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
|
const { job, retry } = await client.txt2img(model, params, upscale, highres, experimental);
|
||||||
pushHistory(job, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue