1
0
Fork 0
onnx-web/api/onnx_web/server/api.py

757 lines
22 KiB
Python

from io import BytesIO
from logging import getLogger
from os import path
from typing import Any, Dict, List, Optional
from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
from ..chain import CHAIN_STAGES, ChainPipeline
from ..chain.result import ImageMetadata, StageResult
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..diffusers.utils import replace_wildcards
from ..output import make_job_name
from ..params import Size, StageParams, TileOrder, get_size
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
get_and_clamp_float,
get_and_clamp_int,
get_boolean,
get_from_list,
get_from_map,
get_list,
get_not_empty,
load_config,
load_config_str,
sanitize_name,
)
from ..worker.command import JobStatus, JobType, Progress
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .load import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_extra_strings,
get_mask_filters,
get_network_models,
get_noise_sources,
get_prompt_filters,
get_source_filters,
get_upscaling_models,
get_wildcard_data,
)
from .params import (
build_border,
build_upscale,
get_request_data,
get_request_params,
pipeline_from_json,
)
from .utils import wrap_route
logger = getLogger(__name__)
def ready_reply(
ready: bool = False,
cancelled: bool = False,
failed: bool = False,
pending: bool = False,
progress: int = 0,
):
return jsonify(
{
"cancelled": cancelled,
"failed": failed,
"pending": pending,
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
EMPTY_PROGRESS = Progress(0, 0)
def job_reply(name: str, queue: int = 0):
return jsonify(
{
"name": name,
"queue": Progress(queue, queue).tojson(),
"status": JobStatus.PENDING,
"stages": EMPTY_PROGRESS.tojson(),
"steps": EMPTY_PROGRESS.tojson(),
"tiles": EMPTY_PROGRESS.tojson(),
}
)
def image_reply(
server: ServerContext,
name: str,
status: str,
queue: Progress = None,
stages: Progress = None,
steps: Progress = None,
tiles: Progress = None,
metadata: Optional[List[ImageMetadata]] = None,
outputs: Optional[List[str]] = None,
thumbnails: Optional[List[str]] = None,
reason: Optional[str] = None,
) -> Dict[str, Any]:
if queue is None:
queue = EMPTY_PROGRESS
if stages is None:
stages = EMPTY_PROGRESS
if steps is None:
steps = EMPTY_PROGRESS
if tiles is None:
tiles = EMPTY_PROGRESS
data = {
"name": name,
"status": status,
"queue": queue.tojson(),
"stages": stages.tojson(),
"steps": steps.tojson(),
"tiles": tiles.tojson(),
}
if reason is not None:
data["reason"] = reason
if outputs is not None:
if metadata is None:
logger.error("metadata is required with outputs")
return error_reply("metadata is required with outputs")
if len(metadata) != len(outputs):
logger.error("metadata and outputs must be the same length")
return error_reply("metadata and outputs must be the same length")
data["metadata"] = [m.tojson(server, [o]) for m, o in zip(metadata, outputs)]
data["outputs"] = outputs
if thumbnails is not None:
if len(thumbnails) != len(outputs):
logger.error("thumbnails and outputs must be the same length")
return error_reply("thumbnails and outputs must be the same length")
data["thumbnails"] = thumbnails
return data
def multi_image_reply(results: List[Dict[str, Any]]):
return jsonify(
{
"results": results,
}
)
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
options[arg] = ":%s" % (arg)
return url_for(rule.endpoint, **options)
def introspect(server: ServerContext, app: Flask):
return {
"name": "onnx-web",
"routes": [
{
"path": url_from_rule(rule),
"methods": list(rule.methods or []),
}
for rule in app.url_map.iter_rules()
],
}
def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings())
def list_filters(server: ServerContext):
mask_filters = list(get_mask_filters().keys())
prompt_filters = get_prompt_filters()
source_filters = list(get_source_filters().keys())
return jsonify(
{
"mask": mask_filters,
"prompt": prompt_filters,
"source": source_filters,
}
)
def list_mask_filters(server: ServerContext):
logger.info("dedicated list endpoint for mask filters is deprecated")
return jsonify(list(get_mask_filters().keys()))
def list_models(server: ServerContext):
return jsonify(
{
"correction": get_correction_models(),
"diffusion": get_diffusion_models(),
"networks": [model.tojson() for model in get_network_models()],
"upscaling": get_upscaling_models(),
}
)
def list_noise_sources(server: ServerContext):
return jsonify(list(get_noise_sources().keys()))
def list_params(server: ServerContext):
return jsonify(get_config_params())
def list_pipelines(server: ServerContext):
return jsonify(get_available_pipelines())
def list_platforms(server: ServerContext):
return jsonify([p.device for p in get_available_platforms()])
def list_schedulers(server: ServerContext):
return jsonify(get_pipeline_schedulers())
def list_wildcards(server: ServerContext):
return jsonify(list(get_wildcard_data().keys()))
def img2img(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
data = get_request_data()
data_params = data.get("params", data)
source_filter = get_from_list(
data_params, "sourceFilter", list(get_source_filters().keys())
)
strength = get_and_clamp_float(
data_params,
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
params = get_request_params(server, JobType.IMG2IMG.value)
params.size = Size(source.width, source.height)
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name(
JobType.IMG2IMG.value, params, params.size, extras=[strength]
)
queue = pool.submit(
job_name,
JobType.IMG2IMG,
run_img2img_pipeline,
server,
params,
source,
strength,
needs_device=params.device,
source_filter=source_filter,
)
logger.info("img2img job queued for: %s", job_name)
return job_reply(job_name, queue=queue)
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
params = get_request_params(server, JobType.TXT2IMG.value)
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size)
queue = pool.submit(
job_name,
JobType.TXT2IMG,
run_txt2img_pipeline,
server,
params,
needs_device=params.device,
)
logger.info("txt2img job queued for: %s", job_name)
return job_reply(job_name, queue=queue)
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
size = Size(source.width, source.height)
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
mask.alpha_composite(mask_top_layer)
mask.convert(mode="L")
data = get_request_data()
data_params = data.get("params", data)
full_res_inpaint = get_boolean(
data_params, "fullresInpaint", get_config_value("fullresInpaint")
)
full_res_inpaint_padding = get_and_clamp_float(
data_params,
"fullresInpaintPadding",
get_config_value("fullresInpaintPadding"),
get_config_value("fullresInpaintPadding", "max"),
get_config_value("fullresInpaintPadding", "min"),
)
params = get_request_params(server, JobType.INPAINT.value)
replace_wildcards(params.image, get_wildcard_data())
fill_color = get_not_empty(data_params, "fillColor", "white")
mask_filter = get_from_map(data_params, "filter", get_mask_filters(), "none")
noise_source = get_from_map(data_params, "noise", get_noise_sources(), "histogram")
tile_order = get_from_list(
data_params,
"tileOrder",
[TileOrder.grid, TileOrder.kernel, TileOrder.spiral],
)
job_name = make_job_name(
JobType.INPAINT.value,
params,
size,
extras=[
params.border.left,
params.border.right,
params.border.top,
params.border.bottom,
mask_filter.__name__,
noise_source.__name__,
fill_color,
tile_order,
],
)
queue = pool.submit(
job_name,
JobType.INPAINT,
run_inpaint_pipeline,
server,
params,
source,
mask,
noise_source,
mask_filter,
fill_color,
tile_order,
full_res_inpaint,
full_res_inpaint_padding,
needs_device=params.device,
)
logger.info("inpaint job queued for: %s", job_name)
return job_reply(job_name, queue=queue)
def upscale(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
params = get_request_params(server, JobType.UPSCALE.value)
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name("upscale", params.image, params.size)
queue = pool.submit(
job_name,
JobType.UPSCALE,
run_upscale_pipeline,
server,
params,
source,
needs_device=params.device,
)
logger.info("upscale job queued for: %s", job_name)
return job_reply(job_name, queue=queue)
# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]
def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json:
logger.debug("chain pipeline request with JSON body")
data = request.get_json()
else:
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = load_config_str(body)
schema = load_config("./schemas/chain.yaml")
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
device, base_params, base_size = pipeline_from_json(
server, data=data.get("defaults")
)
# start building the pipeline
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs: Dict[str, Any] = stage_data.get("params", {})
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
# TODO: combine base params with stage params
_device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
# remove parsed keys, like model names (which become paths)
for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
kwargs.pop(pop_key)
if "seed" in kwargs and kwargs["seed"] == -1:
kwargs.pop("seed")
# replace kwargs with parsed versions
kwargs["params"] = params
kwargs["size"] = size
border = build_border(kwargs)
kwargs["border"] = border
upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale
# prepare the stage metadata
stage = StageParams(
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
# load any images related to this stage
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
source_file = request.files.get(stage_source_name)
if source_file is not None:
source = Image.open(BytesIO(source_file.read())).convert("RGB")
kwargs["stage_source"] = source
if stage_mask_name in request.files:
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
mask_file = request.files.get(stage_mask_name)
if mask_file is not None:
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
kwargs["stage_mask"] = mask
pipeline.append((stage_class(), stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
job_name = make_job_name("chain", base_params, base_size)
# build and run chain pipeline
queue = pool.submit(
job_name,
JobType.CHAIN,
pipeline,
server,
base_params,
StageResult.empty(),
size=base_size,
needs_device=device,
)
return job_reply(job_name, queue=queue)
def blend(server: ServerContext, pool: DevicePoolExecutor):
mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
max_sources = 2
sources = []
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
if source_file is None:
logger.warning("missing source %s", i)
else:
source = Image.open(BytesIO(source_file.read())).convert("RGB")
sources.append(source)
params = get_request_params(server)
job_name = make_job_name("blend", params.image, params.size)
queue = pool.submit(
job_name,
JobType.BLEND,
run_blend_pipeline,
server,
params,
sources,
mask,
needs_device=params.device,
)
logger.info("upscale job queued for: %s", job_name)
return job_reply(job_name, queue=queue)
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
params = get_request_params(server)
job_name = make_job_name("txt2txt", params.image, params.size)
logger.info("upscale job queued for: %s", job_name)
queue = pool.submit(
job_name,
JobType.TXT2TXT,
run_txt2txt_pipeline,
server,
params,
needs_device=params.device,
)
return job_reply(job_name, queue=queue)
def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
output_file = sanitize_name(output_file)
cancelled = pool.cancel(output_file)
return ready_reply(cancelled=cancelled)
def ready(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
output_file = sanitize_name(output_file)
status, progress, _queue = pool.status(output_file)
if status == JobStatus.PENDING:
return ready_reply(pending=True)
if progress is None:
output = base_join(server.output_path, output_file)
if path.exists(output):
return ready_reply(ready=True)
else:
return ready_reply(
ready=True,
failed=True,
) # is a missing image really an error? yes will display the retry button
return ready_reply(
ready=(status == JobStatus.SUCCESS),
progress=progress.steps.current,
failed=(status == JobStatus.FAILED),
cancelled=(status == JobStatus.CANCELLED),
)
def job_create(server: ServerContext, pool: DevicePoolExecutor):
return chain(server, pool)
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = get_list(request.args, "jobs")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
elif len(job_list) > 10:
return error_reply("too many jobs")
results: List[Dict[str, str]] = []
for job_name in job_list:
job_name = sanitize_name(job_name)
cancelled = pool.cancel(job_name)
results.append(
{
"name": job_name,
"status": JobStatus.CANCELLED if cancelled else JobStatus.PENDING,
}
)
return multi_image_reply(results)
def job_status(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = get_list(request.args, "jobs")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
elif len(job_list) > 10:
return error_reply("too many jobs")
records = []
for job_name in job_list:
job_name = sanitize_name(job_name)
status, progress, queue = pool.status(job_name)
if progress is not None:
metadata = None
outputs = None
thumbnails = None
if progress.result is not None:
metadata = progress.result.metadata
outputs = progress.result.outputs
thumbnails = progress.result.thumbnails
records.append(
image_reply(
server,
job_name,
status,
stages=progress.stages,
steps=progress.steps,
tiles=progress.tiles,
metadata=metadata,
outputs=outputs,
thumbnails=thumbnails,
reason=progress.reason,
)
)
else:
records.append(image_reply(server, job_name, status, queue=queue))
return jsonify(records)
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, server, app=app)),
# job routes
app.route("/api/job", methods=["POST"])(
wrap_route(job_create, server, pool=pool)
),
app.route("/api/job/cancel", methods=["PUT"])(
wrap_route(job_cancel, server, pool=pool)
),
app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)),
# settings routes
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
app.route("/api/settings/params")(wrap_route(list_params, server)),
app.route("/api/settings/pipelines")(wrap_route(list_pipelines, server)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
# legacy job routes
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, server, pool=pool)
),
app.route("/api/txt2img", methods=["POST"])(
wrap_route(txt2img, server, pool=pool)
),
app.route("/api/txt2txt", methods=["POST"])(
wrap_route(txt2txt, server, pool=pool)
),
app.route("/api/inpaint", methods=["POST"])(
wrap_route(inpaint, server, pool=pool)
),
app.route("/api/upscale", methods=["POST"])(
wrap_route(upscale, server, pool=pool)
),
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
# deprecated routes
app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, server, pool=pool)
),
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
]