1
0
Fork 0

fix(api): send CORS more consistently

This commit is contained in:
Sean Sube 2023-01-14 10:18:53 -06:00
parent 56ac6c6bc7
commit fa82ac18ab
3 changed files with 15 additions and 18 deletions

View File

@ -169,10 +169,10 @@ Install the following packages for AI:
> pip install accelerate diffusers ftfy onnx onnxruntime spacy scipy transformers > pip install accelerate diffusers ftfy onnx onnxruntime spacy scipy transformers
``` ```
Install the following packages for the web UI: Install the following packages for the API:
```shell ```shell
> pip install flask > pip install flask flask-cors flask_executor
``` ```
_Or_ install all of these packages at once using [the `requirements.txt` file](./api/requirements.txt): _Or_ install all of these packages at once using [the `requirements.txt` file](./api/requirements.txt):

View File

@ -20,6 +20,7 @@ from diffusers import (
DiffusionPipeline, DiffusionPipeline,
) )
from flask import Flask, jsonify, request, send_file, send_from_directory, url_for from flask import Flask, jsonify, request, send_file, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor from flask_executor import Executor
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
@ -55,7 +56,8 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs'))
params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json') params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json')
# options # options
num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 2)) cors_origin = environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(',')
num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))
# pipeline caching # pipeline caching
available_models = [] available_models = []
@ -146,13 +148,6 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
return pipe return pipe
def json_with_cors(data, origin='*'):
"""Build a JSON response with CORS headers allowing `origin`"""
res = jsonify(data)
res.access_control_allow_origin = origin
return res
def serve_bundle_file(filename='index.html'): def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename) return send_from_directory(path.join('..', bundle_path), filename)
@ -323,6 +318,7 @@ load_params()
app = Flask(__name__) app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = num_workers app.config['EXECUTOR_MAX_WORKERS'] = num_workers
CORS(app, origins=cors_origin)
executor = Executor(app) executor = Executor(app)
# routes # routes
@ -351,22 +347,22 @@ def introspect():
@app.route('/api/settings/models') @app.route('/api/settings/models')
def list_models(): def list_models():
return json_with_cors(available_models) return jsonify(available_models)
@app.route('/api/settings/params') @app.route('/api/settings/params')
def list_params(): def list_params():
return json_with_cors(config_params) return jsonify(config_params)
@app.route('/api/settings/platforms') @app.route('/api/settings/platforms')
def list_platforms(): def list_platforms():
return json_with_cors(list(platform_providers.keys())) return jsonify(list(platform_providers.keys()))
@app.route('/api/settings/schedulers') @app.route('/api/settings/schedulers')
def list_schedulers(): def list_schedulers():
return json_with_cors(list(pipeline_schedulers.keys())) return jsonify(list(pipeline_schedulers.keys()))
@app.route('/api/img2img', methods=['POST']) @app.route('/api/img2img', methods=['POST'])
@ -387,7 +383,7 @@ def img2img():
executor.submit_stored(output_file, run_img2img_pipeline, model, provider, executor.submit_stored(output_file, run_img2img_pipeline, model, provider,
scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image) scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image)
return json_with_cors({ return jsonify({
'output': output_file, 'output': output_file,
'params': { 'params': {
'model': model, 'model': model,
@ -416,7 +412,7 @@ def txt2img():
executor.submit_stored(output_file, run_txt2img_pipeline, model, executor.submit_stored(output_file, run_txt2img_pipeline, model,
provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, height, width) provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, height, width)
return json_with_cors({ return jsonify({
'output': output_file, 'output': output_file,
'params': { 'params': {
'model': model, 'model': model,
@ -453,7 +449,7 @@ def inpaint():
executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt, executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt,
cfg, steps, seed, output_full, height, width, source_image, mask_image) cfg, steps, seed, output_full, height, width, source_image, mask_image)
return json_with_cors({ return jsonify({
'output': output_file, 'output': output_file,
'params': { 'params': {
'model': model, 'model': model,
@ -474,7 +470,7 @@ def inpaint():
def ready(): def ready():
output_file = request.args.get('output', None) output_file = request.args.get('output', None)
return json_with_cors({ return jsonify({
'ready': executor.futures.done(output_file), 'ready': executor.futures.done(output_file),
}) })

View File

@ -9,4 +9,5 @@ transformers
### Server packages ### ### Server packages ###
flask flask
flask-cors
flask_executor flask_executor