feat: add eta parameter (fixes #194)
This commit is contained in:
parent
f8cfc18479
commit
c1189aad96
|
@ -91,6 +91,7 @@ def blend_inpaint(
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
|
@ -106,6 +107,7 @@ def blend_inpaint(
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
|
|
@ -88,5 +88,6 @@ def upscale_stable_diffusion(
|
|||
source,
|
||||
generator=generator,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
).images[0]
|
||||
|
|
|
@ -51,6 +51,7 @@ def run_txt2img_pipeline(
|
|||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
else:
|
||||
|
@ -64,6 +65,7 @@ def run_txt2img_pipeline(
|
|||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
|
@ -119,6 +121,7 @@ def run_img2img_pipeline(
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
else:
|
||||
|
@ -131,6 +134,7 @@ def run_img2img_pipeline(
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
|
|
|
@ -154,7 +154,8 @@ class ImageParams:
|
|||
steps: int,
|
||||
seed: int,
|
||||
negative_prompt: Optional[str] = None,
|
||||
lpw: Optional[bool] = False,
|
||||
lpw: bool = False,
|
||||
eta: float = 0.0,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -164,6 +165,7 @@ class ImageParams:
|
|||
self.seed = seed
|
||||
self.steps = steps
|
||||
self.lpw = lpw or False
|
||||
self.eta = eta
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
|
@ -175,6 +177,7 @@ class ImageParams:
|
|||
"seed": self.seed,
|
||||
"steps": self.steps,
|
||||
"lpw": self.lpw,
|
||||
"eta": self.eta,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -187,6 +190,7 @@ class ImageParams:
|
|||
kwargs.get("seed", self.seed),
|
||||
kwargs.get("negative_prompt", self.negative_prompt),
|
||||
kwargs.get("lpw", self.lpw),
|
||||
kwargs.get("eta", self.eta),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -172,6 +172,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
get_config_value("cfg", "max"),
|
||||
get_config_value("cfg", "min"),
|
||||
)
|
||||
eta = get_and_clamp_float(
|
||||
request.args,
|
||||
"eta",
|
||||
get_config_value("eta"),
|
||||
get_config_value("eta", "max"),
|
||||
get_config_value("eta", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args,
|
||||
"steps",
|
||||
|
@ -220,6 +227,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
cfg,
|
||||
steps,
|
||||
seed,
|
||||
eta=eta,
|
||||
lpw=lpw,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
|
|
|
@ -22,6 +22,12 @@
|
|||
"max": 1,
|
||||
"step": 0.1
|
||||
},
|
||||
"eta": {
|
||||
"default": 0.0,
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 0.1
|
||||
},
|
||||
"faceOutscale": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
|
|
|
@ -45,6 +45,7 @@ export interface BaseImgParams {
|
|||
cfg: number;
|
||||
steps: number;
|
||||
seed: number;
|
||||
eta: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -279,6 +280,7 @@ export function makeApiUrl(root: string, ...path: Array<string>) {
|
|||
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
|
||||
const url = makeApiUrl(root, type);
|
||||
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
|
|
|
@ -34,23 +34,41 @@ export function ImageControl(props: ImageControlProps) {
|
|||
});
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<QueryList
|
||||
id='schedulers'
|
||||
labels={SCHEDULER_LABELS}
|
||||
name='Scheduler'
|
||||
query={{
|
||||
result: schedulers,
|
||||
}}
|
||||
value={mustDefault(controlState.scheduler, '')}
|
||||
onChange={(value) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...controlState,
|
||||
scheduler: value,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<QueryList
|
||||
id='schedulers'
|
||||
labels={SCHEDULER_LABELS}
|
||||
name='Scheduler'
|
||||
query={{
|
||||
result: schedulers,
|
||||
}}
|
||||
value={mustDefault(controlState.scheduler, '')}
|
||||
onChange={(value) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...controlState,
|
||||
scheduler: value,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
decimal
|
||||
label='Eta'
|
||||
min={params.eta.min}
|
||||
max={params.eta.max}
|
||||
step={params.eta.step}
|
||||
value={controlState.eta}
|
||||
onChange={(eta) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...controlState,
|
||||
eta,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
decimal
|
||||
|
|
|
@ -200,6 +200,7 @@ export const DEFAULT_HISTORY = {
|
|||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||
return {
|
||||
cfg: defaults.cfg.default,
|
||||
eta: defaults.eta.default,
|
||||
negativePrompt: defaults.negativePrompt.default,
|
||||
prompt: defaults.prompt.default,
|
||||
scheduler: defaults.scheduler.default,
|
||||
|
|
Loading…
Reference in New Issue