From 0bbdc3b96bfc81d83fd4db1d628a04da80ab958d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 17 Feb 2024 15:50:04 -0600 Subject: [PATCH] implement experimental parameters --- api/onnx_web/chain/text_prompt.py | 17 +++++++++++------ api/onnx_web/diffusers/run.py | 2 +- gui/src/client/api.ts | 13 +++++++++---- gui/src/client/base.ts | 21 ++++++++++++++++----- gui/src/components/tab/Img2Img.tsx | 2 +- gui/src/components/tab/Inpaint.tsx | 4 ++-- gui/src/components/tab/Txt2Img.tsx | 7 ++++--- 7 files changed, 44 insertions(+), 22 deletions(-) diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py index 3c3f8733..bfa16237 100644 --- a/api/onnx_web/chain/text_prompt.py +++ b/api/onnx_web/chain/text_prompt.py @@ -31,15 +31,16 @@ class TextPromptStage(BaseStage): sources: StageResult, *, callback: Optional[ProgressCallback] = None, - prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", - exclude_tokens: Optional[str] = None, + prompt_filter: str = "Gustavosta/MagicPrompt-Stable-Diffusion", + remove_tokens: Optional[str] = None, + add_suffix: Optional[str] = None, min_length: int = 75, **kwargs, ) -> StageResult: device = worker.device.torch_str() text_pipe = pipeline( "text-generation", - model=prompt_model, + model=prompt_filter, device=device, framework="pt", ) @@ -64,11 +65,11 @@ class TextPromptStage(BaseStage): ) prompt = result[0]["generated_text"].strip() - if exclude_tokens: + if remove_tokens: logger.debug( - "removing excluded tokens from prompt: %s", exclude_tokens + "removing excluded tokens from prompt: %s", remove_tokens ) - prompt = sub(exclude_tokens, "", prompt) + prompt = sub(remove_tokens, "", prompt) if retries >= RETRY_LIMIT: logger.warning( @@ -78,6 +79,10 @@ class TextPromptStage(BaseStage): prompt, ) + if add_suffix: + prompt = f"{prompt}, {add_suffix}" + logger.trace("adding suffix to prompt: %s", prompt) + prompt_results.append(prompt) complete_prompt = " || ".join(prompt_results) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 675ae0ec..fac07cef 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -94,7 +94,7 @@ def run_txt2img_pipeline( # prepare the chain pipeline and first stage chain = ChainPipeline() - add_prompt_filter(server, chain) + add_prompt_filter(server, chain, request.experimental) chain.stage( SourceTxt2ImgStage(), diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 56a690c3..ed17b461 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -1,3 +1,4 @@ +/* eslint-disable max-params */ /* eslint-disable camelcase */ /* eslint-disable max-lines */ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; @@ -351,7 +352,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { const url = makeApiURL(root, 'img2img'); const json = makeImageJSON({ model, @@ -359,6 +360,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe { + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { const url = makeApiURL(root, 'txt2img'); const json = makeImageJSON({ model, @@ -387,6 +389,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe { + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { const url = makeApiURL(root, 'inpaint'); const json = makeImageJSON({ model, @@ -415,6 +418,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe { + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { const url = makeApiURL(root, 'inpaint'); const json = makeImageJSON({ model, @@ -444,6 +448,7 @@ export function makeClient(root: string, batchInterval: number, token: Maybe; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an im2img pipeline. */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an inpaint pipeline. */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an outpaint pipeline. */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an upscale pipeline. diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index 9cd34a40..66ad7814 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -32,7 +32,7 @@ export function Img2Img() { const { job, retry } = await client.img2img(model, { ...img2img, source: mustExist(img2img.source), // TODO: show an error if this doesn't exist - }, selectUpscale(state), selectHighres(state)); + }, selectUpscale(state), selectHighres(state), selectExperimental(state)); pushHistory(job, retry); } diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index a0c614b3..a36acc72 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -46,7 +46,7 @@ export function Inpaint() { ...outpaint, mask: mustExist(mask), source: mustExist(source), - }, selectUpscale(state), selectHighres(state)); + }, selectUpscale(state), selectHighres(state), selectExperimental(state)); pushHistory(job, retry); } else { @@ -54,7 +54,7 @@ export function Inpaint() { ...inpaint, mask: mustExist(mask), source: mustExist(source), - }, selectUpscale(state), selectHighres(state)); + }, selectUpscale(state), selectHighres(state), selectExperimental(state)); pushHistory(job, retry); } diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 2de95424..2447f7ec 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -63,16 +63,17 @@ export function Txt2Img() { async function generateImage() { const state = store.getState(); const grid = selectVariable(state); - const params2 = selectParams(state); + const params = selectParams(state); const upscale = selectUpscale(state); const highres = selectHighres(state); + const experimental = selectExperimental(state); if (grid.enabled) { - const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres); + const chain = makeTxt2ImgGridPipeline(grid, model, params, upscale, highres); const image = await client.chain(model, chain); pushHistory(image); } else { - const { job, retry } = await client.txt2img(model, params2, upscale, highres); + const { job, retry } = await client.txt2img(model, params, upscale, highres, experimental); pushHistory(job, retry); } }