1
0
Fork 0

feat: add eta parameter (fixes #194)

This commit is contained in:
Sean Sube 2023-02-19 23:29:26 -06:00
parent f8cfc18479
commit c1189aad96
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 64 additions and 18 deletions

View File

@ -91,6 +91,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
eta=params.eta,
callback=callback, callback=callback,
) )
else: else:
@ -106,6 +107,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
eta=params.eta,
callback=callback, callback=callback,
) )

View File

@ -88,5 +88,6 @@ def upscale_stable_diffusion(
source, source,
generator=generator, generator=generator,
num_inference_steps=params.steps, num_inference_steps=params.steps,
eta=params.eta,
callback=callback, callback=callback,
).images[0] ).images[0]

View File

@ -51,6 +51,7 @@ def run_txt2img_pipeline(
latents=latents, latents=latents,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
eta=params.eta,
callback=progress, callback=progress,
) )
else: else:
@ -64,6 +65,7 @@ def run_txt2img_pipeline(
latents=latents, latents=latents,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
eta=params.eta,
callback=progress, callback=progress,
) )
@ -119,6 +121,7 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,
eta=params.eta,
callback=progress, callback=progress,
) )
else: else:
@ -131,6 +134,7 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,
eta=params.eta,
callback=progress, callback=progress,
) )

View File

@ -154,7 +154,8 @@ class ImageParams:
steps: int, steps: int,
seed: int, seed: int,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
lpw: Optional[bool] = False, lpw: bool = False,
eta: float = 0.0,
) -> None: ) -> None:
self.model = model self.model = model
self.scheduler = scheduler self.scheduler = scheduler
@ -164,6 +165,7 @@ class ImageParams:
self.seed = seed self.seed = seed
self.steps = steps self.steps = steps
self.lpw = lpw or False self.lpw = lpw or False
self.eta = eta
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
@ -175,6 +177,7 @@ class ImageParams:
"seed": self.seed, "seed": self.seed,
"steps": self.steps, "steps": self.steps,
"lpw": self.lpw, "lpw": self.lpw,
"eta": self.eta,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -187,6 +190,7 @@ class ImageParams:
kwargs.get("seed", self.seed), kwargs.get("seed", self.seed),
kwargs.get("negative_prompt", self.negative_prompt), kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("lpw", self.lpw), kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
) )

View File

@ -172,6 +172,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
get_config_value("cfg", "max"), get_config_value("cfg", "max"),
get_config_value("cfg", "min"), get_config_value("cfg", "min"),
) )
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, request.args,
"steps", "steps",
@ -220,6 +227,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
cfg, cfg,
steps, steps,
seed, seed,
eta=eta,
lpw=lpw, lpw=lpw,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
) )

View File

@ -22,6 +22,12 @@
"max": 1, "max": 1,
"step": 0.1 "step": 0.1
}, },
"eta": {
"default": 0.0,
"min": 0,
"max": 1,
"step": 0.1
},
"faceOutscale": { "faceOutscale": {
"default": 1, "default": 1,
"min": 1, "min": 1,

View File

@ -45,6 +45,7 @@ export interface BaseImgParams {
cfg: number; cfg: number;
steps: number; steps: number;
seed: number; seed: number;
eta: number;
} }
/** /**
@ -279,6 +280,7 @@ export function makeApiUrl(root: string, ...path: Array<string>) {
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL { export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
const url = makeApiUrl(root, type); const url = makeApiUrl(root, type);
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
if (doesExist(params.scheduler)) { if (doesExist(params.scheduler)) {

View File

@ -34,23 +34,41 @@ export function ImageControl(props: ImageControlProps) {
}); });
return <Stack spacing={2}> return <Stack spacing={2}>
<QueryList <Stack direction='row' spacing={4}>
id='schedulers' <QueryList
labels={SCHEDULER_LABELS} id='schedulers'
name='Scheduler' labels={SCHEDULER_LABELS}
query={{ name='Scheduler'
result: schedulers, query={{
}} result: schedulers,
value={mustDefault(controlState.scheduler, '')} }}
onChange={(value) => { value={mustDefault(controlState.scheduler, '')}
if (doesExist(props.onChange)) { onChange={(value) => {
props.onChange({ if (doesExist(props.onChange)) {
...controlState, props.onChange({
scheduler: value, ...controlState,
}); scheduler: value,
} });
}} }
/> }}
/>
<NumericField
decimal
label='Eta'
min={params.eta.min}
max={params.eta.max}
step={params.eta.step}
value={controlState.eta}
onChange={(eta) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
eta,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}> <Stack direction='row' spacing={4}>
<NumericField <NumericField
decimal decimal

View File

@ -200,6 +200,7 @@ export const DEFAULT_HISTORY = {
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> { export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return { return {
cfg: defaults.cfg.default, cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default, negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default, prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default, scheduler: defaults.scheduler.default,