1
0
Fork 0

use experimental defaults from server

This commit is contained in:
Sean Sube 2024-02-24 11:06:32 -06:00
parent cdef20ffb6
commit 7743e8d738
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 196 additions and 105 deletions

View File

@ -1,6 +1,6 @@
from logging import getLogger
from random import randint
from re import sub
from re import match, sub
from typing import Optional
from transformers import pipeline
@ -34,7 +34,7 @@ class TextPromptStage(BaseStage):
prompt_filter: str,
remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None,
min_length: int = 150,
min_length: int = 80,
**kwargs,
) -> StageResult:
device = worker.device.torch_str()
@ -69,7 +69,11 @@ class TextPromptStage(BaseStage):
logger.debug(
"removing excluded tokens from prompt: %s", remove_tokens
)
prompt = sub(remove_tokens, "", prompt)
remove_limit = 3
while remove_limit > 0 and match(remove_tokens, prompt):
prompt = sub(remove_tokens, "", prompt)
remove_limit -= 1
if retries >= RETRY_LIMIT:
logger.warning(

View File

@ -74,10 +74,10 @@ def add_prompt_filter(
pipeline.stage(
TextPromptStage(),
StageParams(),
prompt_filter=experimental.prompt_editing.model,
remove_tokens=experimental.prompt_editing.remove_tokens,
add_suffix=experimental.prompt_editing.add_suffix,
# TODO: add min length to experimental params
min_length=experimental.prompt_editing.min_length,
prompt_filter=experimental.prompt_editing.filter,
remove_tokens=experimental.prompt_editing.remove_tokens,
)
else:
logger.warning("prompt editing is not supported by the server")

View File

@ -613,6 +613,7 @@ class PromptEditingParams:
filter: str
remove_tokens: str
add_suffix: str
min_length: int
def __init__(
self,
@ -620,11 +621,13 @@ class PromptEditingParams:
filter: str,
remove_tokens: str,
add_suffix: str,
min_length: int,
) -> None:
self.enabled = enabled
self.filter = filter
self.remove_tokens = remove_tokens
self.add_suffix = add_suffix
self.min_length = min_length
class ExperimentalParams:

View File

@ -395,8 +395,17 @@ def build_prompt_editing(
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
min_length = get_and_clamp_int(
data,
"minLength",
get_config_value("minLength"),
get_config_value("minLength", "max"),
get_config_value("minLength", "min"),
)
return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix)
return PromptEditingParams(
enabled, prompt_filter, remove_tokens, add_suffix, min_length
)
def build_experimental(

View File

@ -1,5 +1,6 @@
import importlib
import json
from functools import reduce
from hashlib import sha256
from json import JSONDecodeError
from logging import getLogger
@ -29,6 +30,10 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False)
def recursive_get(d, *keys):
return reduce(lambda c, k: c.get(k, {}), keys, d)
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = args.get(key, str(default_value))

View File

@ -153,6 +153,12 @@
"max": 10,
"step": 1
},
"minLength": {
"default": 150,
"min": 1,
"max": 1000,
"step": 1
},
"model": {
"default": "stable-diffusion-onnx-v1-5",
"keys": []

View File

@ -125,7 +125,8 @@ export function makeImageJSON(params: ImageJSON): string {
}
if (doesExist(size)) {
body.size = {
body.params = {
...body.params,
width: size.width,
height: size.height,
};
@ -174,9 +175,10 @@ export function makeImageJSON(params: ImageJSON): string {
},
promptEditing: {
enabled: experimental.promptEditing.enabled,
addSuffix: experimental.promptEditing.addSuffix,
minLength: experimental.promptEditing.minLength,
promptFilter: experimental.promptEditing.filter,
removeTokens: experimental.promptEditing.removeTokens,
addSuffix: experimental.promptEditing.addSuffix,
},
};
}

View File

@ -32,6 +32,87 @@ export function ExperimentalControl(props: ExperimentalControlProps) {
});
return <Stack spacing={STANDARD_SPACING}>
<Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel
label={t('experimental.prompt_editing.label')}
control={
<Checkbox
checked={experimental.promptEditing.enabled}
value='check'
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
enabled: experimental.promptEditing.enabled === false,
},
});
}}
/>}
/>
<QueryList
disabled={experimental.promptEditing.enabled === false}
id='prompt_filters'
labelKey='model.prompt'
name={t('experimental.prompt_editing.filter')}
query={{
result: filters,
selector: (f) => f.prompt,
}}
value={mustDefault(experimental.promptEditing.filter, '')}
onChange={(prompt_filter) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
filter: prompt_filter,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.remove_tokens')}
variant='outlined'
value={experimental.promptEditing.removeTokens}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
removeTokens: event.target.value,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.add_suffix')}
variant='outlined'
value={experimental.promptEditing.addSuffix}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
addSuffix: event.target.value,
},
});
}}
/>
<NumericField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.min_length')}
min={1}
max={1000}
step={1}
value={experimental.promptEditing.minLength}
onChange={(prompt_editing_min_length) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
minLength: prompt_editing_min_length,
},
});
}}
/>
</Stack>
<Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel
label={t('experimental.latent_symmetry.label')}
@ -101,70 +182,5 @@ export function ExperimentalControl(props: ExperimentalControlProps) {
}}
/>
</Stack>
<Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel
label={t('experimental.prompt_editing.label')}
control={
<Checkbox
checked={experimental.promptEditing.enabled}
value='check'
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
enabled: experimental.promptEditing.enabled === false,
},
});
}}
/>}
/>
<QueryList
disabled={experimental.promptEditing.enabled === false}
id='prompt_filters'
labelKey='model.prompt'
name={t('experimental.prompt_editing.filter')}
query={{
result: filters,
selector: (f) => f.prompt,
}}
value={mustDefault(experimental.promptEditing.filter, '')}
onChange={(prompt_filter) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
filter: prompt_filter,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.remove_tokens')}
variant='outlined'
value={experimental.promptEditing.removeTokens}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
removeTokens: event.target.value,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.add_suffix')}
variant='outlined'
value={experimental.promptEditing.addSuffix}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
addSuffix: event.target.value,
},
});
}}
/>
</Stack>
</Stack>;
}

