1
0
Fork 0

feat(api): remove Flask app from global scope

This commit is contained in:
Sean Sube 2023-02-26 10:15:12 -06:00
parent 943281feb5
commit 06c74a7a96
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
15 changed files with 1044 additions and 895 deletions

1
api/.gitignore vendored
View File

@ -1,6 +1,7 @@
.coverage .coverage
coverage.xml coverage.xml
*.log
*.swp *.swp
*.pyc *.pyc

View File

@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-} --token=${HF_TOKEN:-}
echo "Launching API server..." echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0 flask --app='onnx_web.main:main()' run --host=0.0.0.0

View File

@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-} --token=${HF_TOKEN:-}
echo "Launching API server..." echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0 flask --app='onnx_web.main:main()' run --host=0.0.0.0

View File

@ -1,5 +1,10 @@
from . import logging from . import logging
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion from .chain import (
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline
from .diffusion.run import ( from .diffusion.run import (
run_blend_pipeline, run_blend_pipeline,

View File

@ -13,3 +13,20 @@ from .source_txt2img import source_txt2img
from .upscale_outpaint import upscale_outpaint from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion from .upscale_stable_diffusion import upscale_stable_diffusion
CHAIN_STAGES = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-mask": blend_mask,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}

View File

@ -4,9 +4,11 @@ from typing import Any, Optional, Tuple
import numpy as np import numpy as np
from diffusers import ( from diffusers import (
DiffusionPipeline,
OnnxRuntimeModel,
StableDiffusionPipeline,
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
@ -17,15 +19,13 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
OnnxRuntimeModel,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline,
) )
try: try:
from diffusers import DEISMultistepScheduler from diffusers import DEISMultistepScheduler
except ImportError: except ImportError:
from .stub_scheduler import StubScheduler as DEISMultistepScheduler from ..diffusion.stub_scheduler import StubScheduler as DEISMultistepScheduler
from ..params import DeviceParams, Size from ..params import DeviceParams, Size
from ..server import ServerContext from ..server import ServerContext
@ -54,6 +54,10 @@ pipeline_schedulers = {
} }
def get_pipeline_schedulers():
return pipeline_schedulers
def get_scheduler_name(scheduler: Any) -> Optional[str]: def get_scheduler_name(scheduler: Any) -> Optional[str]:
for k, v in pipeline_schedulers.items(): for k, v in pipeline_schedulers.items():
if scheduler == v or scheduler == v.__name__: if scheduler == v or scheduler == v.__name__:
@ -137,13 +141,14 @@ def load_pipeline(
server: ServerContext, server: ServerContext,
pipeline: DiffusionPipeline, pipeline: DiffusionPipeline,
model: str, model: str,
scheduler_type: Any, scheduler_name: str,
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
inversion: Optional[str], inversion: Optional[str],
): ):
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion) pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
scheduler_key = (scheduler_type, model) scheduler_key = (scheduler_name, model)
scheduler_type = get_pipeline_schedulers()[scheduler_name]
cache_pipe = server.cache.get("diffusion", pipe_key) cache_pipe = server.cache.get("diffusion", pipe_key)

53
api/onnx_web/main.py Normal file
View File

@ -0,0 +1,53 @@
import gc
from diffusers.utils.logging import disable_progress_bar
from flask import Flask
from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars
from .server.api import register_api_routes
from .server.static import register_static_routes
from .server.config import get_available_platforms, load_models, load_params, load_platforms
from .server.utils import check_paths
from .server.context import ServerContext
from .server.hacks import apply_patches
from .utils import (
is_debug,
)
from .worker import DevicePoolExecutor
def main():
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
load_platforms(context)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
if not context.show_progress:
disable_progress_bar()
disable_progress_bars()
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
# any is a fake device, should not be in the pool
pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"])
# register routes
register_static_routes(app, context, pool)
register_api_routes(app, context, pool)
return app #, context, pool
if __name__ == "__main__":
# app, context, pool = main()
app = main()
app.run("0.0.0.0", 5000, debug=is_debug())
# pool.join()

View File

@ -8,7 +8,6 @@ from typing import Any, List, Optional, Tuple
from PIL import Image from PIL import Image
from .diffusion.load import get_scheduler_name
from .params import Border, ImageParams, Param, Size, UpscaleParams from .params import Border, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext from .server import ServerContext
from .utils import base_join from .utils import base_join
@ -44,7 +43,7 @@ def json_params(
} }
json["params"]["model"] = path.basename(params.model) json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = get_scheduler_name(params.scheduler) json["params"]["scheduler"] = params.scheduler
if border is not None: if border is not None:
json["border"] = border.tojson() json["border"] = border.tojson()
@ -71,7 +70,7 @@ def make_output_name(
hash_value(sha, mode) hash_value(sha, mode)
hash_value(sha, params.model) hash_value(sha, params.model)
hash_value(sha, params.scheduler.__name__) hash_value(sha, params.scheduler)
hash_value(sha, params.prompt) hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt) hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg) hash_value(sha, params.cfg)

