1
0
Fork 0

use experimental defaults from server

This commit is contained in:
Sean Sube 2024-02-24 11:06:32 -06:00
parent cdef20ffb6
commit 7743e8d738
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 196 additions and 105 deletions

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from random import randint from random import randint
from re import sub from re import match, sub
from typing import Optional from typing import Optional
from transformers import pipeline from transformers import pipeline
@ -34,7 +34,7 @@ class TextPromptStage(BaseStage):
prompt_filter: str, prompt_filter: str,
remove_tokens: Optional[str] = None, remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None, add_suffix: Optional[str] = None,
min_length: int = 150, min_length: int = 80,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
device = worker.device.torch_str() device = worker.device.torch_str()
@ -69,7 +69,11 @@ class TextPromptStage(BaseStage):
logger.debug( logger.debug(
"removing excluded tokens from prompt: %s", remove_tokens "removing excluded tokens from prompt: %s", remove_tokens
) )
prompt = sub(remove_tokens, "", prompt)
remove_limit = 3
while remove_limit > 0 and match(remove_tokens, prompt):
prompt = sub(remove_tokens, "", prompt)
remove_limit -= 1
if retries >= RETRY_LIMIT: if retries >= RETRY_LIMIT:
logger.warning( logger.warning(

View File

@ -74,10 +74,10 @@ def add_prompt_filter(
pipeline.stage( pipeline.stage(
TextPromptStage(), TextPromptStage(),
StageParams(), StageParams(),
prompt_filter=experimental.prompt_editing.model,
remove_tokens=experimental.prompt_editing.remove_tokens,
add_suffix=experimental.prompt_editing.add_suffix, add_suffix=experimental.prompt_editing.add_suffix,
# TODO: add min length to experimental params min_length=experimental.prompt_editing.min_length,
prompt_filter=experimental.prompt_editing.filter,
remove_tokens=experimental.prompt_editing.remove_tokens,
) )
else: else:
logger.warning("prompt editing is not supported by the server") logger.warning("prompt editing is not supported by the server")

View File

@ -613,6 +613,7 @@ class PromptEditingParams:
filter: str filter: str
remove_tokens: str remove_tokens: str
add_suffix: str add_suffix: str
min_length: int
def __init__( def __init__(
self, self,
@ -620,11 +621,13 @@ class PromptEditingParams:
filter: str, filter: str,
remove_tokens: str, remove_tokens: str,
add_suffix: str, add_suffix: str,
min_length: int,
) -> None: ) -> None:
self.enabled = enabled self.enabled = enabled
self.filter = filter self.filter = filter
self.remove_tokens = remove_tokens self.remove_tokens = remove_tokens
self.add_suffix = add_suffix self.add_suffix = add_suffix
self.min_length = min_length
class ExperimentalParams: class ExperimentalParams:

View File

@ -395,8 +395,17 @@ def build_prompt_editing(
prompt_filter = data.get("promptFilter", "") prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "") remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "") add_suffix = data.get("addSuffix", "")
min_length = get_and_clamp_int(
data,
"minLength",
get_config_value("minLength"),
get_config_value("minLength", "max"),
get_config_value("minLength", "min"),
)
return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix) return PromptEditingParams(
enabled, prompt_filter, remove_tokens, add_suffix, min_length
)
def build_experimental( def build_experimental(

View File

@ -1,5 +1,6 @@
import importlib import importlib
import json import json
from functools import reduce
from hashlib import sha256 from hashlib import sha256
from json import JSONDecodeError from json import JSONDecodeError
from logging import getLogger from logging import getLogger
@ -29,6 +30,10 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False) return get_boolean(environ, "DEBUG", False)
def recursive_get(d, *keys):
return reduce(lambda c, k: c.get(k, {}), keys, d)
def get_boolean(args: Any, key: str, default_value: bool) -> bool: def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = args.get(key, str(default_value)) val = args.get(key, str(default_value))

View File

@ -153,6 +153,12 @@
"max": 10, "max": 10,
"step": 1 "step": 1
}, },
"minLength": {
"default": 150,
"min": 1,
"max": 1000,
"step": 1
},
"model": { "model": {
"default": "stable-diffusion-onnx-v1-5", "default": "stable-diffusion-onnx-v1-5",
"keys": [] "keys": []

View File

@ -125,7 +125,8 @@ export function makeImageJSON(params: ImageJSON): string {
} }
if (doesExist(size)) { if (doesExist(size)) {
body.size = { body.params = {
...body.params,
width: size.width, width: size.width,
height: size.height, height: size.height,
}; };
@ -174,9 +175,10 @@ export function makeImageJSON(params: ImageJSON): string {
}, },
promptEditing: { promptEditing: {
enabled: experimental.promptEditing.enabled, enabled: experimental.promptEditing.enabled,
addSuffix: experimental.promptEditing.addSuffix,
minLength: experimental.promptEditing.minLength,
promptFilter: experimental.promptEditing.filter, promptFilter: experimental.promptEditing.filter,
removeTokens: experimental.promptEditing.removeTokens, removeTokens: experimental.promptEditing.removeTokens,
addSuffix: experimental.promptEditing.addSuffix,
}, },
}; };
} }

