diff --git a/README.md b/README.md index 7d670d64..f595b70b 100644 --- a/README.md +++ b/README.md @@ -169,10 +169,10 @@ Install the following packages for AI: > pip install accelerate diffusers ftfy onnx onnxruntime spacy scipy transformers ``` -Install the following packages for the web UI: +Install the following packages for the API: ```shell -> pip install flask +> pip install flask flask-cors flask_executor ``` _Or_ install all of these packages at once using [the `requirements.txt` file](./api/requirements.txt): diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 0abe91b8..aacc120d 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -20,6 +20,7 @@ from diffusers import ( DiffusionPipeline, ) from flask import Flask, jsonify, request, send_file, send_from_directory, url_for +from flask_cors import CORS from flask_executor import Executor from hashlib import sha256 from io import BytesIO @@ -55,7 +56,8 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')) params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json') # options -num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 2)) +cors_origin = environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(',') +num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1)) # pipeline caching available_models = [] @@ -146,13 +148,6 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu return pipe -def json_with_cors(data, origin='*'): - """Build a JSON response with CORS headers allowing `origin`""" - res = jsonify(data) - res.access_control_allow_origin = origin - return res - - def serve_bundle_file(filename='index.html'): return send_from_directory(path.join('..', bundle_path), filename) @@ -323,6 +318,7 @@ load_params() app = Flask(__name__) app.config['EXECUTOR_MAX_WORKERS'] = num_workers +CORS(app, origins=cors_origin) executor = Executor(app) # routes @@ -351,22 +347,22 @@ def introspect(): @app.route('/api/settings/models') def list_models(): - return json_with_cors(available_models) + return jsonify(available_models) @app.route('/api/settings/params') def list_params(): - return json_with_cors(config_params) + return jsonify(config_params) @app.route('/api/settings/platforms') def list_platforms(): - return json_with_cors(list(platform_providers.keys())) + return jsonify(list(platform_providers.keys())) @app.route('/api/settings/schedulers') def list_schedulers(): - return json_with_cors(list(pipeline_schedulers.keys())) + return jsonify(list(pipeline_schedulers.keys())) @app.route('/api/img2img', methods=['POST']) @@ -387,7 +383,7 @@ def img2img(): executor.submit_stored(output_file, run_img2img_pipeline, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image) - return json_with_cors({ + return jsonify({ 'output': output_file, 'params': { 'model': model, @@ -416,7 +412,7 @@ def txt2img(): executor.submit_stored(output_file, run_txt2img_pipeline, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, height, width) - return json_with_cors({ + return jsonify({ 'output': output_file, 'params': { 'model': model, @@ -453,7 +449,7 @@ def inpaint(): executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, height, width, source_image, mask_image) - return json_with_cors({ + return jsonify({ 'output': output_file, 'params': { 'model': model, @@ -474,7 +470,7 @@ def inpaint(): def ready(): output_file = request.args.get('output', None) - return json_with_cors({ + return jsonify({ 'ready': executor.futures.done(output_file), }) diff --git a/api/requirements.txt b/api/requirements.txt index bf170901..d63dd9e7 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -9,4 +9,5 @@ transformers ### Server packages ### flask +flask-cors flask_executor \ No newline at end of file