View File

@ -148,7 +148,7 @@ class ImageParams:
def __init__( def __init__(
self, self,
model: str, model: str,
scheduler: Any, scheduler: str,
prompt: str, prompt: str,
cfg: float, cfg: float,
steps: int, steps: int,
@ -174,7 +174,7 @@ class ImageParams:
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
"model": self.model, "model": self.model,
"scheduler": self.scheduler.__name__, "scheduler": self.scheduler,
"prompt": self.prompt, "prompt": self.prompt,
"negative_prompt": self.negative_prompt, "negative_prompt": self.negative_prompt,
"cfg": self.cfg, "cfg": self.cfg,

View File

@ -1,881 +0,0 @@
import gc
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
import yaml
from diffusers.utils.logging import disable_progress_bar
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars
from jsonschema import validate
from onnxruntime import get_available_providers
from PIL import Image
from .chain import (
ChainPipeline,
blend_img2img,
blend_inpaint,
correct_codeformer,
correct_gfpgan,
persist_disk,
persist_s3,
reduce_crop,
reduce_thumbnail,
source_noise,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
from .output import json_params, make_output_name
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
TileOrder,
UpscaleParams,
)
from .server import ServerContext, apply_patches
from .transformers import run_txt2txt_pipeline
from .utils import (
base_join,
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
get_size,
is_debug,
)
from .worker import DevicePoolExecutor
logger = getLogger(__name__)
# config caching
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
# pipeline params
platform_providers = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
noise_sources = {
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,
"gaussian": noise_source_gaussian,
"histogram": noise_source_histogram,
"normal": noise_source_normal,
"uniform": noise_source_uniform,
}
mask_filters = {
"none": mask_filter_none,
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
chain_stages = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
# Available ORT providers
available_platforms: List[DeviceParams] = []
# loaded from model_path
correction_models: List[str] = []
diffusion_models: List[str] = []
inversion_models: List[str] = []
upscaling_models: List[str] = []
def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default)
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
options[arg] = ":%s" % (arg)
return url_for(rule.endpoint, **options)
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
device = None
device_name = request.args.get("platform")
if device_name is not None and device_name != "any":
for platform in available_platforms:
if platform.device == device_name:
device = platform
# pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model)
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
inversion = request.args.get("inversion", None)
inversion_path = None
if inversion is not None and inversion.strip() != "":
inversion_path = get_model_path(inversion)
# image params
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
negative_prompt = request.args.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
batch = get_and_clamp_int(
request.args,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
steps = get_and_clamp_int(
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
seed = int(request.args.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler.__name__,
model_path,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
inversion=inversion_path,
)
size = Size(width, height)
return (device, params, size)
def border_from_request() -> Border:
left = get_and_clamp_int(
request.args, "left", 0, get_config_value("width", "max"), 0
)
right = get_and_clamp_int(
request.args, "right", 0, get_config_value("width", "max"), 0
)
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int(
request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
upscale_order = request.args.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
correction_model=correction,
denoise=denoise,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
format="onnx",
outscale=outscale,
scale=scale,
upscale_order=upscale_order,
)
def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path):
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
def get_model_name(model: str) -> str:
base = path.basename(model)
(file, _ext) = path.splitext(base)
return file
def load_models(context: ServerContext) -> None:
global correction_models
global diffusion_models
global inversion_models
global upscaling_models
diffusion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
]
diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
]
inversion_models = list(set(inversion_models))
inversion_models.sort()
upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
def load_params(context: ServerContext) -> None:
global config_params
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None:
logger.info(
"Overriding default platform from environment: %s",
context.default_platform,
)
config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform
def load_platforms(context: ServerContext) -> None:
global available_platforms
providers = list(get_available_providers())
for potential in platform_providers:
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
{
"device_id": i,
},
context.optimizations,
)
)
else:
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
)
if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)
# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == b.device:
return 0
# any should be first, if it's available
if a.device == "any":
return -1
# cpu should be last, if it's available
if a.device == "cpu":
return 1
return -1
available_platforms = sorted(
available_platforms, key=cmp_to_key(any_first_cpu_last)
)
logger.info(
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
load_platforms(context)
if not context.show_progress:
disable_progress_bar()
disable_progress_bars()
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
# any is a fake device, should not be in the pool
executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"])
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
def ready_reply(ready: bool, progress: int = 0):
return jsonify(
{
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
def get_model_path(model: str):
return base_join(context.model_path, model)
def serve_bundle_file(filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
# routes
@app.route("/")
def index():
return serve_bundle_file()
@app.route("/<path:filename>")
def index_path(filename):
return serve_bundle_file(filename)
@app.route("/api")
def introspect():
return {
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
for rule in app.url_map.iter_rules()
],
}
@app.route("/api/settings/masks")
def list_mask_filters():
return jsonify(list(mask_filters.keys()))
@app.route("/api/settings/models")
def list_models():
return jsonify(
{
"correction": correction_models,
"diffusion": diffusion_models,
"inversion": inversion_models,
"upscaling": upscaling_models,
}
)
@app.route("/api/settings/noises")
def list_noise_sources():
return jsonify(list(noise_sources.keys()))
@app.route("/api/settings/params")
def list_params():
return jsonify(config_params)
@app.route("/api/settings/platforms")
def list_platforms():
return jsonify([p.device for p in available_platforms])
@app.route("/api/settings/schedulers")
def list_schedulers():
return jsonify(list(pipeline_schedulers.keys()))
@app.route("/api/img2img", methods=["POST"])
def img2img():
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
output = make_output_name(context, "img2img", params, size, extras=(strength,))
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
job_name,
run_img2img_pipeline,
context,
params,
output,
upscale,
source,
strength,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route("/api/txt2img", methods=["POST"])
def txt2img():
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(context, "txt2img", params, size)
job_name = output[0]
logger.info("txt2img job queued for: %s", job_name)
executor.submit(
job_name,
run_txt2img_pipeline,
context,
params,
size,
output,
upscale,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route("/api/inpaint", methods=["POST"])
def inpaint():
if "source" not in request.files:
return error_reply("source image is required")
if "mask" not in request.files:
return error_reply("mask image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
expand = border_from_request()
upscale = upscale_from_request()
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
tile_order = get_from_list(
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)
output = make_output_name(
context,
"inpaint",
params,
size,
extras=(
expand.left,
expand.right,
expand.top,
expand.bottom,
mask_filter.__name__,
noise_source.__name__,
fill_color,
tile_order,
),
)
job_name = output[0]
logger.info("inpaint job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
mask = valid_image(mask, min_dims=size, max_dims=size)
executor.submit(
job_name,
run_inpaint_pipeline,
context,
params,
size,
output,
upscale,
source,
mask,
expand,
noise_source,
mask_filter,
fill_color,
tile_order,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
@app.route("/api/upscale", methods=["POST"])
def upscale():
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
job_name,
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route("/api/chain", methods=["POST"])
def chain():
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body)
with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request()
output = make_output_name(context, "chain", params, size)
job_name = output[0]
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
callback = chain_stages[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams(
stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
source_file = request.files.get(stage_source_name)
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
mask_file = request.files.get(stage_mask_name)
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
# build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
job_name,
pipeline,
context,
params,
empty_source,
output=output[0],
size=size,
needs_device=device,
)
return jsonify(json_params(output, params, size))
@app.route("/api/blend", methods=["POST"])
def blend():
if "mask" not in request.files:
return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2
sources = []
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
executor.submit(
job_name,
run_blend_pipeline,
context,
params,
size,
output,
upscale,
sources,
mask,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route("/api/txt2txt", methods=["POST"])
def txt2txt():
device, params, size = pipeline_from_request()
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
executor.submit(
output,
run_txt2txt_pipeline,
context,
params,
size,
output,
needs_device=device,
)
return jsonify(json_params(output, params, size))
@app.route("/api/cancel", methods=["PUT"])
def cancel():
output_file = request.args.get("output", None)
cancel = executor.cancel(output_file)
return ready_reply(cancel)
@app.route("/api/ready")
def ready():
output_file = request.args.get("output", None)
done, progress = executor.done(output_file)
if done is None:
output = base_join(context.output_path, output_file)
if path.exists(output):
return ready_reply(True)
return ready_reply(done, progress=progress)
@app.route("/api/status")
def status():
return jsonify(executor.status())
@app.route("/output/<path:filename>")
def output(filename: str):
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)

477
api/onnx_web/server/api.py Normal file
View File

@ -0,0 +1,477 @@
from io import BytesIO
from logging import getLogger
from os import path
import yaml
from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
from .context import ServerContext
from .utils import wrap_route
from ..worker.pool import DevicePoolExecutor
from .config import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_inversion_models,
get_mask_filters,
get_noise_sources,
get_upscaling_models,
)
from .params import border_from_request, pipeline_from_request, upscale_from_request
from ..chain import (
CHAIN_STAGES,
ChainPipeline,
)
from ..diffusion.load import get_pipeline_schedulers
from ..diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..image import ( # mask filters; noise sources
valid_image,
)
from ..output import json_params, make_output_name
from ..params import (
Border,
StageParams,
TileOrder,
UpscaleParams,
)
from ..transformers import run_txt2txt_pipeline
from ..utils import (
base_join,
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
get_size,
)
logger = getLogger(__name__)
def ready_reply(ready: bool, progress: int = 0):
return jsonify(
{
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
options[arg] = ":%s" % (arg)
return url_for(rule.endpoint, **options)
def introspect(context: ServerContext, app: Flask):
return {
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
for rule in app.url_map.iter_rules()
],
}
def list_mask_filters(context: ServerContext):
return jsonify(list(get_mask_filters().keys()))
def list_models(context: ServerContext):
return jsonify(
{
"correction": get_correction_models(),
"diffusion": get_diffusion_models(),
"inversion": get_inversion_models(),
"upscaling": get_upscaling_models(),
}
)
def list_noise_sources(context: ServerContext):
return jsonify(list(get_noise_sources().keys()))
def list_params(context: ServerContext):
return jsonify(get_config_params())
def list_platforms(context: ServerContext):
return jsonify([p.device for p in get_available_platforms()])
def list_schedulers(context: ServerContext):
return jsonify(list(get_pipeline_schedulers().keys()))
def img2img(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
output = make_output_name(context, "img2img", params, size, extras=(strength,))
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
pool.submit(
job_name,
run_img2img_pipeline,
context,
params,
output,
upscale,
source,
strength,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
output = make_output_name(context, "txt2img", params, size)
job_name = output[0]
logger.info("txt2img job queued for: %s", job_name)
pool.submit(
job_name,
run_txt2img_pipeline,
context,
params,
size,
output,
upscale,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
return error_reply("source image is required")
if "mask" not in request.files:
return error_reply("mask image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
expand = border_from_request()
upscale = upscale_from_request()
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram")
tile_order = get_from_list(
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)
output = make_output_name(
context,
"inpaint",
params,
size,
extras=(
expand.left,
expand.right,
expand.top,
expand.bottom,
mask_filter.__name__,
noise_source.__name__,
fill_color,
tile_order,
),
)
job_name = output[0]
logger.info("inpaint job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
mask = valid_image(mask, min_dims=size, max_dims=size)
pool.submit(
job_name,
run_inpaint_pipeline,
context,
params,
size,
output,
upscale,
source,
mask,
expand,
noise_source,
mask_filter,
fill_color,
tile_order,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
def upscale(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
source = valid_image(source, min_dims=size, max_dims=size)
pool.submit(
job_name,
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
def chain(context: ServerContext, pool: DevicePoolExecutor):
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body)
with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request(context)
output = make_output_name(context, "chain", params, size)
job_name = output[0]
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
callback = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams(
stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
source_file = request.files.get(stage_source_name)
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
mask_file = request.files.get(stage_mask_name)
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
# build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
pool.submit(
job_name,
pipeline,
context,
params,
empty_source,
output=output[0],
size=size,
needs_device=device,
)
return jsonify(json_params(output, params, size))
def blend(context: ServerContext, pool: DevicePoolExecutor):
if "mask" not in request.files:
return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2
sources = []
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
pool.submit(
job_name,
run_blend_pipeline,
context,
params,
size,
output,
upscale,
sources,
mask,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
pool.submit(
output,
run_txt2txt_pipeline,
context,
params,
size,
output,
needs_device=device,
)
return jsonify(json_params(output, params, size))
def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
cancel = pool.cancel(output_file)
return ready_reply(cancel)
def ready(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
done, progress = pool.done(output_file)
if done is None:
output = base_join(context.output_path, output_file)
if path.exists(output):
return ready_reply(True)
return ready_reply(done, progress=progress)
def status(context: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, context, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
app.route("/api/settings/models")(wrap_route(list_models, context)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)),
app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)),
app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)),
app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)),
app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)),
app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)),
app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)),
]

View File

@ -0,0 +1,224 @@
from functools import cmp_to_key
from glob import glob
from logging import getLogger
from os import path
from typing import Dict, List, Optional, Union
import torch
import yaml
from onnxruntime import get_available_providers
from .context import ServerContext
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from ..params import (
DeviceParams,
)
logger = getLogger(__name__)
# config caching
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
# pipeline params
platform_providers = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
noise_sources = {
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,
"gaussian": noise_source_gaussian,
"histogram": noise_source_histogram,
"normal": noise_source_normal,
"uniform": noise_source_uniform,
}
mask_filters = {
"none": mask_filter_none,
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
# Available ORT providers
available_platforms: List[DeviceParams] = []
# loaded from model_path
correction_models: List[str] = []
diffusion_models: List[str] = []
inversion_models: List[str] = []
upscaling_models: List[str] = []
def get_config_params():
return config_params
def get_available_platforms():
return available_platforms
def get_correction_models():
return correction_models
def get_diffusion_models():
return diffusion_models
def get_inversion_models():
return inversion_models
def get_upscaling_models():
return upscaling_models
def get_mask_filters():
return mask_filters
def get_noise_sources():
return noise_sources
def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default)
def get_model_name(model: str) -> str:
base = path.basename(model)
(file, _ext) = path.splitext(base)
return file
def load_models(context: ServerContext) -> None:
global correction_models
global diffusion_models
global inversion_models
global upscaling_models
diffusion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
]
diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
]
inversion_models = list(set(inversion_models))
inversion_models.sort()
upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
def load_params(context: ServerContext) -> None:
global config_params
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None:
logger.info(
"Overriding default platform from environment: %s",
context.default_platform,
)
config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform
def load_platforms(context: ServerContext) -> None:
global available_platforms
providers = list(get_available_providers())
for potential in platform_providers:
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
{
"device_id": i,
},
context.optimizations,
)
)
else:
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
)
if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)
# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == b.device:
return 0
# any should be first, if it's available
if a.device == "any":
return -1
# cpu should be last, if it's available
if a.device == "cpu":
return 1
return -1
available_platforms = sorted(
available_platforms, key=cmp_to_key(any_first_cpu_last)
)
logger.info(
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)