View File

@ -32,6 +32,87 @@ export function ExperimentalControl(props: ExperimentalControlProps) {
}); });
return <Stack spacing={STANDARD_SPACING}> return <Stack spacing={STANDARD_SPACING}>
<Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel
label={t('experimental.prompt_editing.label')}
control={
<Checkbox
checked={experimental.promptEditing.enabled}
value='check'
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
enabled: experimental.promptEditing.enabled === false,
},
});
}}
/>}
/>
<QueryList
disabled={experimental.promptEditing.enabled === false}
id='prompt_filters'
labelKey='model.prompt'
name={t('experimental.prompt_editing.filter')}
query={{
result: filters,
selector: (f) => f.prompt,
}}
value={mustDefault(experimental.promptEditing.filter, '')}
onChange={(prompt_filter) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
filter: prompt_filter,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.remove_tokens')}
variant='outlined'
value={experimental.promptEditing.removeTokens}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
removeTokens: event.target.value,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.add_suffix')}
variant='outlined'
value={experimental.promptEditing.addSuffix}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
addSuffix: event.target.value,
},
});
}}
/>
<NumericField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.min_length')}
min={1}
max={1000}
step={1}
value={experimental.promptEditing.minLength}
onChange={(prompt_editing_min_length) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
minLength: prompt_editing_min_length,
},
});
}}
/>
</Stack>
<Stack direction='row' spacing={STANDARD_SPACING}> <Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel <FormControlLabel
label={t('experimental.latent_symmetry.label')} label={t('experimental.latent_symmetry.label')}
@ -101,70 +182,5 @@ export function ExperimentalControl(props: ExperimentalControlProps) {
}} }}
/> />
</Stack> </Stack>
<Stack direction='row' spacing={STANDARD_SPACING}>
<FormControlLabel
label={t('experimental.prompt_editing.label')}
control={
<Checkbox
checked={experimental.promptEditing.enabled}
value='check'
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
enabled: experimental.promptEditing.enabled === false,
},
});
}}
/>}
/>
<QueryList
disabled={experimental.promptEditing.enabled === false}
id='prompt_filters'
labelKey='model.prompt'
name={t('experimental.prompt_editing.filter')}
query={{
result: filters,
selector: (f) => f.prompt,
}}
value={mustDefault(experimental.promptEditing.filter, '')}
onChange={(prompt_filter) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
filter: prompt_filter,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.remove_tokens')}
variant='outlined'
value={experimental.promptEditing.removeTokens}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
removeTokens: event.target.value,
},
});
}}
/>
<TextField
disabled={experimental.promptEditing.enabled === false}
label={t('experimental.prompt_editing.add_suffix')}
variant='outlined'
value={experimental.promptEditing.addSuffix}
onChange={(event) => {
setExperimental({
promptEditing: {
...experimental.promptEditing,
addSuffix: event.target.value,
},
});
}}
/>
</Stack>
</Stack>; </Stack>;
} }

View File

