1
0
Fork 0

add endpoint for multiple image generation

This commit is contained in:
Sean Sube 2023-09-10 16:35:16 -05:00
parent 956a260db6
commit 1a732d54b6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 211 additions and 3 deletions

View File

@ -158,6 +158,7 @@ def make_output_name(
size: Size,
extras: Optional[List[Optional[Param]]] = None,
count: Optional[int] = None,
offset: int = 0,
) -> List[str]:
count = count or params.batch
now = int(time())
@ -183,7 +184,7 @@ def make_output_name(
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(count)
for i in range(offset, count + offset)
]

View File

@ -51,6 +51,7 @@ from .load import (
from .params import (
border_from_request,
highres_from_request,
pipeline_from_json,
pipeline_from_request,
upscale_from_request,
)
@ -221,7 +222,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size)
output = make_output_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0]
pool.submit(
@ -514,6 +515,61 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size))
def generate(server: ServerContext, pool: DevicePoolExecutor):
if not request.is_json():
return error_reply("generate endpoint requires JSON parameters")
# TODO: should this accept YAML as well?
data = request.get_json()
schema = load_config("./schemas/generate.yaml")
logger.debug("validating generate request: %s against %s", data, schema)
validate(data, schema)
jobs = []
if "txt2img" in data:
for job in data.get("txt2img"):
device, params, size = pipeline_from_json(server, job, "txt2img")
jobs.append((
f"generate-txt2img-{len(jobs)}",
run_txt2img_pipeline,
server,
params,
size,
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
None,
None,
device,
))
if "img2img" in data:
for job in data.get("img2img"):
device, params, size = pipeline_from_json(server, job, "img2img")
jobs.append((
f"generate-img2img-{len(jobs)}",
run_img2img_pipeline,
server,
params,
size,
make_output_name(server, "img2img", params, size, offset=len(jobs))
None,
None,
device,
))
for job in jobs:
pool.submit(*job)
# TODO: collect results
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
# needs to collect the previous outputs and put them on a grid. jobs write their own
# output to disk and do not return it, so that may need to read the images based on the
# output names assigned to each job. knowing when the jobs are done is the first problem.
# TODO: assemble grid
def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Tuple
from typing import Any, Dict, Tuple
import numpy as np
from flask import request
@ -34,6 +34,157 @@ from .utils import get_model_path
logger = getLogger(__name__)
def pipeline_from_json(
server: ServerContext,
data: Dict[str, Any],
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
device = None
device_name = data.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
# diffusion model
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
# pipeline stuff
pipeline = get_from_list(
data, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
# prompt does not come from config
prompt = data.get("prompt", "")
negative_prompt = data.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
# image params
batch = get_and_clamp_int(
data,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
data,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
data,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
data,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
data,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
data,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
)
overlap = get_and_clamp_float(
data,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
stride = get_and_clamp_int(
data,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
)
if stride > tiles:
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles
seed = int(data.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.debug(
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
pipeline,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
negative_prompt=negative_prompt,
batch=batch,
# TODO: control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
)
size = Size(width, height)
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",