feat: add eta parameter (fixes #194)
This commit is contained in:
parent
f8cfc18479
commit
c1189aad96
|
@ -91,6 +91,7 @@ def blend_inpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
eta=params.eta,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -106,6 +107,7 @@ def blend_inpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
eta=params.eta,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -88,5 +88,6 @@ def upscale_stable_diffusion(
|
||||||
source,
|
source,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
eta=params.eta,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
|
@ -51,6 +51,7 @@ def run_txt2img_pipeline(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -64,6 +65,7 @@ def run_txt2img_pipeline(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -119,6 +121,7 @@ def run_img2img_pipeline(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -131,6 +134,7 @@ def run_img2img_pipeline(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -154,7 +154,8 @@ class ImageParams:
|
||||||
steps: int,
|
steps: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
lpw: Optional[bool] = False,
|
lpw: bool = False,
|
||||||
|
eta: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
|
@ -164,6 +165,7 @@ class ImageParams:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.lpw = lpw or False
|
self.lpw = lpw or False
|
||||||
|
self.eta = eta
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
|
@ -175,6 +177,7 @@ class ImageParams:
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"steps": self.steps,
|
"steps": self.steps,
|
||||||
"lpw": self.lpw,
|
"lpw": self.lpw,
|
||||||
|
"eta": self.eta,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -187,6 +190,7 @@ class ImageParams:
|
||||||
kwargs.get("seed", self.seed),
|
kwargs.get("seed", self.seed),
|
||||||
kwargs.get("negative_prompt", self.negative_prompt),
|
kwargs.get("negative_prompt", self.negative_prompt),
|
||||||
kwargs.get("lpw", self.lpw),
|
kwargs.get("lpw", self.lpw),
|
||||||
|
kwargs.get("eta", self.eta),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -172,6 +172,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
get_config_value("cfg", "max"),
|
get_config_value("cfg", "max"),
|
||||||
get_config_value("cfg", "min"),
|
get_config_value("cfg", "min"),
|
||||||
)
|
)
|
||||||
|
eta = get_and_clamp_float(
|
||||||
|
request.args,
|
||||||
|
"eta",
|
||||||
|
get_config_value("eta"),
|
||||||
|
get_config_value("eta", "max"),
|
||||||
|
get_config_value("eta", "min"),
|
||||||
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args,
|
request.args,
|
||||||
"steps",
|
"steps",
|
||||||
|
@ -220,6 +227,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
cfg,
|
cfg,
|
||||||
steps,
|
steps,
|
||||||
seed,
|
seed,
|
||||||
|
eta=eta,
|
||||||
lpw=lpw,
|
lpw=lpw,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,6 +22,12 @@
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
},
|
},
|
||||||
|
"eta": {
|
||||||
|
"default": 0.0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"step": 0.1
|
||||||
|
},
|
||||||
"faceOutscale": {
|
"faceOutscale": {
|
||||||
"default": 1,
|
"default": 1,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
|
|
|
@ -45,6 +45,7 @@ export interface BaseImgParams {
|
||||||
cfg: number;
|
cfg: number;
|
||||||
steps: number;
|
steps: number;
|
||||||
seed: number;
|
seed: number;
|
||||||
|
eta: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -279,6 +280,7 @@ export function makeApiUrl(root: string, ...path: Array<string>) {
|
||||||
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
|
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
|
||||||
const url = makeApiUrl(root, type);
|
const url = makeApiUrl(root, type);
|
||||||
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
||||||
|
url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT));
|
||||||
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
||||||
|
|
||||||
if (doesExist(params.scheduler)) {
|
if (doesExist(params.scheduler)) {
|
||||||
|
|
|
@ -34,23 +34,41 @@ export function ImageControl(props: ImageControlProps) {
|
||||||
});
|
});
|
||||||
|
|
||||||
return <Stack spacing={2}>
|
return <Stack spacing={2}>
|
||||||
<QueryList
|
<Stack direction='row' spacing={4}>
|
||||||
id='schedulers'
|
<QueryList
|
||||||
labels={SCHEDULER_LABELS}
|
id='schedulers'
|
||||||
name='Scheduler'
|
labels={SCHEDULER_LABELS}
|
||||||
query={{
|
name='Scheduler'
|
||||||
result: schedulers,
|
query={{
|
||||||
}}
|
result: schedulers,
|
||||||
value={mustDefault(controlState.scheduler, '')}
|
}}
|
||||||
onChange={(value) => {
|
value={mustDefault(controlState.scheduler, '')}
|
||||||
if (doesExist(props.onChange)) {
|
onChange={(value) => {
|
||||||
props.onChange({
|
if (doesExist(props.onChange)) {
|
||||||
...controlState,
|
props.onChange({
|
||||||
scheduler: value,
|
...controlState,
|
||||||
});
|
scheduler: value,
|
||||||
}
|
});
|
||||||
}}
|
}
|
||||||
/>
|
}}
|
||||||
|
/>
|
||||||
|
<NumericField
|
||||||
|
decimal
|
||||||
|
label='Eta'
|
||||||
|
min={params.eta.min}
|
||||||
|
max={params.eta.max}
|
||||||
|
step={params.eta.step}
|
||||||
|
value={controlState.eta}
|
||||||
|
onChange={(eta) => {
|
||||||
|
if (doesExist(props.onChange)) {
|
||||||
|
props.onChange({
|
||||||
|
...controlState,
|
||||||
|
eta,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Stack>
|
||||||
<Stack direction='row' spacing={4}>
|
<Stack direction='row' spacing={4}>
|
||||||
<NumericField
|
<NumericField
|
||||||
decimal
|
decimal
|
||||||
|
|
|
@ -200,6 +200,7 @@ export const DEFAULT_HISTORY = {
|
||||||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||||
return {
|
return {
|
||||||
cfg: defaults.cfg.default,
|
cfg: defaults.cfg.default,
|
||||||
|
eta: defaults.eta.default,
|
||||||
negativePrompt: defaults.negativePrompt.default,
|
negativePrompt: defaults.negativePrompt.default,
|
||||||
prompt: defaults.prompt.default,
|
prompt: defaults.prompt.default,
|
||||||
scheduler: defaults.scheduler.default,
|
scheduler: defaults.scheduler.default,
|
||||||
|
|
Loading…
Reference in New Issue