1
0
Fork 0
onnx-web/api/onnx_web/serve.py

878 lines
24 KiB
Python
Raw Normal View History

2023-02-05 13:53:26 +00:00
import gc
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import Dict, List, Tuple, Union
2023-02-05 13:53:26 +00:00
import numpy as np
import torch
import yaml
from diffusers.utils.logging import disable_progress_bar
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
2023-01-14 16:18:53 +00:00
from flask_cors import CORS
2023-02-18 15:37:27 +00:00
from huggingface_hub.utils.tqdm import disable_progress_bars
from jsonschema import validate
from onnxruntime import get_available_providers
2023-02-05 13:53:26 +00:00
from PIL import Image
2023-01-28 23:09:19 +00:00
from .chain import (
2023-02-05 13:53:26 +00:00
ChainPipeline,
blend_img2img,
blend_inpaint,
correct_codeformer,
correct_gfpgan,
2023-01-28 15:08:59 +00:00
persist_disk,
persist_s3,
reduce_crop,
2023-02-05 13:53:26 +00:00
reduce_thumbnail,
source_noise,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
2023-02-05 13:53:26 +00:00
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
2023-02-05 13:53:26 +00:00
from .output import json_params, make_output_name
2023-02-12 00:10:36 +00:00
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
TileOrder,
UpscaleParams,
)
2023-02-19 02:28:21 +00:00
from .server import DevicePoolExecutor, ServerContext, apply_patches
2023-02-14 13:40:06 +00:00
from .transformers import run_txt2txt_pipeline
2023-01-16 00:54:20 +00:00
from .utils import (
2023-02-02 14:33:33 +00:00
base_join,
2023-01-16 00:54:20 +00:00
get_and_clamp_float,
get_and_clamp_int,
2023-01-17 02:10:52 +00:00
get_from_list,
2023-01-16 00:54:20 +00:00
get_from_map,
get_not_empty,
2023-01-29 05:06:25 +00:00
get_size,
2023-02-05 13:53:26 +00:00
is_debug,
2023-01-16 00:04:10 +00:00
)
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
2023-01-05 00:25:00 +00:00
2023-01-28 23:09:19 +00:00
# config caching
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
# pipeline params
platform_providers = {
2023-02-05 13:53:26 +00:00
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
noise_sources = {
2023-02-05 13:53:26 +00:00
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,
"gaussian": noise_source_gaussian,
"histogram": noise_source_histogram,
"normal": noise_source_normal,
"uniform": noise_source_uniform,
}
mask_filters = {
2023-02-05 13:53:26 +00:00
"none": mask_filter_none,
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
chain_stages = {
2023-02-05 13:53:26 +00:00
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"correct-codeformer": correct_codeformer,
2023-02-05 13:53:26 +00:00
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
# Available ORT providers
available_platforms: List[DeviceParams] = []
2023-01-17 02:10:52 +00:00
# loaded from model_path
correction_models: List[str] = []
diffusion_models: List[str] = []
inversion_models: List[str] = []
upscaling_models: List[str] = []
2023-01-16 20:52:56 +00:00
2023-02-05 23:55:04 +00:00
def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default)
2023-01-16 01:33:40 +00:00
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
options[arg] = ":%s" % (arg)
return url_for(rule.endpoint, **options)
2023-01-11 05:00:18 +00:00
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
2023-01-11 05:00:18 +00:00
user = request.remote_addr
# platform stuff
device = None
device_name = request.args.get("platform")
if device_name is not None and device_name != "any":
for platform in available_platforms:
if platform.device == device_name:
2023-02-11 22:26:00 +00:00
device = platform
2023-01-11 05:00:18 +00:00
# pipeline stuff
2023-02-05 23:15:37 +00:00
lpw = get_not_empty(request.args, "lpw", "false") == "true"
2023-02-05 13:53:26 +00:00
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model)
2023-02-05 13:53:26 +00:00
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
inversion = get_not_empty(request.args, "inversion", get_config_value("inversion"))
inversion_path = get_model_path(inversion)
2023-01-11 05:00:18 +00:00
# image params
2023-02-05 13:53:26 +00:00
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
negative_prompt = request.args.get("negativePrompt", None)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
if negative_prompt is not None and negative_prompt.strip() == "":
2023-01-11 05:00:18 +00:00
negative_prompt = None
2023-02-21 03:46:23 +00:00
batch = get_and_clamp_int(
request.args,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
2023-02-05 13:53:26 +00:00
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
2023-02-20 05:29:26 +00:00
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
steps = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
2023-01-11 05:00:18 +00:00
height = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
seed = int(request.args.get("seed", -1))
2023-01-11 05:00:18 +00:00
if seed == -1:
# this one can safely use np.random because it produces a single value
2023-01-11 05:00:18 +00:00
seed = np.random.randint(np.iinfo(np.int32).max)
2023-02-05 13:53:26 +00:00
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler.__name__,
model_path,
device or "any device",
2023-02-05 13:53:26 +00:00
width,
height,
cfg,
seed,
prompt,
)
2023-01-11 05:00:18 +00:00
2023-02-05 13:53:26 +00:00
params = ImageParams(
2023-02-05 23:55:04 +00:00
model_path,
scheduler,
prompt,
cfg,
steps,
seed,
2023-02-20 05:29:26 +00:00
eta=eta,
2023-02-05 23:55:04 +00:00
lpw=lpw,
negative_prompt=negative_prompt,
2023-02-21 03:46:23 +00:00
batch=batch,
inversion=inversion_path,
2023-02-05 13:53:26 +00:00
)
size = Size(width, height)
return (device, params, size)
def border_from_request() -> Border:
2023-02-05 13:53:26 +00:00
left = get_and_clamp_int(
request.args, "left", 0, get_config_value("width", "max"), 0
)
right = get_and_clamp_int(
request.args, "right", 0, get_config_value("width", "max"), 0
)
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int(
2023-02-05 13:53:26 +00:00
request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
2023-02-05 13:53:26 +00:00
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
2023-02-05 13:53:26 +00:00
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
upscale_order = request.args.get("upscaleOrder", "correction-first")
2023-01-17 02:10:52 +00:00
2023-01-16 20:52:56 +00:00
return UpscaleParams(
2023-01-17 02:10:52 +00:00
upscaling,
correction_model=correction,
2023-01-16 20:52:56 +00:00
denoise=denoise,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
2023-02-05 13:53:26 +00:00
format="onnx",
outscale=outscale,
scale=scale,
upscale_order=upscale_order,
2023-01-16 20:52:56 +00:00
)
def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path):
2023-02-05 13:53:26 +00:00
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
def get_model_name(model: str) -> str:
base = path.basename(model)
(file, _ext) = path.splitext(base)
return file
def load_models(context: ServerContext) -> None:
2023-01-17 02:10:52 +00:00
global correction_models
global diffusion_models
global inversion_models
2023-01-17 02:10:52 +00:00
global upscaling_models
2023-02-05 13:53:26 +00:00
diffusion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
]
diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
2023-01-17 02:10:52 +00:00
correction_models = [
2023-02-05 13:53:26 +00:00
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
]
inversion_models = list(set(inversion_models))
inversion_models.sort()
upscaling_models = [
2023-02-05 13:53:26 +00:00
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
def load_params(context: ServerContext) -> None:
global config_params
2023-02-05 13:53:26 +00:00
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
2023-01-28 23:09:19 +00:00
config_params = yaml.safe_load(f)
2023-02-05 13:53:26 +00:00
if "platform" in config_params and context.default_platform is not None:
2023-02-11 22:50:57 +00:00
logger.info(
"Overriding default platform from environment: %s",
context.default_platform,
)
config_platform = config_params.get("platform", {})
2023-02-05 13:53:26 +00:00
config_platform["default"] = context.default_platform
def load_platforms(context: ServerContext) -> None:
global available_platforms
2023-02-11 22:06:14 +00:00
providers = list(get_available_providers())
2023-02-11 21:53:27 +00:00
for potential in platform_providers:
2023-02-05 13:53:26 +00:00
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
2023-02-05 13:53:26 +00:00
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
{
"device_id": i,
},
context.optimizations,
2023-02-05 13:53:26 +00:00
)
)
else:
2023-02-05 13:53:26 +00:00
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
2023-02-05 13:53:26 +00:00
)
if context.any_platform:
2023-02-11 22:06:14 +00:00
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)
# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == b.device:
return 0
# any should be first, if it's available
if a.device == "any":
return -1
# cpu should be last, if it's available
2023-02-05 13:53:26 +00:00
if a.device == "cpu":
return 1
return -1
2023-02-11 22:50:57 +00:00
available_platforms = sorted(
available_platforms, key=cmp_to_key(any_first_cpu_last)
)
2023-02-05 13:53:26 +00:00
logger.info(
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
2023-02-11 21:53:27 +00:00
load_platforms(context)
if not context.show_progress:
disable_progress_bar()
2023-02-18 15:37:27 +00:00
disable_progress_bars()
2023-01-05 01:42:37 +00:00
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
2023-02-11 22:16:40 +00:00
# any is a fake device, should not be in the pool
executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"])
2023-01-05 00:25:00 +00:00
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
def ready_reply(ready: bool, progress: int = 0):
2023-02-05 13:53:26 +00:00
return jsonify(
{
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
2023-02-05 13:53:26 +00:00
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
def get_model_path(model: str):
return base_join(context.model_path, model)
2023-02-05 13:53:26 +00:00
def serve_bundle_file(filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
2023-01-05 01:42:37 +00:00
# routes
2023-02-05 13:53:26 +00:00
@app.route("/")
def index():
return serve_bundle_file()
2023-01-13 04:10:46 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/<path:filename>")
2023-01-13 04:10:46 +00:00
def index_path(filename):
return serve_bundle_file(filename)
2023-01-13 04:10:46 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api")
2023-01-13 04:10:46 +00:00
def introspect():
return {
2023-02-05 13:53:26 +00:00
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
for rule in app.url_map.iter_rules()
],
}
2023-01-05 00:25:00 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/masks")
def list_mask_filters():
return jsonify(list(mask_filters.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/models")
2023-01-06 04:01:58 +00:00
def list_models():
2023-02-05 13:53:26 +00:00
return jsonify(
{
"correction": correction_models,
"diffusion": diffusion_models,
"inversion": inversion_models,
2023-02-05 13:53:26 +00:00
"upscaling": upscaling_models,
}
)
2023-01-06 04:01:58 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/noises")
def list_noise_sources():
return jsonify(list(noise_sources.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/params")
def list_params():
2023-01-14 16:18:53 +00:00
return jsonify(config_params)
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/platforms")
def list_platforms():
return jsonify([p.device for p in available_platforms])
2023-02-05 13:53:26 +00:00
@app.route("/api/settings/schedulers")
def list_schedulers():
2023-01-14 16:18:53 +00:00
return jsonify(list(pipeline_schedulers.keys()))
2023-02-05 13:53:26 +00:00
@app.route("/api/img2img", methods=["POST"])
2023-01-07 21:05:29 +00:00
def img2img():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-07 21:05:29 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
2023-02-05 13:53:26 +00:00
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
2023-01-07 21:19:24 +00:00
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "img2img", params, size, extras=(strength,))
2023-02-21 05:47:43 +00:00
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
2023-02-05 13:53:26 +00:00
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
2023-02-05 13:53:26 +00:00
run_img2img_pipeline,
context,
params,
output,
upscale,
source,
2023-02-05 13:53:26 +00:00
strength,
needs_device=device,
2023-02-05 13:53:26 +00:00
)
2023-01-07 21:05:29 +00:00
return jsonify(json_params(output, params, size, upscale=upscale))
2023-01-07 21:05:29 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/txt2img", methods=["POST"])
2023-01-07 21:05:29 +00:00
def txt2img():
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "txt2img", params, size)
2023-02-21 05:47:43 +00:00
job_name = output[0]
logger.info("txt2img job queued for: %s", job_name)
2023-01-16 01:33:40 +00:00
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
run_txt2img_pipeline,
context,
params,
size,
output,
upscale,
needs_device=device,
2023-02-05 13:53:26 +00:00
)
return jsonify(json_params(output, params, size, upscale=upscale))
2023-02-05 13:53:26 +00:00
@app.route("/api/inpaint", methods=["POST"])
2023-01-09 00:11:34 +00:00
def inpaint():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
if "mask" not in request.files:
return error_reply("mask image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-09 00:11:34 +00:00
2023-02-05 13:53:26 +00:00
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
2023-01-09 00:11:34 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
expand = border_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-02-05 13:53:26 +00:00
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
2023-02-12 00:10:36 +00:00
tile_order = get_from_list(
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)
output = make_output_name(
context,
2023-02-05 13:53:26 +00:00
"inpaint",
2023-01-16 01:33:40 +00:00
params,
size,
extras=(
expand.left,
expand.right,
expand.top,
expand.bottom,
mask_filter.__name__,
noise_source.__name__,
fill_color,
tile_order,
2023-02-05 13:53:26 +00:00
),
2023-01-16 01:33:40 +00:00
)
2023-02-21 05:47:43 +00:00
job_name = output[0]
logger.info("inpaint job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
mask = valid_image(mask, min_dims=size, max_dims=size)
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
run_inpaint_pipeline,
context,
2023-01-16 01:33:40 +00:00
params,
size,
output,
upscale,
source,
mask,
2023-01-16 01:33:40 +00:00
expand,
noise_source,
mask_filter,
2023-02-05 13:53:26 +00:00
fill_color,
tile_order,
needs_device=device,
2023-02-05 13:53:26 +00:00
)
2023-01-09 00:11:34 +00:00
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
2023-01-09 00:11:34 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/upscale", methods=["POST"])
2023-01-17 05:45:54 +00:00
def upscale():
2023-02-05 13:53:26 +00:00
if "source" not in request.files:
return error_reply("source image is required")
2023-02-05 13:53:26 +00:00
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
2023-01-17 05:45:54 +00:00
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 03:23:34 +00:00
upscale = upscale_from_request()
2023-01-17 05:45:54 +00:00
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "upscale", params, size)
2023-02-21 05:47:43 +00:00
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
2023-01-17 05:45:54 +00:00
source = valid_image(source, min_dims=size, max_dims=size)
2023-02-05 13:53:26 +00:00
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
2023-02-05 13:53:26 +00:00
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source,
needs_device=device,
2023-02-05 13:53:26 +00:00
)
2023-01-17 05:45:54 +00:00
return jsonify(json_params(output, params, size, upscale=upscale))
2023-01-17 05:45:54 +00:00
2023-02-05 13:53:26 +00:00
@app.route("/api/chain", methods=["POST"])
def chain():
2023-02-05 13:53:26 +00:00
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
2023-02-04 20:24:18 +00:00
if body is None:
2023-02-05 13:53:26 +00:00
return error_reply("chain pipeline must have a body")
2023-02-04 20:24:18 +00:00
data = yaml.safe_load(body)
with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
2023-02-04 21:00:22 +00:00
device, params, size = pipeline_from_request()
2023-02-05 13:53:26 +00:00
output = make_output_name(context, "chain", params, size)
2023-02-21 05:47:43 +00:00
job_name = output[0]
pipeline = ChainPipeline()
2023-02-05 13:53:26 +00:00
for stage_data in data.get("stages", []):
callback = chain_stages[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams(
2023-02-05 13:53:26 +00:00
stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
2023-02-05 13:53:26 +00:00
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
2023-02-05 13:53:26 +00:00
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
2023-02-05 13:53:26 +00:00
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
2023-02-05 13:53:26 +00:00
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
2023-02-04 20:24:18 +00:00
source_file = request.files.get(stage_source_name)
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
2023-02-05 13:53:26 +00:00
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
2023-02-04 20:24:18 +00:00
mask_file = request.files.get(stage_mask_name)
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
2023-02-05 13:53:26 +00:00
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
2023-01-28 15:08:59 +00:00
# build and run chain pipeline
2023-02-05 13:53:26 +00:00
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
2023-02-11 22:50:57 +00:00
pipeline,
context,
params,
empty_source,
output=output[0],
2023-02-11 22:50:57 +00:00
size=size,
needs_device=device,
2023-02-05 13:53:26 +00:00
)
return jsonify(json_params(output, params, size))
@app.route("/api/blend", methods=["POST"])
def blend():
if "mask" not in request.files:
return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2
sources = []
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
2023-02-21 05:47:43 +00:00
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
executor.submit(
2023-02-21 05:47:43 +00:00
job_name,
run_blend_pipeline,
context,
params,
size,
output,
upscale,
sources,
mask,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
2023-02-14 13:40:06 +00:00
@app.route("/api/txt2txt", methods=["POST"])
def txt2txt():
device, params, size = pipeline_from_request()
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
executor.submit(
output,
run_txt2txt_pipeline,
context,
params,
size,
output,
needs_device=device,
)
return jsonify(json_params(output, params, size))
2023-02-05 13:53:26 +00:00
@app.route("/api/cancel", methods=["PUT"])
2023-02-04 16:59:03 +00:00
def cancel():
2023-02-05 13:53:26 +00:00
output_file = request.args.get("output", None)
2023-02-04 16:59:03 +00:00
cancel = executor.cancel(output_file)
return ready_reply(cancel)
2023-02-05 13:53:26 +00:00
@app.route("/api/ready")
def ready():
2023-02-05 13:53:26 +00:00
output_file = request.args.get("output", None)
done, progress = executor.done(output_file)
if done is None:
2023-02-20 04:10:35 +00:00
output = base_join(context.output_path, output_file)
if path.exists(output):
return ready_reply(True)
return ready_reply(done, progress=progress)
2023-02-05 13:53:26 +00:00
@app.route("/api/status")
2023-02-04 16:59:03 +00:00
def status():
return jsonify(executor.status())
2023-02-05 13:53:26 +00:00
@app.route("/output/<path:filename>")
def output(filename: str):
2023-02-05 13:53:26 +00:00
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)