2023-01-05 00:25:00 +00:00
|
|
|
from diffusers import OnnxStableDiffusionPipeline
|
2023-01-05 05:39:50 +00:00
|
|
|
from diffusers import (
|
|
|
|
DDIMScheduler,
|
2023-01-05 23:23:37 +00:00
|
|
|
DDPMScheduler,
|
|
|
|
DPMSolverMultistepScheduler,
|
2023-01-05 05:39:50 +00:00
|
|
|
EulerDiscreteScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
2023-01-05 23:23:37 +00:00
|
|
|
LMSDiscreteScheduler,
|
|
|
|
PNDMScheduler,
|
2023-01-05 05:39:50 +00:00
|
|
|
)
|
2023-01-05 23:24:33 +00:00
|
|
|
from flask import Flask, make_response, request, send_file, send_from_directory
|
2023-01-05 05:44:16 +00:00
|
|
|
from stringcase import spinalcase
|
2023-01-05 01:42:37 +00:00
|
|
|
from io import BytesIO
|
|
|
|
from os import environ, path, makedirs
|
2023-01-05 06:44:28 +00:00
|
|
|
import numpy as np
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# defaults
|
2023-01-05 05:59:45 +00:00
|
|
|
default_prompt = "a photo of an astronaut eating a hamburger"
|
2023-01-05 06:44:28 +00:00
|
|
|
default_cfg = 8
|
|
|
|
default_steps = 20
|
2023-01-05 05:59:45 +00:00
|
|
|
default_height = 512
|
|
|
|
default_width = 512
|
|
|
|
|
2023-01-05 06:44:28 +00:00
|
|
|
max_cfg = 30
|
|
|
|
max_steps = 150
|
2023-01-05 00:25:00 +00:00
|
|
|
max_height = 512
|
|
|
|
max_width = 512
|
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# paths
|
2023-01-05 21:48:23 +00:00
|
|
|
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models/stable-diffusion-onnx-v1-5")
|
|
|
|
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
2023-01-05 01:42:37 +00:00
|
|
|
|
2023-01-05 23:24:14 +00:00
|
|
|
# platforms
|
|
|
|
platform_providers = {
|
|
|
|
'amd': 'DmlExecutionProvider',
|
|
|
|
'cpu': 'CPUExecutionProvider',
|
|
|
|
}
|
|
|
|
|
2023-01-05 05:39:50 +00:00
|
|
|
# schedulers
|
2023-01-05 23:23:37 +00:00
|
|
|
pipeline_schedulers = {
|
2023-01-05 17:19:42 +00:00
|
|
|
'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
2023-01-05 23:23:37 +00:00
|
|
|
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
2023-01-05 17:19:42 +00:00
|
|
|
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
2023-01-05 23:23:37 +00:00
|
|
|
'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
|
|
|
'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
|
|
|
'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
|
|
|
'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
2023-01-05 05:39:50 +00:00
|
|
|
}
|
|
|
|
|
2023-01-05 17:19:42 +00:00
|
|
|
|
2023-01-05 05:59:45 +00:00
|
|
|
def get_and_clamp(args, key, default_value, max_value, min_value=1):
|
2023-01-05 17:19:42 +00:00
|
|
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
2023-01-05 05:59:45 +00:00
|
|
|
|
|
|
|
def get_from_map(args, key, values, default):
|
2023-01-05 17:19:42 +00:00
|
|
|
selected = args.get(key, default)
|
|
|
|
if selected in values:
|
|
|
|
return values[selected]
|
|
|
|
else:
|
|
|
|
return values[default]
|
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 06:44:28 +00:00
|
|
|
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
|
|
|
# 1 is batch size
|
2023-01-05 17:19:42 +00:00
|
|
|
latents_shape = (1, 4, height // 8, width // 8)
|
2023-01-05 06:44:28 +00:00
|
|
|
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
2023-01-05 17:19:42 +00:00
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
|
|
|
return image_latents
|
|
|
|
|
2023-01-05 06:44:28 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# setup
|
2023-01-05 05:59:45 +00:00
|
|
|
if not path.exists(model_path):
|
2023-01-05 17:19:42 +00:00
|
|
|
raise RuntimeError('model path must exist')
|
2023-01-05 05:59:45 +00:00
|
|
|
|
|
|
|
if not path.exists(output_path):
|
2023-01-05 17:19:42 +00:00
|
|
|
makedirs(output_path)
|
2023-01-05 05:59:45 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
app = Flask(__name__)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# routes
|
2023-01-05 17:19:42 +00:00
|
|
|
|
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
@app.route('/')
|
|
|
|
def hello():
|
2023-01-05 01:42:37 +00:00
|
|
|
return 'Hello, %s' % (__name__)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 17:19:42 +00:00
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
@app.route('/txt2img')
|
|
|
|
def txt2img():
|
2023-01-05 17:19:42 +00:00
|
|
|
user = request.remote_addr
|
|
|
|
|
|
|
|
prompt = request.args.get('prompt', default_prompt)
|
2023-01-05 23:24:14 +00:00
|
|
|
provider = get_from_map(request.args, 'provider', platform_providers, 'amd')
|
2023-01-05 17:19:42 +00:00
|
|
|
scheduler = get_from_map(request.args, 'scheduler',
|
2023-01-05 23:23:37 +00:00
|
|
|
pipeline_schedulers, 'euler-a')
|
2023-01-05 17:19:42 +00:00
|
|
|
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0)
|
|
|
|
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
|
|
|
|
height = get_and_clamp(request.args, 'height', default_height, max_height)
|
|
|
|
width = get_and_clamp(request.args, 'width', default_width, max_width)
|
|
|
|
|
|
|
|
seed = int(request.args.get('seed', -1))
|
|
|
|
if seed == -1:
|
|
|
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
|
|
|
|
|
|
|
latents = get_latents_from_seed(seed, width, height)
|
|
|
|
|
|
|
|
print("txt2img from %s: %s/%s, %sx%s, %s, %s" %
|
|
|
|
(user, cfg, steps, width, height, seed, prompt))
|
|
|
|
|
|
|
|
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
|
|
|
model_path,
|
2023-01-05 23:24:14 +00:00
|
|
|
provider=provider,
|
2023-01-05 17:19:42 +00:00
|
|
|
safety_checker=None,
|
|
|
|
scheduler=scheduler
|
|
|
|
)
|
|
|
|
image = pipe(
|
|
|
|
prompt,
|
|
|
|
height,
|
|
|
|
width,
|
|
|
|
num_inference_steps=steps,
|
|
|
|
guidance_scale=cfg,
|
|
|
|
latents=latents
|
|
|
|
).images[0]
|
|
|
|
|
|
|
|
output = '%s/txt2img_%s_%s.png' % (output_path,
|
|
|
|
seed, spinalcase(prompt[0:64]))
|
|
|
|
print("txt2img output: %s" % (output))
|
|
|
|
image.save(output)
|
|
|
|
|
|
|
|
img_io = BytesIO()
|
|
|
|
image.save(img_io, 'PNG', quality=100)
|
|
|
|
img_io.seek(0)
|
|
|
|
|
|
|
|
res = make_response(send_file(img_io, mimetype='image/png'))
|
|
|
|
res.headers.add('Access-Control-Allow-Origin', '*')
|
|
|
|
return res
|
2023-01-05 23:24:33 +00:00
|
|
|
|
|
|
|
@app.route('/output/<path:filename>')
|
|
|
|
def output(filename):
|
|
|
|
return send_from_directory(output_path, filename, as_attachment=False)
|