From 7743e8d738c9c88d086fc4e35f6bfdef78ade943 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 24 Feb 2024 11:06:32 -0600 Subject: [PATCH] use experimental defaults from server --- api/onnx_web/chain/text_prompt.py | 10 +- api/onnx_web/diffusers/run.py | 6 +- api/onnx_web/params.py | 3 + api/onnx_web/server/params.py | 11 +- api/onnx_web/utils.py | 5 + api/params.json | 6 + gui/src/client/api.ts | 6 +- .../control/ExperimentalControl.tsx | 146 ++++++++++-------- gui/src/config.json | 43 ++++++ gui/src/config.ts | 7 +- gui/src/state/full.ts | 13 +- gui/src/state/migration/default.ts | 44 +++--- gui/src/types/params.ts | 1 + 13 files changed, 196 insertions(+), 105 deletions(-) diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py index 1dc69ee4..fb0bd523 100644 --- a/api/onnx_web/chain/text_prompt.py +++ b/api/onnx_web/chain/text_prompt.py @@ -1,6 +1,6 @@ from logging import getLogger from random import randint -from re import sub +from re import match, sub from typing import Optional from transformers import pipeline @@ -34,7 +34,7 @@ class TextPromptStage(BaseStage): prompt_filter: str, remove_tokens: Optional[str] = None, add_suffix: Optional[str] = None, - min_length: int = 150, + min_length: int = 80, **kwargs, ) -> StageResult: device = worker.device.torch_str() @@ -69,7 +69,11 @@ class TextPromptStage(BaseStage): logger.debug( "removing excluded tokens from prompt: %s", remove_tokens ) - prompt = sub(remove_tokens, "", prompt) + + remove_limit = 3 + while remove_limit > 0 and match(remove_tokens, prompt): + prompt = sub(remove_tokens, "", prompt) + remove_limit -= 1 if retries >= RETRY_LIMIT: logger.warning( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 9e0cc6f5..99b2a0f6 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -74,10 +74,10 @@ def add_prompt_filter( pipeline.stage( TextPromptStage(), StageParams(), - prompt_filter=experimental.prompt_editing.model, - remove_tokens=experimental.prompt_editing.remove_tokens, add_suffix=experimental.prompt_editing.add_suffix, - # TODO: add min length to experimental params + min_length=experimental.prompt_editing.min_length, + prompt_filter=experimental.prompt_editing.filter, + remove_tokens=experimental.prompt_editing.remove_tokens, ) else: logger.warning("prompt editing is not supported by the server") diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 4d5ef00b..38fe535a 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -613,6 +613,7 @@ class PromptEditingParams: filter: str remove_tokens: str add_suffix: str + min_length: int def __init__( self, @@ -620,11 +621,13 @@ class PromptEditingParams: filter: str, remove_tokens: str, add_suffix: str, + min_length: int, ) -> None: self.enabled = enabled self.filter = filter self.remove_tokens = remove_tokens self.add_suffix = add_suffix + self.min_length = min_length class ExperimentalParams: diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 6dc36726..0e2e7c98 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -395,8 +395,17 @@ def build_prompt_editing( prompt_filter = data.get("promptFilter", "") remove_tokens = data.get("removeTokens", "") add_suffix = data.get("addSuffix", "") + min_length = get_and_clamp_int( + data, + "minLength", + get_config_value("minLength"), + get_config_value("minLength", "max"), + get_config_value("minLength", "min"), + ) - return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix) + return PromptEditingParams( + enabled, prompt_filter, remove_tokens, add_suffix, min_length + ) def build_experimental( diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index a888133d..5d36f918 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,5 +1,6 @@ import importlib import json +from functools import reduce from hashlib import sha256 from json import JSONDecodeError from logging import getLogger @@ -29,6 +30,10 @@ def is_debug() -> bool: return get_boolean(environ, "DEBUG", False) +def recursive_get(d, *keys): + return reduce(lambda c, k: c.get(k, {}), keys, d) + + def get_boolean(args: Any, key: str, default_value: bool) -> bool: val = args.get(key, str(default_value)) diff --git a/api/params.json b/api/params.json index 20424e99..ae7db175 100644 --- a/api/params.json +++ b/api/params.json @@ -153,6 +153,12 @@ "max": 10, "step": 1 }, + "minLength": { + "default": 150, + "min": 1, + "max": 1000, + "step": 1 + }, "model": { "default": "stable-diffusion-onnx-v1-5", "keys": [] diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index ed17b461..041d5ef8 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -125,7 +125,8 @@ export function makeImageJSON(params: ImageJSON): string { } if (doesExist(size)) { - body.size = { + body.params = { + ...body.params, width: size.width, height: size.height, }; @@ -174,9 +175,10 @@ export function makeImageJSON(params: ImageJSON): string { }, promptEditing: { enabled: experimental.promptEditing.enabled, + addSuffix: experimental.promptEditing.addSuffix, + minLength: experimental.promptEditing.minLength, promptFilter: experimental.promptEditing.filter, removeTokens: experimental.promptEditing.removeTokens, - addSuffix: experimental.promptEditing.addSuffix, }, }; } diff --git a/gui/src/components/control/ExperimentalControl.tsx b/gui/src/components/control/ExperimentalControl.tsx index 8f705ca9..089ce848 100644 --- a/gui/src/components/control/ExperimentalControl.tsx +++ b/gui/src/components/control/ExperimentalControl.tsx @@ -32,6 +32,87 @@ export function ExperimentalControl(props: ExperimentalControlProps) { }); return + + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + enabled: experimental.promptEditing.enabled === false, + }, + }); + }} + />} + /> + f.prompt, + }} + value={mustDefault(experimental.promptEditing.filter, '')} + onChange={(prompt_filter) => { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + filter: prompt_filter, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + removeTokens: event.target.value, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + addSuffix: event.target.value, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + minLength: prompt_editing_min_length, + }, + }); + }} + /> + - - { - setExperimental({ - promptEditing: { - ...experimental.promptEditing, - enabled: experimental.promptEditing.enabled === false, - }, - }); - }} - />} - /> - f.prompt, - }} - value={mustDefault(experimental.promptEditing.filter, '')} - onChange={(prompt_filter) => { - setExperimental({ - promptEditing: { - ...experimental.promptEditing, - filter: prompt_filter, - }, - }); - }} - /> - { - setExperimental({ - promptEditing: { - ...experimental.promptEditing, - removeTokens: event.target.value, - }, - }); - }} - /> - { - setExperimental({ - promptEditing: { - ...experimental.promptEditing, - addSuffix: event.target.value, - }, - }); - }} - /> - ; } diff --git a/gui/src/config.json b/gui/src/config.json index 34dc5ff8..c4ca04a5 100644 --- a/gui/src/config.json +++ b/gui/src/config.json @@ -68,6 +68,7 @@ "default": "none", "keys": [] }, + "height": { "default": 512, "min": 128, @@ -113,6 +114,29 @@ "default": "", "keys": [] }, + "latentSymmetry": { + "enabled": { + "default": false + }, + "gradientStart": { + "default": 0.0, + "min": 0, + "max": 0.5, + "step": 0.01 + }, + "gradientEnd": { + "default": 0.25, + "min": 0.0, + "max": 0.5, + "step": 0.01 + }, + "lineOfSymmetry": { + "default": 0.5, + "min": 0, + "max": 1, + "step": 0.01 + } + }, "left": { "default": 0, "min": 0, @@ -159,6 +183,25 @@ "default": "an astronaut eating a hamburger", "keys": [] }, + "promptEditing": { + "default": false, + "addSuffix": { + "default": "" + }, + "filter": { + "default": "none", + "keys": [] + }, + "minLength": { + "default": 80, + "min": 1, + "max": 200, + "step": 1 + }, + "removeTokens": { + "default": "" + } + }, "right": { "default": 0, "min": 0, diff --git a/gui/src/config.ts b/gui/src/config.ts index 97e454e3..b4b57cd5 100644 --- a/gui/src/config.ts +++ b/gui/src/config.ts @@ -47,7 +47,12 @@ export type ConfigFiles = { * Map numbers and strings to their corresponding config types and drop the rest of the fields. */ export type ConfigRanges = { - [K in KeyFilter]: T[K] extends boolean ? ConfigBoolean : T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never; + [K in KeyFilter]: + T[K] extends boolean ? ConfigBoolean : + T[K] extends number ? ConfigNumber : + T[K] extends string ? ConfigString : + T[K] extends object ? ConfigRanges : + never; }; /** diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts index 17aec4c0..a73ce139 100644 --- a/gui/src/state/full.ts +++ b/gui/src/state/full.ts @@ -129,15 +129,16 @@ export function createStateSlices(server: ServerParams) { const defaultExperimental: ExperimentalParams = { promptEditing: { enabled: false, - filter: '', - addSuffix : '', - removeTokens: '', + filter: server.promptEditing.filter.default, + addSuffix: server.promptEditing.addSuffix.default, + removeTokens: server.promptEditing.removeTokens.default, + minLength: server.promptEditing.minLength.default, }, latentSymmetry: { enabled: false, - gradientStart: 0, - gradientEnd: 0, - lineOfSymmetry: 0, + gradientStart: server.latentSymmetry.gradientStart.default, + gradientEnd: server.latentSymmetry.gradientEnd.default, + lineOfSymmetry: server.latentSymmetry.lineOfSymmetry.default, }, }; const defaultGrid: PipelineGrid = { diff --git a/gui/src/state/migration/default.ts b/gui/src/state/migration/default.ts index 06efcc3b..34791d3c 100644 --- a/gui/src/state/migration/default.ts +++ b/gui/src/state/migration/default.ts @@ -122,49 +122,45 @@ export function migrateV7ToV11(params: ServerParams, previousState: OnnxStateV7) export function migrateV11ToV13(params: ServerParams, previousState: OnnxStateV11): CurrentState { // add any missing keys + const defaultLatentSymmetry = { + enabled: params.latentSymmetry.enabled.default, + gradientStart: params.latentSymmetry.gradientStart.default, + gradientEnd: params.latentSymmetry.gradientEnd.default, + lineOfSymmetry: params.latentSymmetry.lineOfSymmetry.default, + }; + const defaultPromptEditing = { + enabled: params.promptEditing.enabled.default, + filter: params.promptEditing.filter.default, + addSuffix: params.promptEditing.addSuffix.default, + minLength: params.promptEditing.minLength.default, + removeTokens: params.promptEditing.removeTokens.default, + }; + const result: CurrentState = { ...params, ...previousState, txt2imgExperimental: { latentSymmetry: { - enabled: false, - gradientStart: 0, - gradientEnd: 0, - lineOfSymmetry: 0, + ...defaultLatentSymmetry, }, promptEditing: { - enabled: false, - filter: '', - addSuffix: '', - removeTokens: '', + ...defaultPromptEditing, }, }, img2imgExperimental: { latentSymmetry: { - enabled: false, - gradientStart: 0, - gradientEnd: 0, - lineOfSymmetry: 0, + ...defaultLatentSymmetry, }, promptEditing: { - enabled: false, - filter: '', - addSuffix: '', - removeTokens: '', + ...defaultPromptEditing, }, }, inpaintExperimental: { latentSymmetry: { - enabled: false, - gradientStart: 0, - gradientEnd: 0, - lineOfSymmetry: 0, + ...defaultLatentSymmetry, }, promptEditing: { - enabled: false, - filter: '', - addSuffix: '', - removeTokens: '', + ...defaultPromptEditing, }, }, }; diff --git a/gui/src/types/params.ts b/gui/src/types/params.ts index c009ba2d..b09a0bfe 100644 --- a/gui/src/types/params.ts +++ b/gui/src/types/params.ts @@ -181,5 +181,6 @@ export interface ExperimentalParams { filter: string; removeTokens: string; addSuffix: string; + minLength: number; }; }