1
0
Fork 0

implement experimental parameters

This commit is contained in:
Sean Sube 2024-02-17 15:50:04 -06:00
parent 48d7a51666
commit 0bbdc3b96b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 44 additions and 22 deletions

View File

@ -31,15 +31,16 @@ class TextPromptStage(BaseStage):
sources: StageResult, sources: StageResult,
*, *,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
exclude_tokens: Optional[str] = None, remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None,
min_length: int = 75, min_length: int = 75,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
device = worker.device.torch_str() device = worker.device.torch_str()
text_pipe = pipeline( text_pipe = pipeline(
"text-generation", "text-generation",
model=prompt_model, model=prompt_filter,
device=device, device=device,
framework="pt", framework="pt",
) )
@ -64,11 +65,11 @@ class TextPromptStage(BaseStage):
) )
prompt = result[0]["generated_text"].strip() prompt = result[0]["generated_text"].strip()
if exclude_tokens: if remove_tokens:
logger.debug( logger.debug(
"removing excluded tokens from prompt: %s", exclude_tokens "removing excluded tokens from prompt: %s", remove_tokens
) )
prompt = sub(exclude_tokens, "", prompt) prompt = sub(remove_tokens, "", prompt)
if retries >= RETRY_LIMIT: if retries >= RETRY_LIMIT:
logger.warning( logger.warning(
@ -78,6 +79,10 @@ class TextPromptStage(BaseStage):
prompt, prompt,
) )
if add_suffix:
prompt = f"{prompt}, {add_suffix}"
logger.trace("adding suffix to prompt: %s", prompt)
prompt_results.append(prompt) prompt_results.append(prompt)
complete_prompt = " || ".join(prompt_results) complete_prompt = " || ".join(prompt_results)

View File

@ -94,7 +94,7 @@ def run_txt2img_pipeline(
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
add_prompt_filter(server, chain) add_prompt_filter(server, chain, request.experimental)
chain.stage( chain.stage(
SourceTxt2ImgStage(), SourceTxt2ImgStage(),

View File

@ -1,3 +1,4 @@
/* eslint-disable max-params */
/* eslint-disable camelcase */ /* eslint-disable camelcase */
/* eslint-disable max-lines */ /* eslint-disable max-lines */
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
@ -351,7 +352,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as Array<string>;
}, },
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> { async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'img2img'); const url = makeApiURL(root, 'img2img');
const json = makeImageJSON({ const json = makeImageJSON({
model, model,
@ -359,6 +360,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale, upscale,
highres, highres,
img2img: params, img2img: params,
experimental,
}); });
const form = new FormData(); const form = new FormData();
@ -379,7 +381,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
}, },
}; };
}, },
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> { async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'txt2img'); const url = makeApiURL(root, 'txt2img');
const json = makeImageJSON({ const json = makeImageJSON({
model, model,
@ -387,6 +389,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
size: params, size: params,
upscale, upscale,
highres, highres,
experimental,
}); });
const form = new FormData(); const form = new FormData();
@ -407,7 +410,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
}, },
}; };
}, },
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> { async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'inpaint'); const url = makeApiURL(root, 'inpaint');
const json = makeImageJSON({ const json = makeImageJSON({
model, model,
@ -415,6 +418,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale, upscale,
highres, highres,
inpaint: params, inpaint: params,
experimental,
}); });
const form = new FormData(); const form = new FormData();
@ -436,7 +440,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
}, },
}; };
}, },
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> { async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry> {
const url = makeApiURL(root, 'inpaint'); const url = makeApiURL(root, 'inpaint');
const json = makeImageJSON({ const json = makeImageJSON({
model, model,
@ -444,6 +448,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
upscale, upscale,
highres, highres,
inpaint: params, inpaint: params,
experimental,
}); });
const form = new FormData(); const form = new FormData();

View File

@ -4,7 +4,18 @@ import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js'; import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js'; import { ChainPipeline } from '../types/chain.js';
import { ExtrasFile } from '../types/model.js'; import { ExtrasFile } from '../types/model.js';
import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js'; import {
BlendParams,
ExperimentalParams,
HighresParams,
Img2ImgParams,
InpaintParams,
ModelParams,
OutpaintParams,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
} from '../types/params.js';
export interface ApiClient { export interface ApiClient {
/** /**
@ -67,22 +78,22 @@ export interface ApiClient {
/** /**
* Start a txt2img pipeline. * Start a txt2img pipeline.
*/ */
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>; txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/** /**
* Start an im2img pipeline. * Start an im2img pipeline.
*/ */
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>; img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/** /**
* Start an inpaint pipeline. * Start an inpaint pipeline.
*/ */
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>; inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/** /**
* Start an outpaint pipeline. * Start an outpaint pipeline.
*/ */
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>; outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise<JobResponseWithRetry>;
/** /**
* Start an upscale pipeline. * Start an upscale pipeline.

View File

@ -32,7 +32,7 @@ export function Img2Img() {
const { job, retry } = await client.img2img(model, { const { job, retry } = await client.img2img(model, {
...img2img, ...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry); pushHistory(job, retry);
} }

View File

@ -46,7 +46,7 @@ export function Inpaint() {
...outpaint, ...outpaint,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry); pushHistory(job, retry);
} else { } else {
@ -54,7 +54,7 @@ export function Inpaint() {
...inpaint, ...inpaint,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state), selectExperimental(state));
pushHistory(job, retry); pushHistory(job, retry);
} }

View File

@ -63,16 +63,17 @@ export function Txt2Img() {
async function generateImage() { async function generateImage() {
const state = store.getState(); const state = store.getState();
const grid = selectVariable(state); const grid = selectVariable(state);
const params2 = selectParams(state); const params = selectParams(state);
const upscale = selectUpscale(state); const upscale = selectUpscale(state);
const highres = selectHighres(state); const highres = selectHighres(state);
const experimental = selectExperimental(state);
if (grid.enabled) { if (grid.enabled) {
const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres); const chain = makeTxt2ImgGridPipeline(grid, model, params, upscale, highres);
const image = await client.chain(model, chain); const image = await client.chain(model, chain);
pushHistory(image); pushHistory(image);
} else { } else {
const { job, retry } = await client.txt2img(model, params2, upscale, highres); const { job, retry } = await client.txt2img(model, params, upscale, highres, experimental);
pushHistory(job, retry); pushHistory(job, retry);
} }
} }