feat(api): add batch size to txt2img and img2img pipelines (#195)
This commit is contained in:
parent
0deaa8898d
commit
5f3b84827b
|
@ -1,4 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -16,12 +17,12 @@ def persist_disk(
|
|||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
output: str,
|
||||
output: List[str],
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
dest = save_image(server, output, source)
|
||||
dest = save_image(server, output[0], source)
|
||||
logger.info("saved image to %s", dest)
|
||||
return source
|
||||
|
|
|
@ -25,7 +25,7 @@ def run_txt2img_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
) -> None:
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
|
@ -50,6 +50,7 @@ def run_txt2img_pipeline(
|
|||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
|
@ -64,12 +65,13 @@ def run_txt2img_pipeline(
|
|||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
for image, output in zip(result.images, outputs):
|
||||
image = run_upscale_correction(
|
||||
job,
|
||||
server,
|
||||
|
@ -84,7 +86,6 @@ def run_txt2img_pipeline(
|
|||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del pipe
|
||||
del image
|
||||
del result
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
@ -96,7 +97,7 @@ def run_img2img_pipeline(
|
|||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
output: str,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
source: Image.Image,
|
||||
strength: float,
|
||||
|
@ -119,6 +120,7 @@ def run_img2img_pipeline(
|
|||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
eta=params.eta,
|
||||
|
@ -132,13 +134,14 @@ def run_img2img_pipeline(
|
|||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
for image, output in zip(result.images, outputs):
|
||||
image = run_upscale_correction(
|
||||
job,
|
||||
server,
|
||||
|
@ -154,7 +157,6 @@ def run_img2img_pipeline(
|
|||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del pipe
|
||||
del image
|
||||
del result
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
@ -167,7 +169,7 @@ def run_inpaint_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
source: Image.Image,
|
||||
mask: Image.Image,
|
||||
|
@ -202,8 +204,8 @@ def run_inpaint_pipeline(
|
|||
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
dest = save_image(server, outputs[0], image)
|
||||
save_params(server, outputs[0], params, size, upscale=upscale, border=border)
|
||||
|
||||
del image
|
||||
|
||||
|
@ -217,7 +219,7 @@ def run_upscale_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
source: Image.Image,
|
||||
) -> None:
|
||||
|
@ -228,8 +230,8 @@ def run_upscale_pipeline(
|
|||
job, server, stage, params, source, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
dest = save_image(server, outputs[0], image)
|
||||
save_params(server, outputs[0], params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
|
||||
|
@ -243,7 +245,7 @@ def run_blend_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
sources: List[Image.Image],
|
||||
mask: Image.Image,
|
||||
|
@ -266,8 +268,8 @@ def run_blend_pipeline(
|
|||
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
dest = save_image(server, outputs[0], image)
|
||||
save_params(server, outputs[0], params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from logging import getLogger
|
|||
from os import path
|
||||
from struct import pack
|
||||
from time import time
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -63,7 +63,7 @@ def make_output_name(
|
|||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[Tuple[Param]] = None,
|
||||
) -> str:
|
||||
) -> List[str]:
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
||||
|
@ -82,13 +82,10 @@ def make_output_name(
|
|||
for param in extras:
|
||||
hash_value(sha, param)
|
||||
|
||||
return "%s_%s_%s_%s.%s" % (
|
||||
mode,
|
||||
params.seed,
|
||||
sha.hexdigest(),
|
||||
now,
|
||||
ctx.image_format,
|
||||
)
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
|
||||
for i in range(params.batch)
|
||||
]
|
||||
|
||||
|
||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||
|
@ -106,7 +103,7 @@ def save_params(
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
) -> str:
|
||||
path = base_join(ctx.output_path, "%s.json" % (output))
|
||||
path = base_join(ctx.output_path, f"{output}.json")
|
||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
||||
with open(path, "w") as f:
|
||||
f.write(dumps(json))
|
||||
|
|
|
@ -156,6 +156,7 @@ class ImageParams:
|
|||
negative_prompt: Optional[str] = None,
|
||||
lpw: bool = False,
|
||||
eta: float = 0.0,
|
||||
batch: int = 1,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -166,6 +167,7 @@ class ImageParams:
|
|||
self.steps = steps
|
||||
self.lpw = lpw or False
|
||||
self.eta = eta
|
||||
self.batch = batch
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
|
@ -178,6 +180,7 @@ class ImageParams:
|
|||
"steps": self.steps,
|
||||
"lpw": self.lpw,
|
||||
"eta": self.eta,
|
||||
"batch": self.batch,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -191,6 +194,7 @@ class ImageParams:
|
|||
kwargs.get("negative_prompt", self.negative_prompt),
|
||||
kwargs.get("lpw", self.lpw),
|
||||
kwargs.get("eta", self.eta),
|
||||
kwargs.get("batch", self.batch),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
{
|
||||
"version": "0.7.1",
|
||||
"batch": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 5,
|
||||
"step": 1
|
||||
},
|
||||
"bottom": {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
|
|
|
@ -42,6 +42,7 @@ export interface BaseImgParams {
|
|||
prompt: string;
|
||||
negativePrompt?: string;
|
||||
|
||||
batch: number;
|
||||
cfg: number;
|
||||
steps: number;
|
||||
seed: number;
|
||||
|
|
|
@ -68,6 +68,21 @@ export function ImageControl(props: ImageControlProps) {
|
|||
}
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Batch Size'
|
||||
min={params.batch.min}
|
||||
max={params.batch.max}
|
||||
step={params.batch.step}
|
||||
value={controlState.batch}
|
||||
onChange={(batch) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
props.onChange({
|
||||
...controlState,
|
||||
batch,
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
|
|
|
@ -199,6 +199,7 @@ export const DEFAULT_HISTORY = {
|
|||
|
||||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||
return {
|
||||
batch: defaults.batch.default,
|
||||
cfg: defaults.cfg.default,
|
||||
eta: defaults.eta.default,
|
||||
negativePrompt: defaults.negativePrompt.default,
|
||||
|
|
Loading…
Reference in New Issue