From 3cb527c2b88981f080c9df1bd07142f585b04e80 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 4 Jan 2023 19:42:37 -0600 Subject: [PATCH] some options, configurable paths --- api/serve.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/api/serve.py b/api/serve.py index de00f4e6..66426c33 100644 --- a/api/serve.py +++ b/api/serve.py @@ -1,24 +1,48 @@ from diffusers import OnnxStableDiffusionPipeline -from flask import Flask +from flask import Flask, request, send_file +from io import BytesIO +from os import environ, path, makedirs +# defaults +empty_prompt = "a photo of an astronaut eating a hamburger" max_height = 512 max_width = 512 +max_steps = 50 +max_cfg = 8 +# paths +model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx") +output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output") + +def setup(): + if not path.exists(model_path): + raise RuntimeError('model path must exist') + if not path.exists(output_path): + makedirs(output_path) + +# setup +setup() app = Flask(__name__) -pipe = OnnxStableDiffusionPipeline.from_pretrained("./stable_diffusion_onnx", provider="DmlExecutionProvider", safety_checker=None) - +pipe = OnnxStableDiffusionPipeline.from_pretrained(model_path, provider="DmlExecutionProvider", safety_checker=None) +# routes @app.route('/') def hello(): - return 'Hello, World!' + return 'Hello, %s' % (__name__) @app.route('/txt2img') def txt2img(): + prompt = request.args.get('prompt', empty_prompt) height = request.args.get('height', max_height) width = request.args.get('width', max_width) - prompt = request.args.get('prompt', "a photo of an astronaut eating a hamburger") - steps = 50 - cfg = 8 + steps = int(request.args.get('steps', max_steps)) + cfg = int(request.args.get('cfg', max_cfg)) + print("txt2img: %s/%s, %sx%s, %s" % (cfg, steps, width, height, prompt)) image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0] - image.save("astronaut_rides_horse.png") + # image.save("astronaut_rides_horse.png") + + img_io = BytesIO() + image.save(img_io, 'PNG', quality=100) + img_io.seek(0) + return send_file(img_io, mimetype='image/png') \ No newline at end of file