add endpoint for multiple image generation
This commit is contained in:
parent
956a260db6
commit
1a732d54b6
|
@ -158,6 +158,7 @@ def make_output_name(
|
||||||
size: Size,
|
size: Size,
|
||||||
extras: Optional[List[Optional[Param]]] = None,
|
extras: Optional[List[Optional[Param]]] = None,
|
||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
count = count or params.batch
|
count = count or params.batch
|
||||||
now = int(time())
|
now = int(time())
|
||||||
|
@ -183,7 +184,7 @@ def make_output_name(
|
||||||
|
|
||||||
return [
|
return [
|
||||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||||
for i in range(count)
|
for i in range(offset, count + offset)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ from .load import (
|
||||||
from .params import (
|
from .params import (
|
||||||
border_from_request,
|
border_from_request,
|
||||||
highres_from_request,
|
highres_from_request,
|
||||||
|
pipeline_from_json,
|
||||||
pipeline_from_request,
|
pipeline_from_request,
|
||||||
upscale_from_request,
|
upscale_from_request,
|
||||||
)
|
)
|
||||||
|
@ -221,7 +222,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output = make_output_name(server, "txt2img", params, size)
|
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
||||||
|
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
pool.submit(
|
pool.submit(
|
||||||
|
@ -514,6 +515,61 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
|
def generate(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if not request.is_json():
|
||||||
|
return error_reply("generate endpoint requires JSON parameters")
|
||||||
|
|
||||||
|
# TODO: should this accept YAML as well?
|
||||||
|
data = request.get_json()
|
||||||
|
schema = load_config("./schemas/generate.yaml")
|
||||||
|
|
||||||
|
logger.debug("validating generate request: %s against %s", data, schema)
|
||||||
|
validate(data, schema)
|
||||||
|
|
||||||
|
jobs = []
|
||||||
|
|
||||||
|
if "txt2img" in data:
|
||||||
|
for job in data.get("txt2img"):
|
||||||
|
device, params, size = pipeline_from_json(server, job, "txt2img")
|
||||||
|
jobs.append((
|
||||||
|
f"generate-txt2img-{len(jobs)}",
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
device,
|
||||||
|
))
|
||||||
|
|
||||||
|
if "img2img" in data:
|
||||||
|
for job in data.get("img2img"):
|
||||||
|
device, params, size = pipeline_from_json(server, job, "img2img")
|
||||||
|
jobs.append((
|
||||||
|
f"generate-img2img-{len(jobs)}",
|
||||||
|
run_img2img_pipeline,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
make_output_name(server, "img2img", params, size, offset=len(jobs))
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
device,
|
||||||
|
))
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
pool.submit(*job)
|
||||||
|
|
||||||
|
# TODO: collect results
|
||||||
|
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
|
||||||
|
# needs to collect the previous outputs and put them on a grid. jobs write their own
|
||||||
|
# output to disk and do not return it, so that may need to read the images based on the
|
||||||
|
# output names assigned to each job. knowing when the jobs are done is the first problem.
|
||||||
|
|
||||||
|
# TODO: assemble grid
|
||||||
|
|
||||||
|
|
||||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_file = request.args.get("output", None)
|
output_file = request.args.get("output", None)
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -34,6 +34,157 @@ from .utils import get_model_path
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_json(
|
||||||
|
server: ServerContext,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
default_pipeline: str = "txt2img",
|
||||||
|
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
|
device = None
|
||||||
|
device_name = data.get("platform")
|
||||||
|
|
||||||
|
if device_name is not None and device_name != "any":
|
||||||
|
for platform in get_available_platforms():
|
||||||
|
if platform.device == device_name:
|
||||||
|
device = platform
|
||||||
|
|
||||||
|
# diffusion model
|
||||||
|
model = get_not_empty(data, "model", get_config_value("model"))
|
||||||
|
model_path = get_model_path(server, model)
|
||||||
|
|
||||||
|
# pipeline stuff
|
||||||
|
pipeline = get_from_list(
|
||||||
|
data, "pipeline", get_available_pipelines(), default_pipeline
|
||||||
|
)
|
||||||
|
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
|
||||||
|
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = get_config_value("scheduler")
|
||||||
|
|
||||||
|
# prompt does not come from config
|
||||||
|
prompt = data.get("prompt", "")
|
||||||
|
negative_prompt = data.get("negativePrompt", None)
|
||||||
|
|
||||||
|
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
# image params
|
||||||
|
batch = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"batch",
|
||||||
|
get_config_value("batch"),
|
||||||
|
get_config_value("batch", "max"),
|
||||||
|
get_config_value("batch", "min"),
|
||||||
|
)
|
||||||
|
cfg = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"cfg",
|
||||||
|
get_config_value("cfg"),
|
||||||
|
get_config_value("cfg", "max"),
|
||||||
|
get_config_value("cfg", "min"),
|
||||||
|
)
|
||||||
|
eta = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"eta",
|
||||||
|
get_config_value("eta"),
|
||||||
|
get_config_value("eta", "max"),
|
||||||
|
get_config_value("eta", "min"),
|
||||||
|
)
|
||||||
|
loopback = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"loopback",
|
||||||
|
get_config_value("loopback"),
|
||||||
|
get_config_value("loopback", "max"),
|
||||||
|
get_config_value("loopback", "min"),
|
||||||
|
)
|
||||||
|
steps = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"steps",
|
||||||
|
get_config_value("steps"),
|
||||||
|
get_config_value("steps", "max"),
|
||||||
|
get_config_value("steps", "min"),
|
||||||
|
)
|
||||||
|
height = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"height",
|
||||||
|
get_config_value("height"),
|
||||||
|
get_config_value("height", "max"),
|
||||||
|
get_config_value("height", "min"),
|
||||||
|
)
|
||||||
|
width = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"width",
|
||||||
|
get_config_value("width"),
|
||||||
|
get_config_value("width", "max"),
|
||||||
|
get_config_value("width", "min"),
|
||||||
|
)
|
||||||
|
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
|
||||||
|
tiles = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"tiles",
|
||||||
|
get_config_value("tiles"),
|
||||||
|
get_config_value("tiles", "max"),
|
||||||
|
get_config_value("tiles", "min"),
|
||||||
|
)
|
||||||
|
overlap = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"overlap",
|
||||||
|
get_config_value("overlap"),
|
||||||
|
get_config_value("overlap", "max"),
|
||||||
|
get_config_value("overlap", "min"),
|
||||||
|
)
|
||||||
|
stride = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"stride",
|
||||||
|
get_config_value("stride"),
|
||||||
|
get_config_value("stride", "max"),
|
||||||
|
get_config_value("stride", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if stride > tiles:
|
||||||
|
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
|
||||||
|
stride = tiles
|
||||||
|
|
||||||
|
seed = int(data.get("seed", -1))
|
||||||
|
if seed == -1:
|
||||||
|
# this one can safely use np.random because it produces a single value
|
||||||
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||||
|
steps,
|
||||||
|
scheduler,
|
||||||
|
model_path,
|
||||||
|
pipeline,
|
||||||
|
device or "any device",
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
cfg,
|
||||||
|
seed,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
params = ImageParams(
|
||||||
|
model_path,
|
||||||
|
pipeline,
|
||||||
|
scheduler,
|
||||||
|
prompt,
|
||||||
|
cfg,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
eta=eta,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
batch=batch,
|
||||||
|
# TODO: control=control,
|
||||||
|
loopback=loopback,
|
||||||
|
tiled_vae=tiled_vae,
|
||||||
|
tiles=tiles,
|
||||||
|
overlap=overlap,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
size = Size(width, height)
|
||||||
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(
|
def pipeline_from_request(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
default_pipeline: str = "txt2img",
|
default_pipeline: str = "txt2img",
|
||||||
|
|
Loading…
Reference in New Issue