1
0
Fork 0

clamp API params, limit filename length, remove queueing thing

This commit is contained in:
Sean Sube 2023-01-04 23:59:45 -06:00
parent 45cc2784cd
commit ea9f929bf1
1 changed files with 31 additions and 25 deletions

View File

@ -14,19 +14,21 @@ from io import BytesIO
from os import environ, path, makedirs
# defaults
empty_prompt = "a photo of an astronaut eating a hamburger"
default_prompt = "a photo of an astronaut eating a hamburger"
default_height = 512
default_width = 512
default_steps = 20
default_cfg = 8
max_height = 512
max_width = 512
max_steps = 50
max_cfg = 8
max_steps = 150
max_cfg = 30
# paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
# queue
image_queue = set()
# schedulers
scheduler_list = {
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
@ -38,14 +40,23 @@ scheduler_list = {
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
}
def setup():
if not path.exists(model_path):
raise RuntimeError('model path must exist')
if not path.exists(output_path):
makedirs(output_path)
def get_and_clamp(args, key, default_value, max_value, min_value=1):
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args, key, values, default):
selected = args.get(key, default)
if selected in values:
return values[selected]
else:
return values[default]
# setup
setup()
if not path.exists(model_path):
raise RuntimeError('model path must exist')
if not path.exists(output_path):
makedirs(output_path)
app = Flask(__name__)
# routes
@ -55,19 +66,16 @@ def hello():
@app.route('/txt2img')
def txt2img():
if len(image_queue) > 0:
return 'Queue full: %s' % (image_queue)
user = request.remote_addr
prompt = request.args.get('prompt', empty_prompt)
height = request.args.get('height', max_height)
width = request.args.get('width', max_width)
steps = int(request.args.get('steps', max_steps))
cfg = int(request.args.get('cfg', max_cfg))
scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')]
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg)
height = get_and_clamp(request.args, 'height', default_height, max_height)
prompt = request.args.get('prompt', default_prompt)
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a')
width = get_and_clamp(request.args, 'width', default_width, max_width)
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
image_queue.add(user)
pipe = OnnxStableDiffusionPipeline.from_pretrained(
model_path,
@ -83,7 +91,7 @@ def txt2img():
guidance_scale=cfg
).images[0]
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt))
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64]))
print("txt2img output: %s" % (output))
image.save(output)
@ -91,8 +99,6 @@ def txt2img():
image.save(img_io, 'PNG', quality=100)
img_io.seek(0)
image_queue.remove(user)
res = make_response(send_file(img_io, mimetype='image/png'))
res.headers.add('Access-Control-Allow-Origin', '*')
return res