add endpoint for multiple image generation
This commit is contained in:
parent
956a260db6
commit
1a732d54b6
|
@ -158,6 +158,7 @@ def make_output_name(
|
|||
size: Size,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
count: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> List[str]:
|
||||
count = count or params.batch
|
||||
now = int(time())
|
||||
|
@ -183,7 +184,7 @@ def make_output_name(
|
|||
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||
for i in range(count)
|
||||
for i in range(offset, count + offset)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ from .load import (
|
|||
from .params import (
|
||||
border_from_request,
|
||||
highres_from_request,
|
||||
pipeline_from_json,
|
||||
pipeline_from_request,
|
||||
upscale_from_request,
|
||||
)
|
||||
|
@ -221,7 +222,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output = make_output_name(server, "txt2img", params, size)
|
||||
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
||||
|
||||
job_name = output[0]
|
||||
pool.submit(
|
||||
|
@ -514,6 +515,61 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
def generate(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if not request.is_json():
|
||||
return error_reply("generate endpoint requires JSON parameters")
|
||||
|
||||
# TODO: should this accept YAML as well?
|
||||
data = request.get_json()
|
||||
schema = load_config("./schemas/generate.yaml")
|
||||
|
||||
logger.debug("validating generate request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
jobs = []
|
||||
|
||||
if "txt2img" in data:
|
||||
for job in data.get("txt2img"):
|
||||
device, params, size = pipeline_from_json(server, job, "txt2img")
|
||||
jobs.append((
|
||||
f"generate-txt2img-{len(jobs)}",
|
||||
run_txt2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
|
||||
None,
|
||||
None,
|
||||
device,
|
||||
))
|
||||
|
||||
if "img2img" in data:
|
||||
for job in data.get("img2img"):
|
||||
device, params, size = pipeline_from_json(server, job, "img2img")
|
||||
jobs.append((
|
||||
f"generate-img2img-{len(jobs)}",
|
||||
run_img2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
make_output_name(server, "img2img", params, size, offset=len(jobs))
|
||||
None,
|
||||
None,
|
||||
device,
|
||||
))
|
||||
|
||||
for job in jobs:
|
||||
pool.submit(*job)
|
||||
|
||||
# TODO: collect results
|
||||
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
|
||||
# needs to collect the previous outputs and put them on a grid. jobs write their own
|
||||
# output to disk and do not return it, so that may need to read the images based on the
|
||||
# output names assigned to each job. knowing when the jobs are done is the first problem.
|
||||
|
||||
# TODO: assemble grid
|
||||
|
||||
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from flask import request
|
||||
|
@ -34,6 +34,157 @@ from .utils import get_model_path
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def pipeline_from_json(
|
||||
server: ServerContext,
|
||||
data: Dict[str, Any],
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
device = None
|
||||
device_name = data.get("platform")
|
||||
|
||||
if device_name is not None and device_name != "any":
|
||||
for platform in get_available_platforms():
|
||||
if platform.device == device_name:
|
||||
device = platform
|
||||
|
||||
# diffusion model
|
||||
model = get_not_empty(data, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
||||
# pipeline stuff
|
||||
pipeline = get_from_list(
|
||||
data, "pipeline", get_available_pipelines(), default_pipeline
|
||||
)
|
||||
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = get_config_value("scheduler")
|
||||
|
||||
# prompt does not come from config
|
||||
prompt = data.get("prompt", "")
|
||||
negative_prompt = data.get("negativePrompt", None)
|
||||
|
||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||
negative_prompt = None
|
||||
|
||||
# image params
|
||||
batch = get_and_clamp_int(
|
||||
data,
|
||||
"batch",
|
||||
get_config_value("batch"),
|
||||
get_config_value("batch", "max"),
|
||||
get_config_value("batch", "min"),
|
||||
)
|
||||
cfg = get_and_clamp_float(
|
||||
data,
|
||||
"cfg",
|
||||
get_config_value("cfg"),
|
||||
get_config_value("cfg", "max"),
|
||||
get_config_value("cfg", "min"),
|
||||
)
|
||||
eta = get_and_clamp_float(
|
||||
data,
|
||||
"eta",
|
||||
get_config_value("eta"),
|
||||
get_config_value("eta", "max"),
|
||||
get_config_value("eta", "min"),
|
||||
)
|
||||
loopback = get_and_clamp_int(
|
||||
data,
|
||||
"loopback",
|
||||
get_config_value("loopback"),
|
||||
get_config_value("loopback", "max"),
|
||||
get_config_value("loopback", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
data,
|
||||
"steps",
|
||||
get_config_value("steps"),
|
||||
get_config_value("steps", "max"),
|
||||
get_config_value("steps", "min"),
|
||||
)
|
||||
height = get_and_clamp_int(
|
||||
data,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
data,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
|
||||
tiles = get_and_clamp_int(
|
||||
data,
|
||||
"tiles",
|
||||
get_config_value("tiles"),
|
||||
get_config_value("tiles", "max"),
|
||||
get_config_value("tiles", "min"),
|
||||
)
|
||||
overlap = get_and_clamp_float(
|
||||
data,
|
||||
"overlap",
|
||||
get_config_value("overlap"),
|
||||
get_config_value("overlap", "max"),
|
||||
get_config_value("overlap", "min"),
|
||||
)
|
||||
stride = get_and_clamp_int(
|
||||
data,
|
||||
"stride",
|
||||
get_config_value("stride"),
|
||||
get_config_value("stride", "max"),
|
||||
get_config_value("stride", "min"),
|
||||
)
|
||||
|
||||
if stride > tiles:
|
||||
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
|
||||
stride = tiles
|
||||
|
||||
seed = int(data.get("seed", -1))
|
||||
if seed == -1:
|
||||
# this one can safely use np.random because it produces a single value
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.debug(
|
||||
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
steps,
|
||||
scheduler,
|
||||
model_path,
|
||||
pipeline,
|
||||
device or "any device",
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
seed,
|
||||
prompt,
|
||||
)
|
||||
|
||||
params = ImageParams(
|
||||
model_path,
|
||||
pipeline,
|
||||
scheduler,
|
||||
prompt,
|
||||
cfg,
|
||||
steps,
|
||||
seed,
|
||||
eta=eta,
|
||||
negative_prompt=negative_prompt,
|
||||
batch=batch,
|
||||
# TODO: control=control,
|
||||
loopback=loopback,
|
||||
tiled_vae=tiled_vae,
|
||||
tiles=tiles,
|
||||
overlap=overlap,
|
||||
stride=stride,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
|
|
Loading…
Reference in New Issue