View File

@ -68,6 +68,7 @@
"default": "none",
"keys": []
},
"height": {
"default": 512,
"min": 128,
@ -113,6 +114,29 @@
"default": "",
"keys": []
},
"latentSymmetry": {
"enabled": {
"default": false
},
"gradientStart": {
"default": 0.0,
"min": 0,
"max": 0.5,
"step": 0.01
},
"gradientEnd": {
"default": 0.25,
"min": 0.0,
"max": 0.5,
"step": 0.01
},
"lineOfSymmetry": {
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.01
}
},
"left": {
"default": 0,
"min": 0,
@ -159,6 +183,25 @@
"default": "an astronaut eating a hamburger",
"keys": []
},
"promptEditing": {
"default": false,
"addSuffix": {
"default": ""
},
"filter": {
"default": "none",
"keys": []
},
"minLength": {
"default": 80,
"min": 1,
"max": 200,
"step": 1
},
"removeTokens": {
"default": ""
}
},
"right": {
"default": 0,
"min": 0,

View File

@ -47,7 +47,12 @@ export type ConfigFiles<T extends object> = {
* Map numbers and strings to their corresponding config types and drop the rest of the fields.
*/
export type ConfigRanges<T extends object> = {
[K in KeyFilter<T>]: T[K] extends boolean ? ConfigBoolean : T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never;
[K in KeyFilter<T, boolean | number | string | object>]:
T[K] extends boolean ? ConfigBoolean :
T[K] extends number ? ConfigNumber :
T[K] extends string ? ConfigString :
T[K] extends object ? ConfigRanges<T[K]> :
never;
};
/**

View File

@ -129,15 +129,16 @@ export function createStateSlices(server: ServerParams) {
const defaultExperimental: ExperimentalParams = {
promptEditing: {
enabled: false,
filter: '',
addSuffix : '',
removeTokens: '',
filter: server.promptEditing.filter.default,
addSuffix: server.promptEditing.addSuffix.default,
removeTokens: server.promptEditing.removeTokens.default,
minLength: server.promptEditing.minLength.default,
},
latentSymmetry: {
enabled: false,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
gradientStart: server.latentSymmetry.gradientStart.default,
gradientEnd: server.latentSymmetry.gradientEnd.default,
lineOfSymmetry: server.latentSymmetry.lineOfSymmetry.default,
},
};
const defaultGrid: PipelineGrid = {

View File

@ -122,49 +122,45 @@ export function migrateV7ToV11(params: ServerParams, previousState: OnnxStateV7)
export function migrateV11ToV13(params: ServerParams, previousState: OnnxStateV11): CurrentState {
// add any missing keys
const defaultLatentSymmetry = {
enabled: params.latentSymmetry.enabled.default,
gradientStart: params.latentSymmetry.gradientStart.default,
gradientEnd: params.latentSymmetry.gradientEnd.default,
lineOfSymmetry: params.latentSymmetry.lineOfSymmetry.default,
};
const defaultPromptEditing = {
enabled: params.promptEditing.enabled.default,
filter: params.promptEditing.filter.default,
addSuffix: params.promptEditing.addSuffix.default,
minLength: params.promptEditing.minLength.default,
removeTokens: params.promptEditing.removeTokens.default,
};
const result: CurrentState = {
...params,
...previousState,
txt2imgExperimental: {
latentSymmetry: {
enabled: false,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
...defaultLatentSymmetry,
},
promptEditing: {
enabled: false,
filter: '',
addSuffix: '',
removeTokens: '',
...defaultPromptEditing,
},
},
img2imgExperimental: {
latentSymmetry: {
enabled: false,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
...defaultLatentSymmetry,
},
promptEditing: {
enabled: false,
filter: '',
addSuffix: '',
removeTokens: '',
...defaultPromptEditing,
},
},
inpaintExperimental: {
latentSymmetry: {
enabled: false,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
...defaultLatentSymmetry,
},
promptEditing: {
enabled: false,
filter: '',
addSuffix: '',
removeTokens: '',
...defaultPromptEditing,
},
},
};

View File

@ -181,5 +181,6 @@ export interface ExperimentalParams {
filter: string;
removeTokens: string;
addSuffix: string;
minLength: number;
};
}