2023-01-05 00:25:00 +00:00
|
|
|
from diffusers import OnnxStableDiffusionPipeline
|
2023-01-05 05:39:50 +00:00
|
|
|
from diffusers import (
|
|
|
|
DDIMScheduler,
|
2023-01-05 23:23:37 +00:00
|
|
|
DDPMScheduler,
|
|
|
|
DPMSolverMultistepScheduler,
|
2023-01-05 05:39:50 +00:00
|
|
|
EulerDiscreteScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
2023-01-05 23:23:37 +00:00
|
|
|
LMSDiscreteScheduler,
|
|
|
|
PNDMScheduler,
|
2023-01-05 05:39:50 +00:00
|
|
|
)
|
2023-01-06 03:13:45 +00:00
|
|
|
from flask import Flask, jsonify, request, send_from_directory, url_for
|
2023-01-05 05:44:16 +00:00
|
|
|
from stringcase import spinalcase
|
2023-01-06 04:50:30 +00:00
|
|
|
from os import environ, makedirs, path, scandir
|
2023-01-05 06:44:28 +00:00
|
|
|
import numpy as np
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# defaults
|
2023-01-05 05:59:45 +00:00
|
|
|
default_prompt = "a photo of an astronaut eating a hamburger"
|
2023-01-05 06:44:28 +00:00
|
|
|
default_cfg = 8
|
|
|
|
default_steps = 20
|
2023-01-05 05:59:45 +00:00
|
|
|
default_height = 512
|
|
|
|
default_width = 512
|
|
|
|
|
2023-01-05 06:44:28 +00:00
|
|
|
max_cfg = 30
|
|
|
|
max_steps = 150
|
2023-01-05 00:25:00 +00:00
|
|
|
max_height = 512
|
|
|
|
max_width = 512
|
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# paths
|
2023-01-06 04:50:30 +00:00
|
|
|
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models")
|
2023-01-05 21:48:23 +00:00
|
|
|
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
2023-01-05 01:42:37 +00:00
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
|
|
|
|
# pipeline caching
|
|
|
|
available_models = []
|
2023-01-06 05:28:25 +00:00
|
|
|
last_pipeline_instance = None
|
|
|
|
last_pipeline_options = (None, None)
|
|
|
|
last_pipeline_scheduler = None
|
2023-01-06 04:50:30 +00:00
|
|
|
|
|
|
|
# pipeline params
|
2023-01-05 23:24:14 +00:00
|
|
|
platform_providers = {
|
|
|
|
'amd': 'DmlExecutionProvider',
|
|
|
|
'cpu': 'CPUExecutionProvider',
|
|
|
|
}
|
2023-01-05 23:23:37 +00:00
|
|
|
pipeline_schedulers = {
|
2023-01-06 04:50:30 +00:00
|
|
|
'ddim': DDIMScheduler,
|
|
|
|
'ddpm': DDPMScheduler,
|
|
|
|
'dpm-multi': DPMSolverMultistepScheduler,
|
|
|
|
'euler': EulerDiscreteScheduler,
|
|
|
|
'euler-a': EulerAncestralDiscreteScheduler,
|
|
|
|
'lms-discrete': LMSDiscreteScheduler,
|
|
|
|
'pndm': PNDMScheduler,
|
2023-01-05 05:39:50 +00:00
|
|
|
}
|
|
|
|
|
2023-01-05 17:19:42 +00:00
|
|
|
|
2023-01-05 05:59:45 +00:00
|
|
|
def get_and_clamp(args, key, default_value, max_value, min_value=1):
|
2023-01-05 17:19:42 +00:00
|
|
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
2023-01-05 05:59:45 +00:00
|
|
|
|
|
|
|
def get_from_map(args, key, values, default):
|
2023-01-05 17:19:42 +00:00
|
|
|
selected = args.get(key, default)
|
|
|
|
if selected in values:
|
|
|
|
return values[selected]
|
|
|
|
else:
|
|
|
|
return values[default]
|
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
# from https://www.travelneil.com/stable-diffusion-updates.html
|
2023-01-05 06:44:28 +00:00
|
|
|
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
|
|
|
# 1 is batch size
|
2023-01-05 17:19:42 +00:00
|
|
|
latents_shape = (1, 4, height // 8, width // 8)
|
2023-01-05 06:44:28 +00:00
|
|
|
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
2023-01-05 17:19:42 +00:00
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
|
|
|
return image_latents
|
|
|
|
|
2023-01-05 06:44:28 +00:00
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
def load_pipeline(model, provider, scheduler):
|
2023-01-06 05:28:25 +00:00
|
|
|
global last_pipeline_instance
|
|
|
|
global last_pipeline_scheduler
|
|
|
|
global last_pipeline_options
|
2023-01-06 04:50:30 +00:00
|
|
|
|
2023-01-06 05:28:25 +00:00
|
|
|
options = (model, provider)
|
|
|
|
if last_pipeline_instance != None and last_pipeline_options == options:
|
2023-01-06 04:50:30 +00:00
|
|
|
print('reusing existing pipeline')
|
2023-01-06 05:28:25 +00:00
|
|
|
pipe = last_pipeline_instance
|
|
|
|
else:
|
|
|
|
print('loading different pipeline')
|
|
|
|
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
|
|
|
model,
|
|
|
|
provider=provider,
|
|
|
|
safety_checker=None,
|
|
|
|
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
|
|
|
)
|
|
|
|
last_pipeline_instance = pipe
|
|
|
|
last_pipeline_options = options
|
|
|
|
last_pipeline_scheduler = scheduler
|
|
|
|
|
|
|
|
if last_pipeline_scheduler != scheduler:
|
|
|
|
print('changing pipeline scheduler')
|
|
|
|
pipe.scheduler = scheduler.from_pretrained(model, subfolder="scheduler")
|
|
|
|
last_pipeline_scheduler = scheduler
|
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
return pipe
|
|
|
|
|
|
|
|
|
2023-01-06 03:54:40 +00:00
|
|
|
def json_with_cors(data, origin='*'):
|
2023-01-06 04:50:30 +00:00
|
|
|
"""Build a JSON response with CORS headers allowing `origin`"""
|
2023-01-06 03:54:40 +00:00
|
|
|
res = jsonify(data)
|
|
|
|
res.access_control_allow_origin = origin
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
2023-01-06 05:28:25 +00:00
|
|
|
def safer_join(base, tail):
|
|
|
|
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
|
|
|
return path.join(base, safer_path)
|
|
|
|
|
|
|
|
|
2023-01-06 03:13:45 +00:00
|
|
|
def url_from_rule(rule):
|
|
|
|
options = {}
|
|
|
|
for arg in rule.arguments:
|
|
|
|
options[arg] = ":%s" % (arg)
|
|
|
|
|
|
|
|
return url_for(rule.endpoint, **options)
|
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# setup
|
2023-01-05 05:59:45 +00:00
|
|
|
|
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
def check_paths():
|
|
|
|
if not path.exists(model_path):
|
|
|
|
raise RuntimeError('model path must exist')
|
|
|
|
|
|
|
|
if not path.exists(output_path):
|
|
|
|
makedirs(output_path)
|
|
|
|
|
|
|
|
|
|
|
|
def load_models():
|
|
|
|
global available_models
|
|
|
|
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
|
|
|
|
|
|
|
|
|
|
|
|
check_paths()
|
|
|
|
load_models()
|
2023-01-05 01:42:37 +00:00
|
|
|
app = Flask(__name__)
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 01:42:37 +00:00
|
|
|
# routes
|
2023-01-05 17:19:42 +00:00
|
|
|
|
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
@app.route('/')
|
2023-01-06 03:54:40 +00:00
|
|
|
def index():
|
2023-01-06 03:13:45 +00:00
|
|
|
return {
|
|
|
|
'name': 'onnx-web',
|
|
|
|
'routes': [{
|
|
|
|
'path': url_from_rule(rule),
|
2023-01-06 17:00:20 +00:00
|
|
|
'methods': list(rule.methods).sort()
|
2023-01-06 03:13:45 +00:00
|
|
|
} for rule in app.url_map.iter_rules()]
|
|
|
|
}
|
2023-01-05 00:25:00 +00:00
|
|
|
|
2023-01-05 17:19:42 +00:00
|
|
|
|
2023-01-06 04:01:58 +00:00
|
|
|
@app.route('/settings/models')
|
|
|
|
def list_models():
|
2023-01-06 04:50:30 +00:00
|
|
|
return json_with_cors(available_models)
|
2023-01-06 04:01:58 +00:00
|
|
|
|
|
|
|
|
2023-01-06 03:54:40 +00:00
|
|
|
@app.route('/settings/platforms')
|
|
|
|
def list_platforms():
|
|
|
|
return json_with_cors(list(platform_providers.keys()))
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/settings/schedulers')
|
|
|
|
def list_schedulers():
|
|
|
|
return json_with_cors(list(pipeline_schedulers.keys()))
|
|
|
|
|
|
|
|
|
2023-01-05 00:25:00 +00:00
|
|
|
@app.route('/txt2img')
|
|
|
|
def txt2img():
|
2023-01-05 17:19:42 +00:00
|
|
|
user = request.remote_addr
|
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
# pipeline stuff
|
2023-01-06 05:28:25 +00:00
|
|
|
model = safer_join(model_path, request.args.get('model', 'stable-diffusion-onnx-v1-5'))
|
2023-01-06 04:50:30 +00:00
|
|
|
provider = get_from_map(request.args, 'platform',
|
|
|
|
platform_providers, 'amd')
|
2023-01-05 17:19:42 +00:00
|
|
|
scheduler = get_from_map(request.args, 'scheduler',
|
2023-01-05 23:23:37 +00:00
|
|
|
pipeline_schedulers, 'euler-a')
|
2023-01-06 04:50:30 +00:00
|
|
|
|
|
|
|
# image params
|
|
|
|
prompt = request.args.get('prompt', default_prompt)
|
2023-01-05 17:19:42 +00:00
|
|
|
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0)
|
|
|
|
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
|
|
|
|
height = get_and_clamp(request.args, 'height', default_height, max_height)
|
|
|
|
width = get_and_clamp(request.args, 'width', default_width, max_width)
|
|
|
|
|
|
|
|
seed = int(request.args.get('seed', -1))
|
|
|
|
if seed == -1:
|
|
|
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
|
|
|
|
|
|
|
latents = get_latents_from_seed(seed, width, height)
|
|
|
|
|
2023-01-06 05:28:25 +00:00
|
|
|
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
|
|
|
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
2023-01-05 17:19:42 +00:00
|
|
|
|
2023-01-06 04:50:30 +00:00
|
|
|
pipe = load_pipeline(model, provider, scheduler)
|
2023-01-05 17:19:42 +00:00
|
|
|
image = pipe(
|
|
|
|
prompt,
|
|
|
|
height,
|
|
|
|
width,
|
|
|
|
num_inference_steps=steps,
|
|
|
|
guidance_scale=cfg,
|
|
|
|
latents=latents
|
|
|
|
).images[0]
|
|
|
|
|
2023-01-06 02:32:46 +00:00
|
|
|
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
2023-01-06 05:28:25 +00:00
|
|
|
output_full = safer_join(output_path, output_file)
|
2023-01-06 02:32:46 +00:00
|
|
|
print("txt2img output: %s" % output_full)
|
|
|
|
image.save(output_full)
|
|
|
|
|
2023-01-06 03:54:40 +00:00
|
|
|
return json_with_cors({
|
2023-01-06 02:32:46 +00:00
|
|
|
'output': output_file,
|
|
|
|
'params': {
|
2023-01-06 16:43:58 +00:00
|
|
|
'model': model,
|
|
|
|
'provider': provider,
|
|
|
|
'scheduler': scheduler.__name__,
|
2023-01-06 02:32:46 +00:00
|
|
|
'cfg': cfg,
|
|
|
|
'steps': steps,
|
|
|
|
'height': height,
|
|
|
|
'width': width,
|
|
|
|
'prompt': prompt,
|
|
|
|
'seed': seed
|
|
|
|
}
|
|
|
|
})
|
2023-01-05 23:24:33 +00:00
|
|
|
|
2023-01-06 02:32:46 +00:00
|
|
|
|
2023-01-05 23:24:33 +00:00
|
|
|
@app.route('/output/<path:filename>')
|
|
|
|
def output(filename):
|
|
|
|
return send_from_directory(output_path, filename, as_attachment=False)
|