1
0
Fork 0

feat(api): add batch size to txt2img and img2img pipelines (#195)

This commit is contained in:
Sean Sube 2023-02-20 08:35:18 -06:00
parent 0deaa8898d
commit 5f3b84827b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 77 additions and 50 deletions

View File

@ -1,4 +1,5 @@
from logging import getLogger
from typing import List
from PIL import Image
@ -16,12 +17,12 @@ def persist_disk(
_params: ImageParams,
source: Image.Image,
*,
output: str,
output: List[str],
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
dest = save_image(server, output, source)
dest = save_image(server, output[0], source)
logger.info("saved image to %s", dest)
return source

View File

@ -25,7 +25,7 @@ def run_txt2img_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size)
@ -50,6 +50,7 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
@ -64,27 +65,27 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
for image, output in zip(result.images, outputs):
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
del pipe
del image
del result
run_gc([job.get_device()])
@ -96,7 +97,7 @@ def run_img2img_pipeline(
job: JobContext,
server: ServerContext,
params: ImageParams,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
strength: float,
@ -119,6 +120,7 @@ def run_img2img_pipeline(
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
@ -132,29 +134,29 @@ def run_img2img_pipeline(
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
for image, output in zip(result.images, outputs):
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(server, output, image)
size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale)
del pipe
del image
del result
run_gc([job.get_device()])
@ -167,7 +169,7 @@ def run_inpaint_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
mask: Image.Image,
@ -202,8 +204,8 @@ def run_inpaint_pipeline(
job, server, stage, params, image, upscale=upscale, callback=progress
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale, border=border)
del image
@ -217,7 +219,7 @@ def run_upscale_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
) -> None:
@ -228,8 +230,8 @@ def run_upscale_pipeline(
job, server, stage, params, source, upscale=upscale, callback=progress
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale)
del image
@ -243,7 +245,7 @@ def run_blend_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
sources: List[Image.Image],
mask: Image.Image,
@ -266,8 +268,8 @@ def run_blend_pipeline(
job, server, stage, params, image, upscale=upscale, callback=progress
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale)
del image

View File

@ -4,7 +4,7 @@ from logging import getLogger
from os import path
from struct import pack
from time import time
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
from PIL import Image
@ -63,7 +63,7 @@ def make_output_name(
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None,
) -> str:
) -> List[str]:
now = int(time())
sha = sha256()
@ -82,13 +82,10 @@ def make_output_name(
for param in extras:
hash_value(sha, param)
return "%s_%s_%s_%s.%s" % (
mode,
params.seed,
sha.hexdigest(),
now,
ctx.image_format,
)
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
for i in range(params.batch)
]
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
@ -106,7 +103,7 @@ def save_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
) -> str:
path = base_join(ctx.output_path, "%s.json" % (output))
path = base_join(ctx.output_path, f"{output}.json")
json = json_params(output, params, size, upscale=upscale, border=border)
with open(path, "w") as f:
f.write(dumps(json))

View File

@ -156,6 +156,7 @@ class ImageParams:
negative_prompt: Optional[str] = None,
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
) -> None:
self.model = model
self.scheduler = scheduler
@ -166,6 +167,7 @@ class ImageParams:
self.steps = steps
self.lpw = lpw or False
self.eta = eta
self.batch = batch
def tojson(self) -> Dict[str, Optional[Param]]:
return {
@ -178,6 +180,7 @@ class ImageParams:
"steps": self.steps,
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
}
def with_args(self, **kwargs):
@ -191,6 +194,7 @@ class ImageParams:
kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
)

View File

@ -1,5 +1,11 @@
{
"version": "0.7.1",
"batch": {
"default": 1,
"min": 1,
"max": 5,
"step": 1
},
"bottom": {
"default": 0,
"min": 0,

View File

@ -42,6 +42,7 @@ export interface BaseImgParams {
prompt: string;
negativePrompt?: string;
batch: number;
cfg: number;
steps: number;
seed: number;

View File

@ -68,6 +68,21 @@ export function ImageControl(props: ImageControlProps) {
}
}}
/>
<NumericField
label='Batch Size'
min={params.batch.min}
max={params.batch.max}
step={params.batch.step}
value={controlState.batch}
onChange={(batch) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
batch,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}>
<NumericField

View File

@ -199,6 +199,7 @@ export const DEFAULT_HISTORY = {
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,