2023-02-26 16:15:12 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from logging import getLogger
|
|
|
|
from os import path
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
from flask import Flask, jsonify, make_response, request, url_for
|
|
|
|
from jsonschema import validate
|
|
|
|
from PIL import Image
|
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
from ..chain import CHAIN_STAGES, ChainPipeline
|
2023-02-26 16:15:12 +00:00
|
|
|
from ..diffusion.load import get_pipeline_schedulers
|
|
|
|
from ..diffusion.run import (
|
|
|
|
run_blend_pipeline,
|
|
|
|
run_img2img_pipeline,
|
|
|
|
run_inpaint_pipeline,
|
|
|
|
run_txt2img_pipeline,
|
|
|
|
run_upscale_pipeline,
|
|
|
|
)
|
2023-02-26 20:15:30 +00:00
|
|
|
from ..image import valid_image # mask filters; noise sources
|
2023-02-26 16:15:12 +00:00
|
|
|
from ..output import json_params, make_output_name
|
2023-02-26 20:15:30 +00:00
|
|
|
from ..params import Border, StageParams, TileOrder, UpscaleParams
|
2023-02-26 16:15:12 +00:00
|
|
|
from ..transformers import run_txt2txt_pipeline
|
|
|
|
from ..utils import (
|
|
|
|
base_join,
|
|
|
|
get_and_clamp_float,
|
|
|
|
get_and_clamp_int,
|
|
|
|
get_from_list,
|
|
|
|
get_from_map,
|
|
|
|
get_not_empty,
|
|
|
|
get_size,
|
|
|
|
)
|
2023-02-26 20:15:30 +00:00
|
|
|
from ..worker.pool import DevicePoolExecutor
|
|
|
|
from .config import (
|
|
|
|
get_available_platforms,
|
|
|
|
get_config_params,
|
|
|
|
get_config_value,
|
|
|
|
get_correction_models,
|
|
|
|
get_diffusion_models,
|
|
|
|
get_inversion_models,
|
|
|
|
get_mask_filters,
|
|
|
|
get_noise_sources,
|
|
|
|
get_upscaling_models,
|
|
|
|
)
|
|
|
|
from .context import ServerContext
|
|
|
|
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
|
|
|
from .utils import wrap_route
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def ready_reply(ready: bool, progress: int = 0):
|
|
|
|
return jsonify(
|
|
|
|
{
|
|
|
|
"progress": progress,
|
|
|
|
"ready": ready,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def error_reply(err: str):
|
|
|
|
response = make_response(
|
|
|
|
jsonify(
|
|
|
|
{
|
|
|
|
"error": err,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
response.status_code = 400
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
def url_from_rule(rule) -> str:
|
|
|
|
options = {}
|
|
|
|
for arg in rule.arguments:
|
|
|
|
options[arg] = ":%s" % (arg)
|
|
|
|
|
|
|
|
return url_for(rule.endpoint, **options)
|
|
|
|
|
|
|
|
|
|
|
|
def introspect(context: ServerContext, app: Flask):
|
|
|
|
return {
|
|
|
|
"name": "onnx-web",
|
|
|
|
"routes": [
|
|
|
|
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
|
|
|
for rule in app.url_map.iter_rules()
|
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def list_mask_filters(context: ServerContext):
|
|
|
|
return jsonify(list(get_mask_filters().keys()))
|
|
|
|
|
|
|
|
|
|
|
|
def list_models(context: ServerContext):
|
|
|
|
return jsonify(
|
|
|
|
{
|
|
|
|
"correction": get_correction_models(),
|
|
|
|
"diffusion": get_diffusion_models(),
|
|
|
|
"inversion": get_inversion_models(),
|
|
|
|
"upscaling": get_upscaling_models(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def list_noise_sources(context: ServerContext):
|
|
|
|
return jsonify(list(get_noise_sources().keys()))
|
|
|
|
|
|
|
|
|
|
|
|
def list_params(context: ServerContext):
|
|
|
|
return jsonify(get_config_params())
|
|
|
|
|
|
|
|
|
|
|
|
def list_platforms(context: ServerContext):
|
|
|
|
return jsonify([p.device for p in get_available_platforms()])
|
|
|
|
|
|
|
|
|
|
|
|
def list_schedulers(context: ServerContext):
|
|
|
|
return jsonify(list(get_pipeline_schedulers().keys()))
|
|
|
|
|
|
|
|
|
|
|
|
def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
if "source" not in request.files:
|
|
|
|
return error_reply("source image is required")
|
|
|
|
|
|
|
|
source_file = request.files.get("source")
|
|
|
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
|
|
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
upscale = upscale_from_request()
|
|
|
|
|
|
|
|
strength = get_and_clamp_float(
|
|
|
|
request.args,
|
|
|
|
"strength",
|
|
|
|
get_config_value("strength"),
|
|
|
|
get_config_value("strength", "max"),
|
|
|
|
get_config_value("strength", "min"),
|
|
|
|
)
|
|
|
|
|
|
|
|
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
|
|
|
job_name = output[0]
|
|
|
|
logger.info("img2img job queued for: %s", job_name)
|
|
|
|
|
|
|
|
source = valid_image(source, min_dims=size, max_dims=size)
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
run_img2img_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
output,
|
|
|
|
upscale,
|
|
|
|
source,
|
|
|
|
strength,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
|
|
|
|
|
|
|
|
|
|
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
upscale = upscale_from_request()
|
|
|
|
|
|
|
|
output = make_output_name(context, "txt2img", params, size)
|
|
|
|
job_name = output[0]
|
|
|
|
logger.info("txt2img job queued for: %s", job_name)
|
|
|
|
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
run_txt2img_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
output,
|
|
|
|
upscale,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
|
|
|
|
|
|
|
|
|
|
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
if "source" not in request.files:
|
|
|
|
return error_reply("source image is required")
|
|
|
|
|
|
|
|
if "mask" not in request.files:
|
|
|
|
return error_reply("mask image is required")
|
|
|
|
|
|
|
|
source_file = request.files.get("source")
|
|
|
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
|
|
|
|
|
|
mask_file = request.files.get("mask")
|
|
|
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
|
|
|
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
expand = border_from_request()
|
|
|
|
upscale = upscale_from_request()
|
|
|
|
|
|
|
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
|
|
|
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
|
|
|
noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram")
|
|
|
|
tile_order = get_from_list(
|
|
|
|
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
|
|
|
)
|
|
|
|
|
|
|
|
output = make_output_name(
|
|
|
|
context,
|
|
|
|
"inpaint",
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
extras=(
|
|
|
|
expand.left,
|
|
|
|
expand.right,
|
|
|
|
expand.top,
|
|
|
|
expand.bottom,
|
|
|
|
mask_filter.__name__,
|
|
|
|
noise_source.__name__,
|
|
|
|
fill_color,
|
|
|
|
tile_order,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
job_name = output[0]
|
|
|
|
logger.info("inpaint job queued for: %s", job_name)
|
|
|
|
|
|
|
|
source = valid_image(source, min_dims=size, max_dims=size)
|
|
|
|
mask = valid_image(mask, min_dims=size, max_dims=size)
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
run_inpaint_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
output,
|
|
|
|
upscale,
|
|
|
|
source,
|
|
|
|
mask,
|
|
|
|
expand,
|
|
|
|
noise_source,
|
|
|
|
mask_filter,
|
|
|
|
fill_color,
|
|
|
|
tile_order,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
|
|
|
|
|
|
|
|
|
|
|
def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
if "source" not in request.files:
|
|
|
|
return error_reply("source image is required")
|
|
|
|
|
|
|
|
source_file = request.files.get("source")
|
|
|
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
|
|
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
upscale = upscale_from_request()
|
|
|
|
|
|
|
|
output = make_output_name(context, "upscale", params, size)
|
|
|
|
job_name = output[0]
|
|
|
|
logger.info("upscale job queued for: %s", job_name)
|
|
|
|
|
|
|
|
source = valid_image(source, min_dims=size, max_dims=size)
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
run_upscale_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
output,
|
|
|
|
upscale,
|
|
|
|
source,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
|
|
|
|
|
|
|
|
|
|
def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
logger.debug(
|
|
|
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
|
|
|
)
|
|
|
|
body = request.form.get("chain") or request.files.get("chain")
|
|
|
|
if body is None:
|
|
|
|
return error_reply("chain pipeline must have a body")
|
|
|
|
|
|
|
|
data = yaml.safe_load(body)
|
|
|
|
with open("./schemas/chain.yaml", "r") as f:
|
|
|
|
schema = yaml.safe_load(f.read())
|
|
|
|
|
|
|
|
logger.debug("validating chain request: %s against %s", data, schema)
|
|
|
|
validate(data, schema)
|
|
|
|
|
|
|
|
# get defaults from the regular parameters
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
output = make_output_name(context, "chain", params, size)
|
|
|
|
job_name = output[0]
|
|
|
|
|
|
|
|
pipeline = ChainPipeline()
|
|
|
|
for stage_data in data.get("stages", []):
|
|
|
|
callback = CHAIN_STAGES[stage_data.get("type")]
|
|
|
|
kwargs = stage_data.get("params", {})
|
|
|
|
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
|
|
|
|
|
|
|
stage = StageParams(
|
|
|
|
stage_data.get("name", callback.__name__),
|
|
|
|
tile_size=get_size(kwargs.get("tile_size")),
|
|
|
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
|
|
|
)
|
|
|
|
|
|
|
|
if "border" in kwargs:
|
|
|
|
border = Border.even(int(kwargs.get("border")))
|
|
|
|
kwargs["border"] = border
|
|
|
|
|
|
|
|
if "upscale" in kwargs:
|
|
|
|
upscale = UpscaleParams(kwargs.get("upscale"))
|
|
|
|
kwargs["upscale"] = upscale
|
|
|
|
|
|
|
|
stage_source_name = "source:%s" % (stage.name)
|
|
|
|
stage_mask_name = "mask:%s" % (stage.name)
|
|
|
|
|
|
|
|
if stage_source_name in request.files:
|
|
|
|
logger.debug(
|
|
|
|
"loading source image %s for pipeline stage %s",
|
|
|
|
stage_source_name,
|
|
|
|
stage.name,
|
|
|
|
)
|
|
|
|
source_file = request.files.get(stage_source_name)
|
|
|
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
|
|
source = valid_image(source, max_dims=(size.width, size.height))
|
|
|
|
kwargs["stage_source"] = source
|
|
|
|
|
|
|
|
if stage_mask_name in request.files:
|
|
|
|
logger.debug(
|
|
|
|
"loading mask image %s for pipeline stage %s",
|
|
|
|
stage_mask_name,
|
|
|
|
stage.name,
|
|
|
|
)
|
|
|
|
mask_file = request.files.get(stage_mask_name)
|
|
|
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
|
|
|
mask = valid_image(mask, max_dims=(size.width, size.height))
|
|
|
|
kwargs["stage_mask"] = mask
|
|
|
|
|
|
|
|
pipeline.append((callback, stage, kwargs))
|
|
|
|
|
|
|
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
|
|
|
|
|
|
|
# build and run chain pipeline
|
|
|
|
empty_source = Image.new("RGB", (size.width, size.height))
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
empty_source,
|
|
|
|
output=output[0],
|
|
|
|
size=size,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size))
|
|
|
|
|
|
|
|
|
|
|
|
def blend(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
if "mask" not in request.files:
|
|
|
|
return error_reply("mask image is required")
|
|
|
|
|
|
|
|
mask_file = request.files.get("mask")
|
|
|
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
|
|
|
mask = valid_image(mask)
|
|
|
|
|
|
|
|
max_sources = 2
|
|
|
|
sources = []
|
|
|
|
|
|
|
|
for i in range(max_sources):
|
|
|
|
source_file = request.files.get("source:%s" % (i))
|
|
|
|
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
|
|
|
source = valid_image(source, mask.size, mask.size)
|
|
|
|
sources.append(source)
|
|
|
|
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
upscale = upscale_from_request()
|
|
|
|
|
|
|
|
output = make_output_name(context, "upscale", params, size)
|
|
|
|
job_name = output[0]
|
|
|
|
logger.info("upscale job queued for: %s", job_name)
|
|
|
|
|
|
|
|
pool.submit(
|
|
|
|
job_name,
|
|
|
|
run_blend_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
output,
|
|
|
|
upscale,
|
|
|
|
sources,
|
|
|
|
mask,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
|
|
|
|
|
|
|
|
|
|
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
device, params, size = pipeline_from_request(context)
|
|
|
|
|
2023-02-27 23:14:53 +00:00
|
|
|
output = make_output_name(context, "txt2txt", params, size)
|
2023-02-26 16:15:12 +00:00
|
|
|
logger.info("upscale job queued for: %s", output)
|
|
|
|
|
|
|
|
pool.submit(
|
|
|
|
output,
|
|
|
|
run_txt2txt_pipeline,
|
|
|
|
context,
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
output,
|
|
|
|
needs_device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
return jsonify(json_params(output, params, size))
|
|
|
|
|
|
|
|
|
|
|
|
def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
output_file = request.args.get("output", None)
|
|
|
|
|
|
|
|
cancel = pool.cancel(output_file)
|
|
|
|
|
|
|
|
return ready_reply(cancel)
|
|
|
|
|
|
|
|
|
|
|
|
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
output_file = request.args.get("output", None)
|
|
|
|
|
|
|
|
done, progress = pool.done(output_file)
|
|
|
|
|
|
|
|
if done is None:
|
|
|
|
output = base_join(context.output_path, output_file)
|
|
|
|
if path.exists(output):
|
|
|
|
return ready_reply(True)
|
|
|
|
|
|
|
|
return ready_reply(done, progress=progress)
|
|
|
|
|
|
|
|
|
|
|
|
def status(context: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
return jsonify(pool.status())
|
|
|
|
|
|
|
|
|
|
|
|
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
2023-02-26 20:15:30 +00:00
|
|
|
return [
|
|
|
|
app.route("/api")(wrap_route(introspect, context, app=app)),
|
|
|
|
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
|
|
|
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
|
|
|
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
|
|
|
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
|
|
|
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
|
|
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
|
|
|
app.route("/api/img2img", methods=["POST"])(
|
|
|
|
wrap_route(img2img, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/txt2img", methods=["POST"])(
|
|
|
|
wrap_route(txt2img, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/txt2txt", methods=["POST"])(
|
|
|
|
wrap_route(txt2txt, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/inpaint", methods=["POST"])(
|
|
|
|
wrap_route(inpaint, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/upscale", methods=["POST"])(
|
|
|
|
wrap_route(upscale, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/chain", methods=["POST"])(
|
|
|
|
wrap_route(chain, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/blend", methods=["POST"])(
|
|
|
|
wrap_route(blend, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/cancel", methods=["PUT"])(
|
|
|
|
wrap_route(cancel, context, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
|
|
|
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
|
|
|
]
|