1
0
Fork 0
onnx-web/api/onnx_web/serve.py

746 lines
20 KiB
Python
Raw Normal View History

2023-02-05 13:53:26 +00:00
import gc
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import List, Tuple
import numpy as np
import torch
import yaml
from diffusers import (
DDIMScheduler,
2023-01-05 23:23:37 +00:00
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
2023-02-05 13:53:26 +00:00
EulerDiscreteScheduler,
HeunDiscreteScheduler,
2023-02-05 13:53:26 +00:00
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
2023-01-05 23:23:37 +00:00
LMSDiscreteScheduler,
PNDMScheduler,
)
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
2023-01-14 16:18:53 +00:00
from flask_cors import CORS
from jsonschema import validate
from onnxruntime import get_available_providers
2023-02-05 13:53:26 +00:00
from PIL import Image
2023-01-28 23:09:19 +00:00
from .chain import (
2023-02-05 13:53:26 +00:00
ChainPipeline,
blend_img2img,
blend_inpaint,
correct_codeformer,
correct_gfpgan,
2023-01-28 15:08:59 +00:00
persist_disk,
persist_s3,
reduce_crop,
2023-02-05 13:53:26 +00:00
reduce_thumbnail,
source_noise,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
)
2023-02-05 13:53:26 +00:00
from .device_pool import DevicePoolExecutor
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
2023-02-05 13:53:26 +00:00
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
2023-02-05 13:53:26 +00:00
from .output import json_params, make_output_name
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
2023-01-16 00:54:20 +00:00
from .utils import (
2023-02-05 13:53:26 +00:00
ServerContext,
2023-02-02 14:33:33 +00:00
base_join,
2023-01-16 00:54:20 +00:00
get_and_clamp_float,
get_and_clamp_int,
2023-01-17 02:10:52 +00:00
get_from_list,
2023-01-16 00:54:20 +00:00
get_from_map,
get_not_empty,
2023-01-29 05:06:25 +00:00
get_size,
2023-02-05 13:53:26 +00:00
is_debug,
2023-01-16 00:04:10 +00:00
)
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
2023-01-05 00:25:00 +00:00
2023-01-28 23:09:19 +00:00
# config caching
config_params = {}
# pipeline params
platform_providers = {
2023-02-05 13:53:26 +00:00
"amd": "DmlExecutionProvider",
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"nvidia": "CUDAExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
2023-01-05 23:23:37 +00:00
pipeline_schedulers = {
2023-02-05 13:53:26 +00:00
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
"dpm-multi": DPMSolverMultistepScheduler,
"dpm-single": DPMSolverSinglestepScheduler,
"euler": EulerDiscreteScheduler,
"euler-a": EulerAncestralDiscreteScheduler,
"heun": HeunDiscreteScheduler,
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
}
noise_sources = {
2023-02-05 13:53:26 +00:00
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,
"gaussian": noise_source_gaussian,
"histogram": noise_source_histogram,
"normal": noise_source_normal,
"uniform": noise_source_uniform,
}
mask_filters = {
2023-02-05 13:53:26 +00:00
"none": mask_filter_none,
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
chain_stages = {
2023-02-05 13:53:26 +00:00
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"correct-codeformer": correct_codeformer,
2023-02-05 13:53:26 +00:00
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
# Available ORT providers
available_platforms: List[DeviceParams] = []
2023-01-17 02:10:52 +00:00
# loaded from model_path
diffusion_models = []
correction_models = []
upscaling_models = []
2023-01-16 20:52:56 +00:00
2023-02-05 13:53:26 +00:00
def get_config_value(key: str, subkey: str = "default"):
return config_params.get(key).get(subkey)
2023-01-16 01:33:40 +00:00
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
options[arg] = ":%s" % (arg)
return url_for(rule.endpoint, **options)
2023-01-11 05:00:18 +00:00
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
2023-01-11 05:00:18 +00:00
user = request.remote_addr
# platform stuff
2023-02-05 13:53:26 +00:00
device_name = request.args.get("platform", available_platforms[0].device)
device = None
for platform in available_platforms:
if platform.device == device_name:
device = available_platforms[0]
if device is None:
2023-02-05 21:33:56 +00:00
logger.warn("unknown platform: %s", device_name)
device = available_platforms[0]
2023-01-11 05:00:18 +00:00
# pipeline stuff
2023-02-05 23:15:37 +00:00
lpw = get_not_empty(request.args, "lpw", "false") == "true"
2023-02-05 13:53:26 +00:00
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model)
2023-02-05 13:53:26 +00:00
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
2023-01-11 05:00:18 +00:00
# image params
2023-02-05 13:53:26 +00:00
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
negative_prompt = request.args.get("negativePrompt", None)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
if negative_prompt is not None and negative_prompt.strip() == "":
2023-01-11 05:00:18 +00:00
negative_prompt = None
cfg = get_and_clamp_float(
2023-02-05 13:53:26 +00:00
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
steps = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
2023-01-11 05:00:18 +00:00
height = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
seed = int(request.args.get("seed", -1))
2023-01-11 05:00:18 +00:00
if seed == -1:
# this one can safely use np.random because it produces a single value
2023-01-11 05:00:18 +00:00
seed = np.random.randint(np.iinfo(np.int32).max)
2023-02-05 13:53:26 +00:00
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler.__name__,
model_path,
device.provider,
width,
height,
cfg,
seed,
prompt,
)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
params = ImageParams(
2023-02-05 23:15:37 +00:00
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw
2023-02-05 13:53:26 +00:00
)
size = Size(width, height)
return (device, params, size)
def border_from_request() -> Border:
2023-02-05 13:53:26 +00:00
left = get_and_clamp_int(
request.args, "left", 0, get_config_value("width", "max"), 0
)
right = get_and_clamp_int(
request.args, "right", 0, get_config_value("width", "max"), 0
)
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
2023-02-05 13:53:26 +00:00
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
2023-01-17 02:10:52 +00:00
2023-01-16 20:52:56 +00:00
return UpscaleParams(
2023-01-17 02:10:52 +00:00
upscaling,
correction_model=correction,
2023-01-16 20:52:56 +00:00
denoise=denoise,
faces=faces,
face_strength=face_strength,
2023-02-05 13:53:26 +00:00
format="onnx",
outscale=outscale,
scale=scale,
2023-01-16 20:52:56 +00:00
)
def check_paths(context: ServerContext):
if not path.exists(context.model_path):
2023-02-05 13:53:26 +00:00
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
def get_model_name(model: str) -> str:
base = path.basename(model)
(file, _ext) = path.splitext(base)
return file
def load_models(context: ServerContext):
2023-01-17 02:10:52 +00:00
global diffusion_models
global correction_models
global upscaling_models
2023-02-05 13:53:26 +00:00
diffusion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
]
diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
2023-01-17 02:10:52 +00:00
correction_models = [
2023-02-05 13:53:26 +00:00
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
upscaling_models = [
2023-02-05 13:53:26 +00:00
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
def load_params(context: ServerContext):
global config_params
2023-02-05 13:53:26 +00:00
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
2023-01-28 23:09:19 +00:00
config_params = yaml.safe_load(f)
2023-02-05 13:53:26 +00:00
if "platform" in config_params and context.default_platform is not None:
logger.info("overriding default platform to %s", context.default_platform)
config_platform = config_params.get("platform")
config_platform["default"] = context.default_platform
def load_platforms():
global available_platforms
providers = get_available_providers()
for potential in platform_providers:
2023-02-05 13:53:26 +00:00
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
2023-02-05 13:53:26 +00:00
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
{
"device_id": i,
},
)
)
else:
2023-02-05 13:53:26 +00:00
available_platforms.append(
DeviceParams(potential, platform_providers[potential])
)
# make sure CPU is last on the list
def cpu_last(a: DeviceParams, b: DeviceParams):
2023-02-05 13:53:26 +00:00
if a.device == "cpu" and b.device == "cpu":
return 0
2023-02-05 13:53:26 +00:00
if a.device == "cpu":
return 1
return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
2023-02-05 13:53:26 +00:00
logger.info(
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)
context = ServerContext.from_environ()
check_paths(context)
load_models(context)
load_params(context)
load_platforms()
2023-01-05 01:42:37 +00:00
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
executor = DevicePoolExecutor(available_platforms)
2023-01-05 00:25:00 +00:00
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
def ready_reply(ready: bool, progress: int = 0):
2023-02-05 13:53:26 +00:00
return jsonify(
{
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
2023-02-05 13:53:26 +00:00
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
def get_model_path(model: str):
return base_join(context.model_path, model)
2023-02-05 13:53:26 +00:00
def serve_bundle_file(filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
2023-01-05 01:42:37 +00:00
# routes
2023-02-05 13:53:26 +00:00
@app.route("/")
def index():
return serve_bundle_file()
2023-01-13 04:10:46 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/<path:filename>")
2023-01-13 04:10:46 +00:00
def index_path(filename):
return serve_bundle_file(filename)
2023-01-13 04:10:46 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api")
2023-01-13 04:10:46 +00:00
def introspect():
return {
2023-02-05 13:53:26 +00:00
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
for rule in app.url_map.iter_rules()
],
}
2023-01-05 00:25:00 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/masks")
def list_mask_filters():
return jsonify(list(mask_filters.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/models")
2023-01-06 04:01:58 +00:00
def list_models():
2023-02-05 13:53:26 +00:00
return jsonify(
{
"diffusion": diffusion_models,
"correction": correction_models,
"upscaling": upscaling_models,
}
)
2023-01-06 04:01:58 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/noises")
def list_noise_sources():
return jsonify(list(noise_sources.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/params")
def list_params():
2023-01-14 16:18:53 +00:00
return jsonify(config_params)
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/platforms")
def list_platforms():
return jsonify([p.device for p in available_platforms])
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/schedulers")
def list_schedulers():
2023-01-14 16:18:53 +00:00
return jsonify(list(pipeline_schedulers.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/img2img", methods=["POST"])
2023-01-07 21:05:29 +00:00
def img2img():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-07 21:05:29 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
2023-02-05 13:53:26 +00:00
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
2023-01-07 21:19:24 +00:00
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "img2img", params, size, extras=(strength,))
2023-02-02 14:31:35 +00:00
logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
2023-02-05 13:53:26 +00:00
executor.submit(
output,
run_img2img_pipeline,
context,
params,
output,
upscale,
source_image,
strength,
)
2023-01-07 21:05:29 +00:00
return jsonify(json_params(output, params, size, upscale=upscale))
2023-01-07 21:05:29 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/txt2img", methods=["POST"])
2023-01-07 21:05:29 +00:00
def txt2img():
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "txt2img", params, size)
2023-02-02 14:31:35 +00:00
logger.info("txt2img job queued for: %s", output)
2023-01-16 01:33:40 +00:00
executor.submit(
2023-02-05 13:53:26 +00:00
output, run_txt2img_pipeline, context, params, size, output, upscale
)
return jsonify(json_params(output, params, size, upscale=upscale))
2023-02-05 13:53:26 +00:00
@app.route("/api/inpaint", methods=["POST"])
2023-01-09 00:11:34 +00:00
def inpaint():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
if "mask" not in request.files:
return error_reply("mask image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-09 00:11:34 +00:00
2023-02-05 13:53:26 +00:00
mask_file = request.files.get("mask")
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
2023-01-09 00:11:34 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
expand = border_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-02-05 13:53:26 +00:00
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
strength = get_and_clamp_float(
request.args,
2023-02-05 13:53:26 +00:00
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
output = make_output_name(
context,
2023-02-05 13:53:26 +00:00
"inpaint",
2023-01-16 01:33:40 +00:00
params,
size,
extras=(
expand.left,
expand.right,
expand.top,
expand.bottom,
mask_filter.__name__,
noise_source.__name__,
strength,
fill_color,
2023-02-05 13:53:26 +00:00
),
2023-01-16 01:33:40 +00:00
)
2023-02-02 14:31:35 +00:00
logger.info("inpaint job queued for: %s", output)
2023-01-16 01:33:40 +00:00
source_image.thumbnail((size.width, size.height))
mask_image.thumbnail((size.width, size.height))
executor.submit(
output,
run_inpaint_pipeline,
context,
2023-01-16 01:33:40 +00:00
params,
size,
output,
upscale,
source_image,
mask_image,
2023-01-16 01:33:40 +00:00
expand,
noise_source,
mask_filter,
strength,
2023-02-05 13:53:26 +00:00
fill_color,
)
2023-01-09 00:11:34 +00:00
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
2023-01-09 00:11:34 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/upscale", methods=["POST"])
2023-01-17 05:45:54 +00:00
def upscale():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-17 05:45:54 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-01-17 05:45:54 +00:00
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
2023-01-17 05:45:54 +00:00
source_image.thumbnail((size.width, size.height))
2023-02-05 13:53:26 +00:00
executor.submit(
output,
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source_image,
)
2023-01-17 05:45:54 +00:00
return jsonify(json_params(output, params, size, upscale=upscale))
2023-01-17 05:45:54 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/chain", methods=["POST"])
def chain():
2023-02-05 13:53:26 +00:00
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
2023-02-04 20:24:18 +00:00
if body is None:
2023-02-05 13:53:26 +00:00
return error_reply("chain pipeline must have a body")
2023-02-04 20:24:18 +00:00
data = yaml.safe_load(body)
2023-02-05 13:53:26 +00:00
with open("./schema.yaml", "r") as f:
schema = yaml.safe_load(f.read())
2023-02-05 13:53:26 +00:00
logger.info("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "chain", params, size)
pipeline = ChainPipeline()
2023-02-05 13:53:26 +00:00
for stage_data in data.get("stages", []):
callback = chain_stages[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams(
2023-02-05 13:53:26 +00:00
stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
2023-02-05 13:53:26 +00:00
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
2023-02-05 13:53:26 +00:00
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
2023-02-05 13:53:26 +00:00
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
2023-02-05 13:53:26 +00:00
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
2023-02-04 20:24:18 +00:00
source_file = request.files.get(stage_source_name)
2023-02-05 13:53:26 +00:00
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image.thumbnail((size.width, size.height))
2023-02-05 13:53:26 +00:00
kwargs["source_image"] = source_image
if stage_mask_name in request.files:
2023-02-05 13:53:26 +00:00
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
2023-02-04 20:24:18 +00:00
mask_file = request.files.get(stage_mask_name)
2023-02-05 13:53:26 +00:00
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image.thumbnail((size.width, size.height))
2023-02-05 13:53:26 +00:00
kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs))
2023-02-05 13:53:26 +00:00
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
2023-01-28 15:08:59 +00:00
# build and run chain pipeline
2023-02-05 13:53:26 +00:00
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
output, pipeline, context, params, empty_source, output=output, size=size
)
return jsonify(json_params(output, params, size))
2023-02-05 13:53:26 +00:00
@app.route("/api/cancel", methods=["PUT"])
2023-02-04 16:59:03 +00:00
def cancel():
2023-02-05 13:53:26 +00:00
output_file = request.args.get("output", None)
2023-02-04 16:59:03 +00:00
cancel = executor.cancel(output_file)
return ready_reply(cancel)
2023-02-05 13:53:26 +00:00
@app.route("/api/ready")
def ready():
2023-02-05 13:53:26 +00:00
output_file = request.args.get("output", None)
done, progress = executor.done(output_file)
if done is None:
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)
return ready_reply(done, progress=progress)
2023-02-05 13:53:26 +00:00
@app.route("/api/status")
2023-02-04 16:59:03 +00:00
def status():
return jsonify(executor.status())
2023-02-05 13:53:26 +00:00
@app.route("/output/<path:filename>")
def output(filename: str):
2023-02-05 13:53:26 +00:00
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)