From 06c74a7a96b73facd19a7ed575252679748eac41 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 10:15:12 -0600 Subject: [PATCH] feat(api): remove Flask app from global scope --- api/.gitignore | 1 + api/launch-extras.sh | 2 +- api/launch.sh | 2 +- api/onnx_web/__init__.py | 7 +- api/onnx_web/chain/__init__.py | 17 + api/onnx_web/diffusion/load.py | 17 +- api/onnx_web/main.py | 53 ++ api/onnx_web/output.py | 5 +- api/onnx_web/params.py | 4 +- api/onnx_web/serve.py | 881 --------------------------------- api/onnx_web/server/api.py | 477 ++++++++++++++++++ api/onnx_web/server/config.py | 224 +++++++++ api/onnx_web/server/params.py | 183 +++++++ api/onnx_web/server/static.py | 34 ++ api/onnx_web/server/utils.py | 32 ++ 15 files changed, 1044 insertions(+), 895 deletions(-) create mode 100644 api/onnx_web/main.py delete mode 100644 api/onnx_web/serve.py create mode 100644 api/onnx_web/server/api.py create mode 100644 api/onnx_web/server/config.py create mode 100644 api/onnx_web/server/params.py create mode 100644 api/onnx_web/server/static.py create mode 100644 api/onnx_web/server/utils.py diff --git a/api/.gitignore b/api/.gitignore index a315070e..2ba1650c 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -1,6 +1,7 @@ .coverage coverage.xml +*.log *.swp *.pyc diff --git a/api/launch-extras.sh b/api/launch-extras.sh index 96db9bb2..f18e14c0 100755 --- a/api/launch-extras.sh +++ b/api/launch-extras.sh @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:main()' run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 50863ba8..983e0930 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:main()' run --host=0.0.0.0 diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 7316bb87..b019bb3d 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -1,5 +1,10 @@ from . import logging -from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion +from .chain import ( + correct_codeformer, + correct_gfpgan, + upscale_resrgan, + upscale_stable_diffusion, +) from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline from .diffusion.run import ( run_blend_pipeline, diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 5aa56c56..44fdd6c1 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -13,3 +13,20 @@ from .source_txt2img import source_txt2img from .upscale_outpaint import upscale_outpaint from .upscale_resrgan import upscale_resrgan from .upscale_stable_diffusion import upscale_stable_diffusion + +CHAIN_STAGES = { + "blend-img2img": blend_img2img, + "blend-inpaint": blend_inpaint, + "blend-mask": blend_mask, + "correct-codeformer": correct_codeformer, + "correct-gfpgan": correct_gfpgan, + "persist-disk": persist_disk, + "persist-s3": persist_s3, + "reduce-crop": reduce_crop, + "reduce-thumbnail": reduce_thumbnail, + "source-noise": source_noise, + "source-txt2img": source_txt2img, + "upscale-outpaint": upscale_outpaint, + "upscale-resrgan": upscale_resrgan, + "upscale-stable-diffusion": upscale_stable_diffusion, +} \ No newline at end of file diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 47c35fcd..1fff3f1d 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -4,9 +4,11 @@ from typing import Any, Optional, Tuple import numpy as np from diffusers import ( + DiffusionPipeline, + OnnxRuntimeModel, + StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, - DiffusionPipeline, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -17,15 +19,13 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, - OnnxRuntimeModel, PNDMScheduler, - StableDiffusionPipeline, ) try: from diffusers import DEISMultistepScheduler except ImportError: - from .stub_scheduler import StubScheduler as DEISMultistepScheduler + from ..diffusion.stub_scheduler import StubScheduler as DEISMultistepScheduler from ..params import DeviceParams, Size from ..server import ServerContext @@ -54,6 +54,10 @@ pipeline_schedulers = { } +def get_pipeline_schedulers(): + return pipeline_schedulers + + def get_scheduler_name(scheduler: Any) -> Optional[str]: for k, v in pipeline_schedulers.items(): if scheduler == v or scheduler == v.__name__: @@ -137,13 +141,14 @@ def load_pipeline( server: ServerContext, pipeline: DiffusionPipeline, model: str, - scheduler_type: Any, + scheduler_name: str, device: DeviceParams, lpw: bool, inversion: Optional[str], ): pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion) - scheduler_key = (scheduler_type, model) + scheduler_key = (scheduler_name, model) + scheduler_type = get_pipeline_schedulers()[scheduler_name] cache_pipe = server.cache.get("diffusion", pipe_key) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py new file mode 100644 index 00000000..6381a241 --- /dev/null +++ b/api/onnx_web/main.py @@ -0,0 +1,53 @@ +import gc + +from diffusers.utils.logging import disable_progress_bar +from flask import Flask +from flask_cors import CORS +from huggingface_hub.utils.tqdm import disable_progress_bars + +from .server.api import register_api_routes +from .server.static import register_static_routes +from .server.config import get_available_platforms, load_models, load_params, load_platforms +from .server.utils import check_paths +from .server.context import ServerContext +from .server.hacks import apply_patches +from .utils import ( + is_debug, +) +from .worker import DevicePoolExecutor + + +def main(): + context = ServerContext.from_environ() + apply_patches(context) + check_paths(context) + load_models(context) + load_params(context) + load_platforms(context) + + if is_debug(): + gc.set_debug(gc.DEBUG_STATS) + + if not context.show_progress: + disable_progress_bar() + disable_progress_bars() + + app = Flask(__name__) + CORS(app, origins=context.cors_origin) + + # any is a fake device, should not be in the pool + pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"]) + + # register routes + register_static_routes(app, context, pool) + register_api_routes(app, context, pool) + + return app #, context, pool + + +if __name__ == "__main__": + # app, context, pool = main() + app = main() + app.run("0.0.0.0", 5000, debug=is_debug()) + # pool.join() + diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index f60b3941..749008a2 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -8,7 +8,6 @@ from typing import Any, List, Optional, Tuple from PIL import Image -from .diffusion.load import get_scheduler_name from .params import Border, ImageParams, Param, Size, UpscaleParams from .server import ServerContext from .utils import base_join @@ -44,7 +43,7 @@ def json_params( } json["params"]["model"] = path.basename(params.model) - json["params"]["scheduler"] = get_scheduler_name(params.scheduler) + json["params"]["scheduler"] = params.scheduler if border is not None: json["border"] = border.tojson() @@ -71,7 +70,7 @@ def make_output_name( hash_value(sha, mode) hash_value(sha, params.model) - hash_value(sha, params.scheduler.__name__) + hash_value(sha, params.scheduler) hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index db23414a..07cf082a 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -148,7 +148,7 @@ class ImageParams: def __init__( self, model: str, - scheduler: Any, + scheduler: str, prompt: str, cfg: float, steps: int, @@ -174,7 +174,7 @@ class ImageParams: def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model, - "scheduler": self.scheduler.__name__, + "scheduler": self.scheduler, "prompt": self.prompt, "negative_prompt": self.negative_prompt, "cfg": self.cfg, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py deleted file mode 100644 index 4d20f296..00000000 --- a/api/onnx_web/serve.py +++ /dev/null @@ -1,881 +0,0 @@ -import gc -from functools import cmp_to_key -from glob import glob -from io import BytesIO -from logging import getLogger -from os import makedirs, path -from typing import Dict, List, Tuple, Union - -import numpy as np -import torch -import yaml -from diffusers.utils.logging import disable_progress_bar -from flask import Flask, jsonify, make_response, request, send_from_directory, url_for -from flask_cors import CORS -from huggingface_hub.utils.tqdm import disable_progress_bars -from jsonschema import validate -from onnxruntime import get_available_providers -from PIL import Image - -from .chain import ( - ChainPipeline, - blend_img2img, - blend_inpaint, - correct_codeformer, - correct_gfpgan, - persist_disk, - persist_s3, - reduce_crop, - reduce_thumbnail, - source_noise, - source_txt2img, - upscale_outpaint, - upscale_resrgan, - upscale_stable_diffusion, -) -from .diffusion.load import pipeline_schedulers -from .diffusion.run import ( - run_blend_pipeline, - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, - run_upscale_pipeline, -) -from .image import ( # mask filters; noise sources - mask_filter_gaussian_multiply, - mask_filter_gaussian_screen, - mask_filter_none, - noise_source_fill_edge, - noise_source_fill_mask, - noise_source_gaussian, - noise_source_histogram, - noise_source_normal, - noise_source_uniform, - valid_image, -) -from .output import json_params, make_output_name -from .params import ( - Border, - DeviceParams, - ImageParams, - Size, - StageParams, - TileOrder, - UpscaleParams, -) -from .server import ServerContext, apply_patches -from .transformers import run_txt2txt_pipeline -from .utils import ( - base_join, - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_from_map, - get_not_empty, - get_size, - is_debug, -) -from .worker import DevicePoolExecutor - -logger = getLogger(__name__) - -# config caching -config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} - -# pipeline params -platform_providers = { - "cpu": "CPUExecutionProvider", - "cuda": "CUDAExecutionProvider", - "directml": "DmlExecutionProvider", - "rocm": "ROCMExecutionProvider", -} - -noise_sources = { - "fill-edge": noise_source_fill_edge, - "fill-mask": noise_source_fill_mask, - "gaussian": noise_source_gaussian, - "histogram": noise_source_histogram, - "normal": noise_source_normal, - "uniform": noise_source_uniform, -} -mask_filters = { - "none": mask_filter_none, - "gaussian-multiply": mask_filter_gaussian_multiply, - "gaussian-screen": mask_filter_gaussian_screen, -} -chain_stages = { - "blend-img2img": blend_img2img, - "blend-inpaint": blend_inpaint, - "correct-codeformer": correct_codeformer, - "correct-gfpgan": correct_gfpgan, - "persist-disk": persist_disk, - "persist-s3": persist_s3, - "reduce-crop": reduce_crop, - "reduce-thumbnail": reduce_thumbnail, - "source-noise": source_noise, - "source-txt2img": source_txt2img, - "upscale-outpaint": upscale_outpaint, - "upscale-resrgan": upscale_resrgan, - "upscale-stable-diffusion": upscale_stable_diffusion, -} - -# Available ORT providers -available_platforms: List[DeviceParams] = [] - -# loaded from model_path -correction_models: List[str] = [] -diffusion_models: List[str] = [] -inversion_models: List[str] = [] -upscaling_models: List[str] = [] - - -def get_config_value(key: str, subkey: str = "default", default=None): - return config_params.get(key, {}).get(subkey, default) - - -def url_from_rule(rule) -> str: - options = {} - for arg in rule.arguments: - options[arg] = ":%s" % (arg) - - return url_for(rule.endpoint, **options) - - -def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: - user = request.remote_addr - - # platform stuff - device = None - device_name = request.args.get("platform") - - if device_name is not None and device_name != "any": - for platform in available_platforms: - if platform.device == device_name: - device = platform - - # pipeline stuff - lpw = get_not_empty(request.args, "lpw", "false") == "true" - model = get_not_empty(request.args, "model", get_config_value("model")) - model_path = get_model_path(model) - scheduler = get_from_map( - request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler") - ) - - inversion = request.args.get("inversion", None) - inversion_path = None - if inversion is not None and inversion.strip() != "": - inversion_path = get_model_path(inversion) - - # image params - prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) - negative_prompt = request.args.get("negativePrompt", None) - - if negative_prompt is not None and negative_prompt.strip() == "": - negative_prompt = None - - batch = get_and_clamp_int( - request.args, - "batch", - get_config_value("batch"), - get_config_value("batch", "max"), - get_config_value("batch", "min"), - ) - cfg = get_and_clamp_float( - request.args, - "cfg", - get_config_value("cfg"), - get_config_value("cfg", "max"), - get_config_value("cfg", "min"), - ) - eta = get_and_clamp_float( - request.args, - "eta", - get_config_value("eta"), - get_config_value("eta", "max"), - get_config_value("eta", "min"), - ) - steps = get_and_clamp_int( - request.args, - "steps", - get_config_value("steps"), - get_config_value("steps", "max"), - get_config_value("steps", "min"), - ) - height = get_and_clamp_int( - request.args, - "height", - get_config_value("height"), - get_config_value("height", "max"), - get_config_value("height", "min"), - ) - width = get_and_clamp_int( - request.args, - "width", - get_config_value("width"), - get_config_value("width", "max"), - get_config_value("width", "min"), - ) - - seed = int(request.args.get("seed", -1)) - if seed == -1: - # this one can safely use np.random because it produces a single value - seed = np.random.randint(np.iinfo(np.int32).max) - - logger.info( - "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", - user, - steps, - scheduler.__name__, - model_path, - device or "any device", - width, - height, - cfg, - seed, - prompt, - ) - - params = ImageParams( - model_path, - scheduler, - prompt, - cfg, - steps, - seed, - eta=eta, - lpw=lpw, - negative_prompt=negative_prompt, - batch=batch, - inversion=inversion_path, - ) - size = Size(width, height) - return (device, params, size) - - -def border_from_request() -> Border: - left = get_and_clamp_int( - request.args, "left", 0, get_config_value("width", "max"), 0 - ) - right = get_and_clamp_int( - request.args, "right", 0, get_config_value("width", "max"), 0 - ) - top = get_and_clamp_int( - request.args, "top", 0, get_config_value("height", "max"), 0 - ) - bottom = get_and_clamp_int( - request.args, "bottom", 0, get_config_value("height", "max"), 0 - ) - - return Border(left, right, top, bottom) - - -def upscale_from_request() -> UpscaleParams: - denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) - scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) - outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) - upscaling = get_from_list(request.args, "upscaling", upscaling_models) - correction = get_from_list(request.args, "correction", correction_models) - faces = get_not_empty(request.args, "faces", "false") == "true" - face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) - face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) - upscale_order = request.args.get("upscaleOrder", "correction-first") - - return UpscaleParams( - upscaling, - correction_model=correction, - denoise=denoise, - faces=faces, - face_outscale=face_outscale, - face_strength=face_strength, - format="onnx", - outscale=outscale, - scale=scale, - upscale_order=upscale_order, - ) - - -def check_paths(context: ServerContext) -> None: - if not path.exists(context.model_path): - raise RuntimeError("model path must exist") - - if not path.exists(context.output_path): - makedirs(context.output_path) - - -def get_model_name(model: str) -> str: - base = path.basename(model) - (file, _ext) = path.splitext(base) - return file - - -def load_models(context: ServerContext) -> None: - global correction_models - global diffusion_models - global inversion_models - global upscaling_models - - diffusion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) - ] - diffusion_models.extend( - [ - get_model_name(f) - for f in glob(path.join(context.model_path, "stable-diffusion-*")) - ] - ) - diffusion_models = list(set(diffusion_models)) - diffusion_models.sort() - - correction_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) - ] - correction_models = list(set(correction_models)) - correction_models.sort() - - inversion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) - ] - inversion_models = list(set(inversion_models)) - inversion_models.sort() - - upscaling_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) - ] - upscaling_models = list(set(upscaling_models)) - upscaling_models.sort() - - -def load_params(context: ServerContext) -> None: - global config_params - params_file = path.join(context.params_path, "params.json") - with open(params_file, "r") as f: - config_params = yaml.safe_load(f) - - if "platform" in config_params and context.default_platform is not None: - logger.info( - "Overriding default platform from environment: %s", - context.default_platform, - ) - config_platform = config_params.get("platform", {}) - config_platform["default"] = context.default_platform - - -def load_platforms(context: ServerContext) -> None: - global available_platforms - - providers = list(get_available_providers()) - - for potential in platform_providers: - if ( - platform_providers[potential] in providers - and potential not in context.block_platforms - ): - if potential == "cuda": - for i in range(torch.cuda.device_count()): - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - { - "device_id": i, - }, - context.optimizations, - ) - ) - else: - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - None, - context.optimizations, - ) - ) - - if context.any_platform: - # the platform should be ignored when the job is scheduled, but set to CPU just in case - available_platforms.append( - DeviceParams( - "any", - platform_providers["cpu"], - None, - context.optimizations, - ) - ) - - # make sure CPU is last on the list - def any_first_cpu_last(a: DeviceParams, b: DeviceParams): - if a.device == b.device: - return 0 - - # any should be first, if it's available - if a.device == "any": - return -1 - - # cpu should be last, if it's available - if a.device == "cpu": - return 1 - - return -1 - - available_platforms = sorted( - available_platforms, key=cmp_to_key(any_first_cpu_last) - ) - - logger.info( - "available acceleration platforms: %s", - ", ".join([str(p) for p in available_platforms]), - ) - - -context = ServerContext.from_environ() -apply_patches(context) -check_paths(context) -load_models(context) -load_params(context) -load_platforms(context) - -if not context.show_progress: - disable_progress_bar() - disable_progress_bars() - -app = Flask(__name__) -CORS(app, origins=context.cors_origin) - -# any is a fake device, should not be in the pool -executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"]) - -if is_debug(): - gc.set_debug(gc.DEBUG_STATS) - - -def ready_reply(ready: bool, progress: int = 0): - return jsonify( - { - "progress": progress, - "ready": ready, - } - ) - - -def error_reply(err: str): - response = make_response( - jsonify( - { - "error": err, - } - ) - ) - response.status_code = 400 - return response - - -def get_model_path(model: str): - return base_join(context.model_path, model) - - -def serve_bundle_file(filename="index.html"): - return send_from_directory(path.join("..", context.bundle_path), filename) - - -# routes - - -@app.route("/") -def index(): - return serve_bundle_file() - - -@app.route("/") -def index_path(filename): - return serve_bundle_file(filename) - - -@app.route("/api") -def introspect(): - return { - "name": "onnx-web", - "routes": [ - {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} - for rule in app.url_map.iter_rules() - ], - } - - -@app.route("/api/settings/masks") -def list_mask_filters(): - return jsonify(list(mask_filters.keys())) - - -@app.route("/api/settings/models") -def list_models(): - return jsonify( - { - "correction": correction_models, - "diffusion": diffusion_models, - "inversion": inversion_models, - "upscaling": upscaling_models, - } - ) - - -@app.route("/api/settings/noises") -def list_noise_sources(): - return jsonify(list(noise_sources.keys())) - - -@app.route("/api/settings/params") -def list_params(): - return jsonify(config_params) - - -@app.route("/api/settings/platforms") -def list_platforms(): - return jsonify([p.device for p in available_platforms]) - - -@app.route("/api/settings/schedulers") -def list_schedulers(): - return jsonify(list(pipeline_schedulers.keys())) - - -@app.route("/api/img2img", methods=["POST"]) -def img2img(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - strength = get_and_clamp_float( - request.args, - "strength", - get_config_value("strength"), - get_config_value("strength", "max"), - get_config_value("strength", "min"), - ) - - output = make_output_name(context, "img2img", params, size, extras=(strength,)) - job_name = output[0] - logger.info("img2img job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_img2img_pipeline, - context, - params, - output, - upscale, - source, - strength, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2img", methods=["POST"]) -def txt2img(): - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "txt2img", params, size) - job_name = output[0] - logger.info("txt2img job queued for: %s", job_name) - - executor.submit( - job_name, - run_txt2img_pipeline, - context, - params, - size, - output, - upscale, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/inpaint", methods=["POST"]) -def inpaint(): - if "source" not in request.files: - return error_reply("source image is required") - - if "mask" not in request.files: - return error_reply("mask image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - expand = border_from_request() - upscale = upscale_from_request() - - fill_color = get_not_empty(request.args, "fillColor", "white") - mask_filter = get_from_map(request.args, "filter", mask_filters, "none") - noise_source = get_from_map(request.args, "noise", noise_sources, "histogram") - tile_order = get_from_list( - request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] - ) - - output = make_output_name( - context, - "inpaint", - params, - size, - extras=( - expand.left, - expand.right, - expand.top, - expand.bottom, - mask_filter.__name__, - noise_source.__name__, - fill_color, - tile_order, - ), - ) - job_name = output[0] - logger.info("inpaint job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - mask = valid_image(mask, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_inpaint_pipeline, - context, - params, - size, - output, - upscale, - source, - mask, - expand, - noise_source, - mask_filter, - fill_color, - tile_order, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) - - -@app.route("/api/upscale", methods=["POST"]) -def upscale(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_upscale_pipeline, - context, - params, - size, - output, - upscale, - source, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/chain", methods=["POST"]) -def chain(): - logger.debug( - "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() - ) - body = request.form.get("chain") or request.files.get("chain") - if body is None: - return error_reply("chain pipeline must have a body") - - data = yaml.safe_load(body) - with open("./schemas/chain.yaml", "r") as f: - schema = yaml.safe_load(f.read()) - - logger.debug("validating chain request: %s against %s", data, schema) - validate(data, schema) - - # get defaults from the regular parameters - device, params, size = pipeline_from_request() - output = make_output_name(context, "chain", params, size) - job_name = output[0] - - pipeline = ChainPipeline() - for stage_data in data.get("stages", []): - callback = chain_stages[stage_data.get("type")] - kwargs = stage_data.get("params", {}) - logger.info("request stage: %s, %s", callback.__name__, kwargs) - - stage = StageParams( - stage_data.get("name", callback.__name__), - tile_size=get_size(kwargs.get("tile_size")), - outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), - ) - - if "border" in kwargs: - border = Border.even(int(kwargs.get("border"))) - kwargs["border"] = border - - if "upscale" in kwargs: - upscale = UpscaleParams(kwargs.get("upscale")) - kwargs["upscale"] = upscale - - stage_source_name = "source:%s" % (stage.name) - stage_mask_name = "mask:%s" % (stage.name) - - if stage_source_name in request.files: - logger.debug( - "loading source image %s for pipeline stage %s", - stage_source_name, - stage.name, - ) - source_file = request.files.get(stage_source_name) - source = Image.open(BytesIO(source_file.read())).convert("RGB") - source = valid_image(source, max_dims=(size.width, size.height)) - kwargs["stage_source"] = source - - if stage_mask_name in request.files: - logger.debug( - "loading mask image %s for pipeline stage %s", - stage_mask_name, - stage.name, - ) - mask_file = request.files.get(stage_mask_name) - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask = valid_image(mask, max_dims=(size.width, size.height)) - kwargs["stage_mask"] = mask - - pipeline.append((callback, stage, kwargs)) - - logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - - # build and run chain pipeline - empty_source = Image.new("RGB", (size.width, size.height)) - executor.submit( - job_name, - pipeline, - context, - params, - empty_source, - output=output[0], - size=size, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/blend", methods=["POST"]) -def blend(): - if "mask" not in request.files: - return error_reply("mask image is required") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") - mask = valid_image(mask) - - max_sources = 2 - sources = [] - - for i in range(max_sources): - source_file = request.files.get("source:%s" % (i)) - source = Image.open(BytesIO(source_file.read())).convert("RGBA") - source = valid_image(source, mask.size, mask.size) - sources.append(source) - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - executor.submit( - job_name, - run_blend_pipeline, - context, - params, - size, - output, - upscale, - sources, - mask, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2txt", methods=["POST"]) -def txt2txt(): - device, params, size = pipeline_from_request() - - output = make_output_name(context, "upscale", params, size) - logger.info("upscale job queued for: %s", output) - - executor.submit( - output, - run_txt2txt_pipeline, - context, - params, - size, - output, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/cancel", methods=["PUT"]) -def cancel(): - output_file = request.args.get("output", None) - - cancel = executor.cancel(output_file) - - return ready_reply(cancel) - - -@app.route("/api/ready") -def ready(): - output_file = request.args.get("output", None) - - done, progress = executor.done(output_file) - - if done is None: - output = base_join(context.output_path, output_file) - if path.exists(output): - return ready_reply(True) - - return ready_reply(done, progress=progress) - - -@app.route("/api/status") -def status(): - return jsonify(executor.status()) - - -@app.route("/output/") -def output(filename: str): - return send_from_directory( - path.join("..", context.output_path), filename, as_attachment=False - ) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py new file mode 100644 index 00000000..ed7dc8c0 --- /dev/null +++ b/api/onnx_web/server/api.py @@ -0,0 +1,477 @@ +from io import BytesIO +from logging import getLogger +from os import path + +import yaml +from flask import Flask, jsonify, make_response, request, url_for +from jsonschema import validate +from PIL import Image + +from .context import ServerContext +from .utils import wrap_route +from ..worker.pool import DevicePoolExecutor + +from .config import ( + get_available_platforms, + get_config_params, + get_config_value, + get_correction_models, + get_diffusion_models, + get_inversion_models, + get_mask_filters, + get_noise_sources, + get_upscaling_models, +) +from .params import border_from_request, pipeline_from_request, upscale_from_request + +from ..chain import ( + CHAIN_STAGES, + ChainPipeline, +) +from ..diffusion.load import get_pipeline_schedulers +from ..diffusion.run import ( + run_blend_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) +from ..image import ( # mask filters; noise sources + valid_image, +) +from ..output import json_params, make_output_name +from ..params import ( + Border, + StageParams, + TileOrder, + UpscaleParams, +) +from ..transformers import run_txt2txt_pipeline +from ..utils import ( + base_join, + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_from_map, + get_not_empty, + get_size, +) + +logger = getLogger(__name__) + + +def ready_reply(ready: bool, progress: int = 0): + return jsonify( + { + "progress": progress, + "ready": ready, + } + ) + + +def error_reply(err: str): + response = make_response( + jsonify( + { + "error": err, + } + ) + ) + response.status_code = 400 + return response + + +def url_from_rule(rule) -> str: + options = {} + for arg in rule.arguments: + options[arg] = ":%s" % (arg) + + return url_for(rule.endpoint, **options) + + +def introspect(context: ServerContext, app: Flask): + return { + "name": "onnx-web", + "routes": [ + {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} + for rule in app.url_map.iter_rules() + ], + } + + +def list_mask_filters(context: ServerContext): + return jsonify(list(get_mask_filters().keys())) + + +def list_models(context: ServerContext): + return jsonify( + { + "correction": get_correction_models(), + "diffusion": get_diffusion_models(), + "inversion": get_inversion_models(), + "upscaling": get_upscaling_models(), + } + ) + + +def list_noise_sources(context: ServerContext): + return jsonify(list(get_noise_sources().keys())) + + +def list_params(context: ServerContext): + return jsonify(get_config_params()) + + +def list_platforms(context: ServerContext): + return jsonify([p.device for p in get_available_platforms()]) + + +def list_schedulers(context: ServerContext): + return jsonify(list(get_pipeline_schedulers().keys())) + + +def img2img(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + strength = get_and_clamp_float( + request.args, + "strength", + get_config_value("strength"), + get_config_value("strength", "max"), + get_config_value("strength", "min"), + ) + + output = make_output_name(context, "img2img", params, size, extras=(strength,)) + job_name = output[0] + logger.info("img2img job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_img2img_pipeline, + context, + params, + output, + upscale, + source, + strength, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2img(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "txt2img", params, size) + job_name = output[0] + logger.info("txt2img job queued for: %s", job_name) + + pool.submit( + job_name, + run_txt2img_pipeline, + context, + params, + size, + output, + upscale, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def inpaint(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + if "mask" not in request.files: + return error_reply("mask image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + mask_file = request.files.get("mask") + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + expand = border_from_request() + upscale = upscale_from_request() + + fill_color = get_not_empty(request.args, "fillColor", "white") + mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") + noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram") + tile_order = get_from_list( + request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] + ) + + output = make_output_name( + context, + "inpaint", + params, + size, + extras=( + expand.left, + expand.right, + expand.top, + expand.bottom, + mask_filter.__name__, + noise_source.__name__, + fill_color, + tile_order, + ), + ) + job_name = output[0] + logger.info("inpaint job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + mask = valid_image(mask, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_inpaint_pipeline, + context, + params, + size, + output, + upscale, + source, + mask, + expand, + noise_source, + mask_filter, + fill_color, + tile_order, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) + + +def upscale(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_upscale_pipeline, + context, + params, + size, + output, + upscale, + source, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def chain(context: ServerContext, pool: DevicePoolExecutor): + logger.debug( + "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() + ) + body = request.form.get("chain") or request.files.get("chain") + if body is None: + return error_reply("chain pipeline must have a body") + + data = yaml.safe_load(body) + with open("./schemas/chain.yaml", "r") as f: + schema = yaml.safe_load(f.read()) + + logger.debug("validating chain request: %s against %s", data, schema) + validate(data, schema) + + # get defaults from the regular parameters + device, params, size = pipeline_from_request(context) + output = make_output_name(context, "chain", params, size) + job_name = output[0] + + pipeline = ChainPipeline() + for stage_data in data.get("stages", []): + callback = CHAIN_STAGES[stage_data.get("type")] + kwargs = stage_data.get("params", {}) + logger.info("request stage: %s, %s", callback.__name__, kwargs) + + stage = StageParams( + stage_data.get("name", callback.__name__), + tile_size=get_size(kwargs.get("tile_size")), + outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), + ) + + if "border" in kwargs: + border = Border.even(int(kwargs.get("border"))) + kwargs["border"] = border + + if "upscale" in kwargs: + upscale = UpscaleParams(kwargs.get("upscale")) + kwargs["upscale"] = upscale + + stage_source_name = "source:%s" % (stage.name) + stage_mask_name = "mask:%s" % (stage.name) + + if stage_source_name in request.files: + logger.debug( + "loading source image %s for pipeline stage %s", + stage_source_name, + stage.name, + ) + source_file = request.files.get(stage_source_name) + source = Image.open(BytesIO(source_file.read())).convert("RGB") + source = valid_image(source, max_dims=(size.width, size.height)) + kwargs["stage_source"] = source + + if stage_mask_name in request.files: + logger.debug( + "loading mask image %s for pipeline stage %s", + stage_mask_name, + stage.name, + ) + mask_file = request.files.get(stage_mask_name) + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + mask = valid_image(mask, max_dims=(size.width, size.height)) + kwargs["stage_mask"] = mask + + pipeline.append((callback, stage, kwargs)) + + logger.info("running chain pipeline with %s stages", len(pipeline.stages)) + + # build and run chain pipeline + empty_source = Image.new("RGB", (size.width, size.height)) + pool.submit( + job_name, + pipeline, + context, + params, + empty_source, + output=output[0], + size=size, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def blend(context: ServerContext, pool: DevicePoolExecutor): + if "mask" not in request.files: + return error_reply("mask image is required") + + mask_file = request.files.get("mask") + mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") + mask = valid_image(mask) + + max_sources = 2 + sources = [] + + for i in range(max_sources): + source_file = request.files.get("source:%s" % (i)) + source = Image.open(BytesIO(source_file.read())).convert("RGBA") + source = valid_image(source, mask.size, mask.size) + sources.append(source) + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + pool.submit( + job_name, + run_blend_pipeline, + context, + params, + size, + output, + upscale, + sources, + mask, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2txt(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + + output = make_output_name(context, "upscale", params, size) + logger.info("upscale job queued for: %s", output) + + pool.submit( + output, + run_txt2txt_pipeline, + context, + params, + size, + output, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def cancel(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + + cancel = pool.cancel(output_file) + + return ready_reply(cancel) + + +def ready(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + + done, progress = pool.done(output_file) + + if done is None: + output = base_join(context.output_path, output_file) + if path.exists(output): + return ready_reply(True) + + return ready_reply(done, progress=progress) + + +def status(context: ServerContext, pool: DevicePoolExecutor): + return jsonify(pool.status()) + + +def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): + return [ + app.route("/api")(wrap_route(introspect, context, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), + app.route("/api/settings/models")(wrap_route(list_models, context)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), + app.route("/api/settings/params")(wrap_route(list_params, context)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)), + app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)), + app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)), + app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)), + app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)), + app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)), + app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)), + app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)), + app.route("/api/ready")(wrap_route(ready, context, pool=pool)), + app.route("/api/status")(wrap_route(status, context, pool=pool)), + ] diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py new file mode 100644 index 00000000..01110548 --- /dev/null +++ b/api/onnx_web/server/config.py @@ -0,0 +1,224 @@ +from functools import cmp_to_key +from glob import glob +from logging import getLogger +from os import path +from typing import Dict, List, Optional, Union + +import torch +import yaml +from onnxruntime import get_available_providers + +from .context import ServerContext +from ..image import ( # mask filters; noise sources + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, + noise_source_fill_edge, + noise_source_fill_mask, + noise_source_gaussian, + noise_source_histogram, + noise_source_normal, + noise_source_uniform, +) +from ..params import ( + DeviceParams, +) + +logger = getLogger(__name__) + +# config caching +config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} + +# pipeline params +platform_providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "directml": "DmlExecutionProvider", + "rocm": "ROCMExecutionProvider", +} +noise_sources = { + "fill-edge": noise_source_fill_edge, + "fill-mask": noise_source_fill_mask, + "gaussian": noise_source_gaussian, + "histogram": noise_source_histogram, + "normal": noise_source_normal, + "uniform": noise_source_uniform, +} +mask_filters = { + "none": mask_filter_none, + "gaussian-multiply": mask_filter_gaussian_multiply, + "gaussian-screen": mask_filter_gaussian_screen, +} + + +# Available ORT providers +available_platforms: List[DeviceParams] = [] + +# loaded from model_path +correction_models: List[str] = [] +diffusion_models: List[str] = [] +inversion_models: List[str] = [] +upscaling_models: List[str] = [] + + +def get_config_params(): + return config_params + + +def get_available_platforms(): + return available_platforms + + +def get_correction_models(): + return correction_models + + +def get_diffusion_models(): + return diffusion_models + + +def get_inversion_models(): + return inversion_models + + +def get_upscaling_models(): + return upscaling_models + + +def get_mask_filters(): + return mask_filters + + +def get_noise_sources(): + return noise_sources + + +def get_config_value(key: str, subkey: str = "default", default=None): + return config_params.get(key, {}).get(subkey, default) + + +def get_model_name(model: str) -> str: + base = path.basename(model) + (file, _ext) = path.splitext(base) + return file + + +def load_models(context: ServerContext) -> None: + global correction_models + global diffusion_models + global inversion_models + global upscaling_models + + diffusion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) + ] + diffusion_models.extend( + [ + get_model_name(f) + for f in glob(path.join(context.model_path, "stable-diffusion-*")) + ] + ) + diffusion_models = list(set(diffusion_models)) + diffusion_models.sort() + + correction_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) + ] + correction_models = list(set(correction_models)) + correction_models.sort() + + inversion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) + ] + inversion_models = list(set(inversion_models)) + inversion_models.sort() + + upscaling_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) + ] + upscaling_models = list(set(upscaling_models)) + upscaling_models.sort() + + +def load_params(context: ServerContext) -> None: + global config_params + params_file = path.join(context.params_path, "params.json") + with open(params_file, "r") as f: + config_params = yaml.safe_load(f) + + if "platform" in config_params and context.default_platform is not None: + logger.info( + "Overriding default platform from environment: %s", + context.default_platform, + ) + config_platform = config_params.get("platform", {}) + config_platform["default"] = context.default_platform + + +def load_platforms(context: ServerContext) -> None: + global available_platforms + + providers = list(get_available_providers()) + + for potential in platform_providers: + if ( + platform_providers[potential] in providers + and potential not in context.block_platforms + ): + if potential == "cuda": + for i in range(torch.cuda.device_count()): + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + { + "device_id": i, + }, + context.optimizations, + ) + ) + else: + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + None, + context.optimizations, + ) + ) + + if context.any_platform: + # the platform should be ignored when the job is scheduled, but set to CPU just in case + available_platforms.append( + DeviceParams( + "any", + platform_providers["cpu"], + None, + context.optimizations, + ) + ) + + # make sure CPU is last on the list + def any_first_cpu_last(a: DeviceParams, b: DeviceParams): + if a.device == b.device: + return 0 + + # any should be first, if it's available + if a.device == "any": + return -1 + + # cpu should be last, if it's available + if a.device == "cpu": + return 1 + + return -1 + + available_platforms = sorted( + available_platforms, key=cmp_to_key(any_first_cpu_last) + ) + + logger.info( + "available acceleration platforms: %s", + ", ".join([str(p) for p in available_platforms]), + ) + diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py new file mode 100644 index 00000000..b70ef79a --- /dev/null +++ b/api/onnx_web/server/params.py @@ -0,0 +1,183 @@ +from logging import getLogger +from typing import Tuple + +import numpy as np +from flask import request + +from .context import ServerContext + +from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models +from .utils import get_model_path + +from ..diffusion.load import pipeline_schedulers +from ..params import ( + Border, + DeviceParams, + ImageParams, + Size, + UpscaleParams, +) +from ..utils import ( + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_not_empty, +) + +logger = getLogger(__name__) + + +def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]: + user = request.remote_addr + + # platform stuff + device = None + device_name = request.args.get("platform") + + if device_name is not None and device_name != "any": + for platform in get_available_platforms(): + if platform.device == device_name: + device = platform + + # pipeline stuff + lpw = get_not_empty(request.args, "lpw", "false") == "true" + model = get_not_empty(request.args, "model", get_config_value("model")) + model_path = get_model_path(context, model) + scheduler = get_from_list( + request.args, "scheduler", pipeline_schedulers.keys() + ) + + if scheduler is None: + scheduler = get_config_value("scheduler") + + inversion = request.args.get("inversion", None) + inversion_path = None + if inversion is not None and inversion.strip() != "": + inversion_path = get_model_path(context, inversion) + + # image params + prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) + negative_prompt = request.args.get("negativePrompt", None) + + if negative_prompt is not None and negative_prompt.strip() == "": + negative_prompt = None + + batch = get_and_clamp_int( + request.args, + "batch", + get_config_value("batch"), + get_config_value("batch", "max"), + get_config_value("batch", "min"), + ) + cfg = get_and_clamp_float( + request.args, + "cfg", + get_config_value("cfg"), + get_config_value("cfg", "max"), + get_config_value("cfg", "min"), + ) + eta = get_and_clamp_float( + request.args, + "eta", + get_config_value("eta"), + get_config_value("eta", "max"), + get_config_value("eta", "min"), + ) + steps = get_and_clamp_int( + request.args, + "steps", + get_config_value("steps"), + get_config_value("steps", "max"), + get_config_value("steps", "min"), + ) + height = get_and_clamp_int( + request.args, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) + width = get_and_clamp_int( + request.args, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) + + seed = int(request.args.get("seed", -1)) + if seed == -1: + # this one can safely use np.random because it produces a single value + seed = np.random.randint(np.iinfo(np.int32).max) + + logger.info( + "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", + user, + steps, + scheduler, + model_path, + device or "any device", + width, + height, + cfg, + seed, + prompt, + ) + + params = ImageParams( + model_path, + scheduler, + prompt, + cfg, + steps, + seed, + eta=eta, + lpw=lpw, + negative_prompt=negative_prompt, + batch=batch, + inversion=inversion_path, + ) + size = Size(width, height) + return (device, params, size) + + +def border_from_request() -> Border: + left = get_and_clamp_int( + request.args, "left", 0, get_config_value("width", "max"), 0 + ) + right = get_and_clamp_int( + request.args, "right", 0, get_config_value("width", "max"), 0 + ) + top = get_and_clamp_int( + request.args, "top", 0, get_config_value("height", "max"), 0 + ) + bottom = get_and_clamp_int( + request.args, "bottom", 0, get_config_value("height", "max"), 0 + ) + + return Border(left, right, top, bottom) + + +def upscale_from_request() -> UpscaleParams: + denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) + scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) + outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) + upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) + correction = get_from_list(request.args, "correction", get_correction_models()) + faces = get_not_empty(request.args, "faces", "false") == "true" + face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) + face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) + upscale_order = request.args.get("upscaleOrder", "correction-first") + + return UpscaleParams( + upscaling, + correction_model=correction, + denoise=denoise, + faces=faces, + face_outscale=face_outscale, + face_strength=face_strength, + format="onnx", + outscale=outscale, + scale=scale, + upscale_order=upscale_order, + ) diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py new file mode 100644 index 00000000..296c8deb --- /dev/null +++ b/api/onnx_web/server/static.py @@ -0,0 +1,34 @@ +from os import path + +from flask import Flask, send_from_directory + +from .utils import wrap_route +from .context import ServerContext +from ..worker.pool import DevicePoolExecutor + + +def serve_bundle_file(context: ServerContext, filename="index.html"): + return send_from_directory(path.join("..", context.bundle_path), filename) + + +# non-API routes +def index(context: ServerContext): + return serve_bundle_file(context) + + +def index_path(context: ServerContext, filename: str): + return serve_bundle_file(context, filename) + + +def output(context: ServerContext, filename: str): + return send_from_directory( + path.join("..", context.output_path), filename, as_attachment=False + ) + + +def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): + return [ + app.route("/")(wrap_route(index, context)), + app.route("/")(wrap_route(index_path, context)), + app.route("/output/")(wrap_route(output, context)), + ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py new file mode 100644 index 00000000..8dd359a1 --- /dev/null +++ b/api/onnx_web/server/utils.py @@ -0,0 +1,32 @@ +from os import makedirs, path +from typing import Callable, Dict, List, Tuple +from functools import partial, update_wrapper + +from flask import Flask + +from onnx_web.utils import base_join +from onnx_web.worker.pool import DevicePoolExecutor + +from .context import ServerContext + + +def check_paths(context: ServerContext) -> None: + if not path.exists(context.model_path): + raise RuntimeError("model path must exist") + + if not path.exists(context.output_path): + makedirs(context.output_path) + + +def get_model_path(context: ServerContext, model: str): + return base_join(context.model_path, model) + + +def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]): + pass + + +def wrap_route(func, *args, **kwargs): + partial_func = partial(func, *args, **kwargs) + update_wrapper(partial_func, func) + return partial_func