1
0
Fork 0

add endpoint for multiple image generation

This commit is contained in:
Sean Sube 2023-09-10 16:35:16 -05:00
parent 956a260db6
commit 1a732d54b6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 211 additions and 3 deletions

View File

@ -158,6 +158,7 @@ def make_output_name(
size: Size, size: Size,
extras: Optional[List[Optional[Param]]] = None, extras: Optional[List[Optional[Param]]] = None,
count: Optional[int] = None, count: Optional[int] = None,
offset: int = 0,
) -> List[str]: ) -> List[str]:
count = count or params.batch count = count or params.batch
now = int(time()) now = int(time())
@ -183,7 +184,7 @@ def make_output_name(
return [ return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(count) for i in range(offset, count + offset)
] ]

View File

@ -51,6 +51,7 @@ from .load import (
from .params import ( from .params import (
border_from_request, border_from_request,
highres_from_request, highres_from_request,
pipeline_from_json,
pipeline_from_request, pipeline_from_request,
upscale_from_request, upscale_from_request,
) )
@ -221,7 +222,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size) output = make_output_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0] job_name = output[0]
pool.submit( pool.submit(
@ -514,6 +515,61 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))
def generate(server: ServerContext, pool: DevicePoolExecutor):
if not request.is_json():
return error_reply("generate endpoint requires JSON parameters")
# TODO: should this accept YAML as well?
data = request.get_json()
schema = load_config("./schemas/generate.yaml")
logger.debug("validating generate request: %s against %s", data, schema)
validate(data, schema)
jobs = []
if "txt2img" in data:
for job in data.get("txt2img"):
device, params, size = pipeline_from_json(server, job, "txt2img")
jobs.append((
f"generate-txt2img-{len(jobs)}",
run_txt2img_pipeline,
server,
params,
size,
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
None,
None,
device,
))
if "img2img" in data:
for job in data.get("img2img"):
device, params, size = pipeline_from_json(server, job, "img2img")
jobs.append((
f"generate-img2img-{len(jobs)}",
run_img2img_pipeline,
server,
params,
size,
make_output_name(server, "img2img", params, size, offset=len(jobs))
None,
None,
device,
))
for job in jobs:
pool.submit(*job)
# TODO: collect results
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
# needs to collect the previous outputs and put them on a grid. jobs write their own
# output to disk and do not return it, so that may need to read the images based on the
# output names assigned to each job. knowing when the jobs are done is the first problem.
# TODO: assemble grid
def cancel(server: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None: if output_file is None:

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Tuple from typing import Any, Dict, Tuple
import numpy as np import numpy as np
from flask import request from flask import request
@ -34,6 +34,157 @@ from .utils import get_model_path
logger = getLogger(__name__) logger = getLogger(__name__)
def pipeline_from_json(
server: ServerContext,
data: Dict[str, Any],
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
device = None
device_name = data.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
# diffusion model
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
# pipeline stuff
pipeline = get_from_list(
data, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
# prompt does not come from config
prompt = data.get("prompt", "")
negative_prompt = data.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
# image params
batch = get_and_clamp_int(
data,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
data,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
data,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
data,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
data,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
data,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
)
overlap = get_and_clamp_float(
data,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
stride = get_and_clamp_int(
data,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
)
if stride > tiles:
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles
seed = int(data.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.debug(
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
pipeline,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
negative_prompt=negative_prompt,
batch=batch,
# TODO: control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
)
size = Size(width, height)
return (device, params, size)
def pipeline_from_request( def pipeline_from_request(
server: ServerContext, server: ServerContext,
default_pipeline: str = "txt2img", default_pipeline: str = "txt2img",