1
0
Fork 0

update API entrypoints for multi-image

This commit is contained in:
Sean Sube 2023-07-04 13:56:02 -05:00
parent e1fcbb9093
commit 99a073aed2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 50 additions and 47 deletions

View File

@ -72,14 +72,11 @@ class ChainPipeline:
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
sources: List[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
) -> List[Image.Image]:
return self(job, server, params, sources=sources, callback=callback, **kwargs)
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
@ -161,10 +158,10 @@ class ChainPipeline:
server,
stage_params,
params,
source_tile,
[source_tile],
callback=callback,
**kwargs,
)
)[0]
if is_debug():
save_image(server, "last-tile.png", output_tile)
@ -176,7 +173,7 @@ class ChainPipeline:
source,
tile,
stage_params.outscale,
[stage_tile],
stage_tile,
**kwargs,
)
stage_outputs.append(output)

View File

@ -80,20 +80,22 @@ def run_txt2img_pipeline(
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, None, callback=progress)
images = chain(job, server, params, [], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up
run_gc([job.get_device()])
@ -170,7 +172,7 @@ def run_img2img_pipeline(
# run and append the filtered source
progress = job.get_progress_callback()
images = [
chain(job, server, params, source, callback=progress),
chain(job, server, params, [source], callback=progress),
]
if source_filter is not None and source_filter != "none":
@ -261,20 +263,21 @@ def run_inpaint_pipeline(
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, source, callback=progress)
images = chain(job, server, params, [source], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
)
# clean up
del image
@ -328,19 +331,20 @@ def run_upscale_pipeline(
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, source, callback=progress)
images = chain(job, server, params, [source], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
# clean up
del image
@ -377,8 +381,10 @@ def run_blend_pipeline(
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, sources[0], callback=progress)
dest = save_image(server, outputs[0], image, params, size, upscale=upscale)
images = chain(job, server, params, sources, callback=progress)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)
# clean up
del image