From b62c7d3742a014f44ce199b78b6f70dc8e29844b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 22 Jan 2023 22:38:03 -0600 Subject: [PATCH] fix(api): return structured error when image parameters are missing (fixes #76) --- api/onnx_web/serve.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index d3c31539..fc25c393 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -13,7 +13,7 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from flask import Flask, jsonify, request, send_from_directory, url_for +from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask_cors import CORS from flask_executor import Executor from glob import glob @@ -292,6 +292,20 @@ def get_model_path(model: str): return safer_join(context.model_path, model) +def ready_reply(ready: bool): + return jsonify({ + 'ready': ready, + }) + + +def error_reply(err: str): + response = make_response(jsonify({ + 'error': err, + })) + response.status_code = 400 + return response + + def serve_bundle_file(filename='index.html'): return send_from_directory(path.join('..', context.bundle_path), filename) @@ -356,6 +370,9 @@ def list_schedulers(): @app.route('/api/img2img', methods=['POST']) def img2img(): + if 'source' not in request.files: + return error_reply('source image is required') + source_file = request.files.get('source') source_image = Image.open(BytesIO(source_file.read())).convert('RGB') @@ -410,6 +427,12 @@ def txt2img(): @app.route('/api/inpaint', methods=['POST']) def inpaint(): + if 'source' not in request.files: + return error_reply('source image is required') + + if 'mask' not in request.files: + return error_reply('mask image is required') + source_file = request.files.get('source') source_image = Image.open(BytesIO(source_file.read())).convert('RGB') @@ -475,6 +498,9 @@ def inpaint(): @app.route('/api/upscale', methods=['POST']) def upscale(): + if 'source' not in request.files: + return error_reply('source image is required') + source_file = request.files.get('source') source_image = Image.open(BytesIO(source_file.read())).convert('RGB') @@ -507,15 +533,12 @@ def ready(): if done is None: file = safer_join(context.output_path, output_file) if path.exists(file): - return jsonify({ - 'ready': True, - }) + return ready_reply(True) + elif done == True: executor.futures.pop(output_file) - return jsonify({ - 'ready': done, - }) + return ready_reply(done) @app.route('/output/')