View File

@ -0,0 +1,183 @@
from logging import getLogger
from typing import Tuple
import numpy as np
from flask import request
from .context import ServerContext
from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models
from .utils import get_model_path
from ..diffusion.load import pipeline_schedulers
from ..params import (
Border,
DeviceParams,
ImageParams,
Size,
UpscaleParams,
)
from ..utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_not_empty,
)
logger = getLogger(__name__)
def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
device = None
device_name = request.args.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
# pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model)
scheduler = get_from_list(
request.args, "scheduler", pipeline_schedulers.keys()
)
if scheduler is None:
scheduler = get_config_value("scheduler")
inversion = request.args.get("inversion", None)
inversion_path = None
if inversion is not None and inversion.strip() != "":
inversion_path = get_model_path(context, inversion)
# image params
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
negative_prompt = request.args.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
batch = get_and_clamp_int(
request.args,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
steps = get_and_clamp_int(
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
seed = int(request.args.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
inversion=inversion_path,
)
size = Size(width, height)
return (device, params, size)
def border_from_request() -> Border:
left = get_and_clamp_int(
request.args, "left", 0, get_config_value("width", "max"), 0
)
right = get_and_clamp_int(
request.args, "right", 0, get_config_value("width", "max"), 0
)
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int(
request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
correction = get_from_list(request.args, "correction", get_correction_models())
faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
upscale_order = request.args.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
correction_model=correction,
denoise=denoise,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
format="onnx",
outscale=outscale,
scale=scale,
upscale_order=upscale_order,
)

View File

@ -0,0 +1,34 @@
from os import path
from flask import Flask, send_from_directory
from .utils import wrap_route
from .context import ServerContext
from ..worker.pool import DevicePoolExecutor
def serve_bundle_file(context: ServerContext, filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
# non-API routes
def index(context: ServerContext):
return serve_bundle_file(context)
def index_path(context: ServerContext, filename: str):
return serve_bundle_file(context, filename)
def output(context: ServerContext, filename: str):
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)
def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/")(wrap_route(index, context)),
app.route("/<path:filename>")(wrap_route(index_path, context)),
app.route("/output/<path:filename>")(wrap_route(output, context)),
]

View File

@ -0,0 +1,32 @@
from os import makedirs, path
from typing import Callable, Dict, List, Tuple
from functools import partial, update_wrapper
from flask import Flask
from onnx_web.utils import base_join
from onnx_web.worker.pool import DevicePoolExecutor
from .context import ServerContext
def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path):
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
def get_model_path(context: ServerContext, model: str):
return base_join(context.model_path, model)
def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]):
pass
def wrap_route(func, *args, **kwargs):
partial_func = partial(func, *args, **kwargs)
update_wrapper(partial_func, func)
return partial_func