@ -68,6 +68,7 @@
"default": "none", "default": "none",
"keys": [] "keys": []
}, },
"height": { "height": {
"default": 512, "default": 512,
"min": 128, "min": 128,
@ -113,6 +114,29 @@
"default": "", "default": "",
"keys": [] "keys": []
}, },
"latentSymmetry": {
"enabled": {
"default": false
},
"gradientStart": {
"default": 0.0,
"min": 0,
"max": 0.5,
"step": 0.01
},
"gradientEnd": {
"default": 0.25,
"min": 0.0,
"max": 0.5,
"step": 0.01
},
"lineOfSymmetry": {
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.01
}
},
"left": { "left": {
"default": 0, "default": 0,
"min": 0, "min": 0,
@ -159,6 +183,25 @@
"default": "an astronaut eating a hamburger", "default": "an astronaut eating a hamburger",
"keys": [] "keys": []
}, },
"promptEditing": {
"default": false,
"addSuffix": {
"default": ""
},
"filter": {
"default": "none",
"keys": []
},
"minLength": {
"default": 80,
"min": 1,
"max": 200,
"step": 1
},
"removeTokens": {
"default": ""
}
},
"right": { "right": {
"default": 0, "default": 0,
"min": 0, "min": 0,

View File

@ -47,7 +47,12 @@ export type ConfigFiles<T extends object> = {
* Map numbers and strings to their corresponding config types and drop the rest of the fields. * Map numbers and strings to their corresponding config types and drop the rest of the fields.
*/ */
export type ConfigRanges<T extends object> = { export type ConfigRanges<T extends object> = {
[K in KeyFilter<T>]: T[K] extends boolean ? ConfigBoolean : T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never; [K in KeyFilter<T, boolean | number | string | object>]:
T[K] extends boolean ? ConfigBoolean :
T[K] extends number ? ConfigNumber :
T[K] extends string ? ConfigString :
T[K] extends object ? ConfigRanges<T[K]> :
never;
}; };
/** /**

View File

@ -129,15 +129,16 @@ export function createStateSlices(server: ServerParams) {
const defaultExperimental: ExperimentalParams = { const defaultExperimental: ExperimentalParams = {
promptEditing: { promptEditing: {
enabled: false, enabled: false,
filter: '', filter: server.promptEditing.filter.default,
addSuffix : '', addSuffix: server.promptEditing.addSuffix.default,
removeTokens: '', removeTokens: server.promptEditing.removeTokens.default,
minLength: server.promptEditing.minLength.default,
}, },
latentSymmetry: { latentSymmetry: {
enabled: false, enabled: false,
gradientStart: 0, gradientStart: server.latentSymmetry.gradientStart.default,
gradientEnd: 0, gradientEnd: server.latentSymmetry.gradientEnd.default,
lineOfSymmetry: 0, lineOfSymmetry: server.latentSymmetry.lineOfSymmetry.default,
}, },
}; };
const defaultGrid: PipelineGrid = { const defaultGrid: PipelineGrid = {

View File

@ -122,49 +122,45 @@ export function migrateV7ToV11(params: ServerParams, previousState: OnnxStateV7)
export function migrateV11ToV13(params: ServerParams, previousState: OnnxStateV11): CurrentState { export function migrateV11ToV13(params: ServerParams, previousState: OnnxStateV11): CurrentState {
// add any missing keys // add any missing keys
const defaultLatentSymmetry = {
enabled: params.latentSymmetry.enabled.default,
gradientStart: params.latentSymmetry.gradientStart.default,
gradientEnd: params.latentSymmetry.gradientEnd.default,
lineOfSymmetry: params.latentSymmetry.lineOfSymmetry.default,
};
const defaultPromptEditing = {
enabled: params.promptEditing.enabled.default,
filter: params.promptEditing.filter.default,
addSuffix: params.promptEditing.addSuffix.default,
minLength: params.promptEditing.minLength.default,
removeTokens: params.promptEditing.removeTokens.default,
};
const result: CurrentState = { const result: CurrentState = {
...params, ...params,
...previousState, ...previousState,
txt2imgExperimental: { txt2imgExperimental: {
latentSymmetry: { latentSymmetry: {
enabled: false, ...defaultLatentSymmetry,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
}, },
promptEditing: { promptEditing: {
enabled: false, ...defaultPromptEditing,
filter: '',
addSuffix: '',
removeTokens: '',
}, },
}, },
img2imgExperimental: { img2imgExperimental: {
latentSymmetry: { latentSymmetry: {
enabled: false, ...defaultLatentSymmetry,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
}, },
promptEditing: { promptEditing: {
enabled: false, ...defaultPromptEditing,
filter: '',
addSuffix: '',
removeTokens: '',
}, },
}, },
inpaintExperimental: { inpaintExperimental: {
latentSymmetry: { latentSymmetry: {
enabled: false, ...defaultLatentSymmetry,
gradientStart: 0,
gradientEnd: 0,
lineOfSymmetry: 0,
}, },
promptEditing: { promptEditing: {
enabled: false, ...defaultPromptEditing,
filter: '',
addSuffix: '',
removeTokens: '',
}, },
}, },
}; };

View File

@ -181,5 +181,6 @@ export interface ExperimentalParams {
filter: string; filter: string;
removeTokens: string; removeTokens: string;
addSuffix: string; addSuffix: string;
minLength: number;
}; };
} }