2023-01-05 00:25:00 +00:00
|
|
|
from diffusers import OnnxStableDiffusionPipeline
|
2023-01-05 04:54:17 +00:00
|
|
|
from flask import Flask, make_response, request, send_file
|
2023-01-05 01:42:37 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from os import environ, path, makedirs
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# defaults
|
|
|
|
empty_prompt = "a photo of an astronaut eating a hamburger"
|
2023-01-05 00:25:00 +00:00
|
|
|
max_height = 512
|
|
|
|
max_width = 512
|
2023-01-05 01:42:37 +00:00
|
|
|
max_steps = 50
|
|
|
|
max_cfg = 8
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# paths
|
|
|
|
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
|
|
|
|
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
|
|
|
|
|
2023-01-05 01:51:09 +00:00
|
|
|
# queue
|
|
|
|
image_queue = set()
|
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
def setup():
|
|
|
|
if not path.exists(model_path):
|
|
|
|
raise RuntimeError('model path must exist')
|
|
|
|
if not path.exists(output_path):
|
|
|
|
makedirs(output_path)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# setup
|
|
|
|
setup()
|
|
|
|
app = Flask(__name__)
|
|
|
|
pipe = OnnxStableDiffusionPipeline.from_pretrained(model_path, provider="DmlExecutionProvider", safety_checker=None)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# routes
|
2023-01-05 00:25:00 +00:00
|
|
|
@app.route('/')
|
|
|
|
def hello():
|
2023-01-05 01:42:37 +00:00
|
|
|
return 'Hello, %s' % (__name__)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
|
|
|
@app.route('/txt2img')
|
|
|
|
def txt2img():
|
2023-01-05 01:51:09 +00:00
|
|
|
if len(image_queue) > 0:
|
|
|
|
return 'Queue full: %s' % (image_queue)
|
|
|
|
|
2023-01-05 04:20:56 +00:00
|
|
|
user = request.remote_addr
|
2023-01-05 01:42:37 +00:00
|
|
|
prompt = request.args.get('prompt', empty_prompt)
|
2023-01-05 00:25:00 +00:00
|
|
|
height = request.args.get('height', max_height)
|
|
|
|
width = request.args.get('width', max_width)
|
2023-01-05 01:42:37 +00:00
|
|
|
steps = int(request.args.get('steps', max_steps))
|
|
|
|
cfg = int(request.args.get('cfg', max_cfg))
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 04:20:56 +00:00
|
|
|
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
|
|
|
|
image_queue.add(user)
|
2023-01-05 01:51:09 +00:00
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
|
2023-01-05 01:42:37 +00:00
|
|
|
# image.save("astronaut_rides_horse.png")
|
|
|
|
|
|
|
|
img_io = BytesIO()
|
|
|
|
image.save(img_io, 'PNG', quality=100)
|
|
|
|
img_io.seek(0)
|
2023-01-05 01:51:09 +00:00
|
|
|
|
2023-01-05 04:20:56 +00:00
|
|
|
image_queue.remove(user)
|
|
|
|
|
2023-01-05 04:54:17 +00:00
|
|
|
res = make_response(send_file(img_io, mimetype='image/png'))
|
|
|
|
res.headers.add('Access-Control-Allow-Origin', '*')
|
|
|
|
return res
|