From ea9f929bf18cf8f999402b338bfc09756e560bbb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 4 Jan 2023 23:59:45 -0600 Subject: [PATCH] clamp API params, limit filename length, remove queueing thing --- api/serve.py | 56 +++++++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/api/serve.py b/api/serve.py index 6bd4fe6a..9be67436 100644 --- a/api/serve.py +++ b/api/serve.py @@ -14,19 +14,21 @@ from io import BytesIO from os import environ, path, makedirs # defaults -empty_prompt = "a photo of an astronaut eating a hamburger" +default_prompt = "a photo of an astronaut eating a hamburger" +default_height = 512 +default_width = 512 +default_steps = 20 +default_cfg = 8 + max_height = 512 max_width = 512 -max_steps = 50 -max_cfg = 8 +max_steps = 150 +max_cfg = 30 # paths model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx") output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output") -# queue -image_queue = set() - # schedulers scheduler_list = { 'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"), @@ -38,14 +40,23 @@ scheduler_list = { 'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"), } -def setup(): - if not path.exists(model_path): - raise RuntimeError('model path must exist') - if not path.exists(output_path): - makedirs(output_path) +def get_and_clamp(args, key, default_value, max_value, min_value=1): + return min(max(int(args.get(key, default_value)), min_value), max_value) + +def get_from_map(args, key, values, default): + selected = args.get(key, default) + if selected in values: + return values[selected] + else: + return values[default] # setup -setup() +if not path.exists(model_path): + raise RuntimeError('model path must exist') + +if not path.exists(output_path): + makedirs(output_path) + app = Flask(__name__) # routes @@ -55,19 +66,16 @@ def hello(): @app.route('/txt2img') def txt2img(): - if len(image_queue) > 0: - return 'Queue full: %s' % (image_queue) - user = request.remote_addr - prompt = request.args.get('prompt', empty_prompt) - height = request.args.get('height', max_height) - width = request.args.get('width', max_width) - steps = int(request.args.get('steps', max_steps)) - cfg = int(request.args.get('cfg', max_cfg)) - scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')] + + cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg) + height = get_and_clamp(request.args, 'height', default_height, max_height) + prompt = request.args.get('prompt', default_prompt) + steps = get_and_clamp(request.args, 'steps', default_steps, max_steps) + scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a') + width = get_and_clamp(request.args, 'width', default_width, max_width) print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt)) - image_queue.add(user) pipe = OnnxStableDiffusionPipeline.from_pretrained( model_path, @@ -83,7 +91,7 @@ def txt2img(): guidance_scale=cfg ).images[0] - output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt)) + output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64])) print("txt2img output: %s" % (output)) image.save(output) @@ -91,8 +99,6 @@ def txt2img(): image.save(img_io, 'PNG', quality=100) img_io.seek(0) - image_queue.remove(user) - res = make_response(send_file(img_io, mimetype='image/png')) res.headers.add('Access-Control-Allow-Origin', '*') return res \ No newline at end of file