diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py
index 1dc69ee4..fb0bd523 100644
--- a/api/onnx_web/chain/text_prompt.py
+++ b/api/onnx_web/chain/text_prompt.py
@@ -1,6 +1,6 @@
from logging import getLogger
from random import randint
-from re import sub
+from re import match, sub
from typing import Optional
from transformers import pipeline
@@ -34,7 +34,7 @@ class TextPromptStage(BaseStage):
prompt_filter: str,
remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None,
- min_length: int = 150,
+ min_length: int = 80,
**kwargs,
) -> StageResult:
device = worker.device.torch_str()
@@ -69,7 +69,11 @@ class TextPromptStage(BaseStage):
logger.debug(
"removing excluded tokens from prompt: %s", remove_tokens
)
- prompt = sub(remove_tokens, "", prompt)
+
+ remove_limit = 3
+ while remove_limit > 0 and match(remove_tokens, prompt):
+ prompt = sub(remove_tokens, "", prompt)
+ remove_limit -= 1
if retries >= RETRY_LIMIT:
logger.warning(
diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py
index 9e0cc6f5..99b2a0f6 100644
--- a/api/onnx_web/diffusers/run.py
+++ b/api/onnx_web/diffusers/run.py
@@ -74,10 +74,10 @@ def add_prompt_filter(
pipeline.stage(
TextPromptStage(),
StageParams(),
- prompt_filter=experimental.prompt_editing.model,
- remove_tokens=experimental.prompt_editing.remove_tokens,
add_suffix=experimental.prompt_editing.add_suffix,
- # TODO: add min length to experimental params
+ min_length=experimental.prompt_editing.min_length,
+ prompt_filter=experimental.prompt_editing.filter,
+ remove_tokens=experimental.prompt_editing.remove_tokens,
)
else:
logger.warning("prompt editing is not supported by the server")
diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py
index 4d5ef00b..38fe535a 100644
--- a/api/onnx_web/params.py
+++ b/api/onnx_web/params.py
@@ -613,6 +613,7 @@ class PromptEditingParams:
filter: str
remove_tokens: str
add_suffix: str
+ min_length: int
def __init__(
self,
@@ -620,11 +621,13 @@ class PromptEditingParams:
filter: str,
remove_tokens: str,
add_suffix: str,
+ min_length: int,
) -> None:
self.enabled = enabled
self.filter = filter
self.remove_tokens = remove_tokens
self.add_suffix = add_suffix
+ self.min_length = min_length
class ExperimentalParams:
diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py
index 6dc36726..0e2e7c98 100644
--- a/api/onnx_web/server/params.py
+++ b/api/onnx_web/server/params.py
@@ -395,8 +395,17 @@ def build_prompt_editing(
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
+ min_length = get_and_clamp_int(
+ data,
+ "minLength",
+ get_config_value("minLength"),
+ get_config_value("minLength", "max"),
+ get_config_value("minLength", "min"),
+ )
- return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix)
+ return PromptEditingParams(
+ enabled, prompt_filter, remove_tokens, add_suffix, min_length
+ )
def build_experimental(
diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py
index a888133d..5d36f918 100644
--- a/api/onnx_web/utils.py
+++ b/api/onnx_web/utils.py
@@ -1,5 +1,6 @@
import importlib
import json
+from functools import reduce
from hashlib import sha256
from json import JSONDecodeError
from logging import getLogger
@@ -29,6 +30,10 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False)
+def recursive_get(d, *keys):
+ return reduce(lambda c, k: c.get(k, {}), keys, d)
+
+
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = args.get(key, str(default_value))
diff --git a/api/params.json b/api/params.json
index 20424e99..ae7db175 100644
--- a/api/params.json
+++ b/api/params.json
@@ -153,6 +153,12 @@
"max": 10,
"step": 1
},
+ "minLength": {
+ "default": 150,
+ "min": 1,
+ "max": 1000,
+ "step": 1
+ },
"model": {
"default": "stable-diffusion-onnx-v1-5",
"keys": []
diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts
index ed17b461..041d5ef8 100644
--- a/gui/src/client/api.ts
+++ b/gui/src/client/api.ts
@@ -125,7 +125,8 @@ export function makeImageJSON(params: ImageJSON): string {
}
if (doesExist(size)) {
- body.size = {
+ body.params = {
+ ...body.params,
width: size.width,
height: size.height,
};
@@ -174,9 +175,10 @@ export function makeImageJSON(params: ImageJSON): string {
},
promptEditing: {
enabled: experimental.promptEditing.enabled,
+ addSuffix: experimental.promptEditing.addSuffix,
+ minLength: experimental.promptEditing.minLength,
promptFilter: experimental.promptEditing.filter,
removeTokens: experimental.promptEditing.removeTokens,
- addSuffix: experimental.promptEditing.addSuffix,
},
};
}
diff --git a/gui/src/components/control/ExperimentalControl.tsx b/gui/src/components/control/ExperimentalControl.tsx
index 8f705ca9..089ce848 100644
--- a/gui/src/components/control/ExperimentalControl.tsx
+++ b/gui/src/components/control/ExperimentalControl.tsx
@@ -32,6 +32,87 @@ export function ExperimentalControl(props: ExperimentalControlProps) {
});
return
+
+ {
+ setExperimental({
+ promptEditing: {
+ ...experimental.promptEditing,
+ enabled: experimental.promptEditing.enabled === false,
+ },
+ });
+ }}
+ />}
+ />
+ f.prompt,
+ }}
+ value={mustDefault(experimental.promptEditing.filter, '')}
+ onChange={(prompt_filter) => {
+ setExperimental({
+ promptEditing: {
+ ...experimental.promptEditing,
+ filter: prompt_filter,
+ },
+ });
+ }}
+ />
+ {
+ setExperimental({
+ promptEditing: {
+ ...experimental.promptEditing,
+ removeTokens: event.target.value,
+ },
+ });
+ }}
+ />
+ {
+ setExperimental({
+ promptEditing: {
+ ...experimental.promptEditing,
+ addSuffix: event.target.value,
+ },
+ });
+ }}
+ />
+ {
+ setExperimental({
+ promptEditing: {
+ ...experimental.promptEditing,
+ minLength: prompt_editing_min_length,
+ },
+ });
+ }}
+ />
+
-
- {
- setExperimental({
- promptEditing: {
- ...experimental.promptEditing,
- enabled: experimental.promptEditing.enabled === false,
- },
- });
- }}
- />}
- />
- f.prompt,
- }}
- value={mustDefault(experimental.promptEditing.filter, '')}
- onChange={(prompt_filter) => {
- setExperimental({
- promptEditing: {
- ...experimental.promptEditing,
- filter: prompt_filter,
- },
- });
- }}
- />
- {
- setExperimental({
- promptEditing: {
- ...experimental.promptEditing,
- removeTokens: event.target.value,
- },
- });
- }}
- />
- {
- setExperimental({
- promptEditing: {
- ...experimental.promptEditing,
- addSuffix: event.target.value,
- },
- });
- }}
- />
-
;
}
diff --git a/gui/src/config.json b/gui/src/config.json
index 34dc5ff8..c4ca04a5 100644
--- a/gui/src/config.json
+++ b/gui/src/config.json
@@ -68,6 +68,7 @@
"default": "none",
"keys": []
},
+
"height": {
"default": 512,
"min": 128,
@@ -113,6 +114,29 @@
"default": "",
"keys": []
},
+ "latentSymmetry": {
+ "enabled": {
+ "default": false
+ },
+ "gradientStart": {
+ "default": 0.0,
+ "min": 0,
+ "max": 0.5,
+ "step": 0.01
+ },
+ "gradientEnd": {
+ "default": 0.25,
+ "min": 0.0,
+ "max": 0.5,
+ "step": 0.01
+ },
+ "lineOfSymmetry": {
+ "default": 0.5,
+ "min": 0,
+ "max": 1,
+ "step": 0.01
+ }
+ },
"left": {
"default": 0,
"min": 0,
@@ -159,6 +183,25 @@
"default": "an astronaut eating a hamburger",
"keys": []
},
+ "promptEditing": {
+ "default": false,
+ "addSuffix": {
+ "default": ""
+ },
+ "filter": {
+ "default": "none",
+ "keys": []
+ },
+ "minLength": {
+ "default": 80,
+ "min": 1,
+ "max": 200,
+ "step": 1
+ },
+ "removeTokens": {
+ "default": ""
+ }
+ },
"right": {
"default": 0,
"min": 0,
diff --git a/gui/src/config.ts b/gui/src/config.ts
index 97e454e3..b4b57cd5 100644
--- a/gui/src/config.ts
+++ b/gui/src/config.ts
@@ -47,7 +47,12 @@ export type ConfigFiles = {
* Map numbers and strings to their corresponding config types and drop the rest of the fields.
*/
export type ConfigRanges = {
- [K in KeyFilter]: T[K] extends boolean ? ConfigBoolean : T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never;
+ [K in KeyFilter]:
+ T[K] extends boolean ? ConfigBoolean :
+ T[K] extends number ? ConfigNumber :
+ T[K] extends string ? ConfigString :
+ T[K] extends object ? ConfigRanges :
+ never;
};
/**
diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts
index 17aec4c0..a73ce139 100644
--- a/gui/src/state/full.ts
+++ b/gui/src/state/full.ts
@@ -129,15 +129,16 @@ export function createStateSlices(server: ServerParams) {
const defaultExperimental: ExperimentalParams = {
promptEditing: {
enabled: false,
- filter: '',
- addSuffix : '',
- removeTokens: '',
+ filter: server.promptEditing.filter.default,
+ addSuffix: server.promptEditing.addSuffix.default,
+ removeTokens: server.promptEditing.removeTokens.default,
+ minLength: server.promptEditing.minLength.default,
},
latentSymmetry: {
enabled: false,
- gradientStart: 0,
- gradientEnd: 0,
- lineOfSymmetry: 0,
+ gradientStart: server.latentSymmetry.gradientStart.default,
+ gradientEnd: server.latentSymmetry.gradientEnd.default,
+ lineOfSymmetry: server.latentSymmetry.lineOfSymmetry.default,
},
};
const defaultGrid: PipelineGrid = {
diff --git a/gui/src/state/migration/default.ts b/gui/src/state/migration/default.ts
index 06efcc3b..34791d3c 100644
--- a/gui/src/state/migration/default.ts
+++ b/gui/src/state/migration/default.ts
@@ -122,49 +122,45 @@ export function migrateV7ToV11(params: ServerParams, previousState: OnnxStateV7)
export function migrateV11ToV13(params: ServerParams, previousState: OnnxStateV11): CurrentState {
// add any missing keys
+ const defaultLatentSymmetry = {
+ enabled: params.latentSymmetry.enabled.default,
+ gradientStart: params.latentSymmetry.gradientStart.default,
+ gradientEnd: params.latentSymmetry.gradientEnd.default,
+ lineOfSymmetry: params.latentSymmetry.lineOfSymmetry.default,
+ };
+ const defaultPromptEditing = {
+ enabled: params.promptEditing.enabled.default,
+ filter: params.promptEditing.filter.default,
+ addSuffix: params.promptEditing.addSuffix.default,
+ minLength: params.promptEditing.minLength.default,
+ removeTokens: params.promptEditing.removeTokens.default,
+ };
+
const result: CurrentState = {
...params,
...previousState,
txt2imgExperimental: {
latentSymmetry: {
- enabled: false,
- gradientStart: 0,
- gradientEnd: 0,
- lineOfSymmetry: 0,
+ ...defaultLatentSymmetry,
},
promptEditing: {
- enabled: false,
- filter: '',
- addSuffix: '',
- removeTokens: '',
+ ...defaultPromptEditing,
},
},
img2imgExperimental: {
latentSymmetry: {
- enabled: false,
- gradientStart: 0,
- gradientEnd: 0,
- lineOfSymmetry: 0,
+ ...defaultLatentSymmetry,
},
promptEditing: {
- enabled: false,
- filter: '',
- addSuffix: '',
- removeTokens: '',
+ ...defaultPromptEditing,
},
},
inpaintExperimental: {
latentSymmetry: {
- enabled: false,
- gradientStart: 0,
- gradientEnd: 0,
- lineOfSymmetry: 0,
+ ...defaultLatentSymmetry,
},
promptEditing: {
- enabled: false,
- filter: '',
- addSuffix: '',
- removeTokens: '',
+ ...defaultPromptEditing,
},
},
};
diff --git a/gui/src/types/params.ts b/gui/src/types/params.ts
index c009ba2d..b09a0bfe 100644
--- a/gui/src/types/params.ts
+++ b/gui/src/types/params.ts
@@ -181,5 +181,6 @@ export interface ExperimentalParams {
filter: string;
removeTokens: string;
addSuffix: string;
+ minLength: number;
};
}