clamp API params, limit filename length, remove queueing thing
This commit is contained in:
parent
45cc2784cd
commit
ea9f929bf1
50
api/serve.py
50
api/serve.py
|
@ -14,19 +14,21 @@ from io import BytesIO
|
||||||
from os import environ, path, makedirs
|
from os import environ, path, makedirs
|
||||||
|
|
||||||
# defaults
|
# defaults
|
||||||
empty_prompt = "a photo of an astronaut eating a hamburger"
|
default_prompt = "a photo of an astronaut eating a hamburger"
|
||||||
|
default_height = 512
|
||||||
|
default_width = 512
|
||||||
|
default_steps = 20
|
||||||
|
default_cfg = 8
|
||||||
|
|
||||||
max_height = 512
|
max_height = 512
|
||||||
max_width = 512
|
max_width = 512
|
||||||
max_steps = 50
|
max_steps = 150
|
||||||
max_cfg = 8
|
max_cfg = 30
|
||||||
|
|
||||||
# paths
|
# paths
|
||||||
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
|
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
|
||||||
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
|
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
|
||||||
|
|
||||||
# queue
|
|
||||||
image_queue = set()
|
|
||||||
|
|
||||||
# schedulers
|
# schedulers
|
||||||
scheduler_list = {
|
scheduler_list = {
|
||||||
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
@ -38,14 +40,23 @@ scheduler_list = {
|
||||||
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup():
|
def get_and_clamp(args, key, default_value, max_value, min_value=1):
|
||||||
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
def get_from_map(args, key, values, default):
|
||||||
|
selected = args.get(key, default)
|
||||||
|
if selected in values:
|
||||||
|
return values[selected]
|
||||||
|
else:
|
||||||
|
return values[default]
|
||||||
|
|
||||||
|
# setup
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
raise RuntimeError('model path must exist')
|
raise RuntimeError('model path must exist')
|
||||||
|
|
||||||
if not path.exists(output_path):
|
if not path.exists(output_path):
|
||||||
makedirs(output_path)
|
makedirs(output_path)
|
||||||
|
|
||||||
# setup
|
|
||||||
setup()
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# routes
|
# routes
|
||||||
|
@ -55,19 +66,16 @@ def hello():
|
||||||
|
|
||||||
@app.route('/txt2img')
|
@app.route('/txt2img')
|
||||||
def txt2img():
|
def txt2img():
|
||||||
if len(image_queue) > 0:
|
|
||||||
return 'Queue full: %s' % (image_queue)
|
|
||||||
|
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
prompt = request.args.get('prompt', empty_prompt)
|
|
||||||
height = request.args.get('height', max_height)
|
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg)
|
||||||
width = request.args.get('width', max_width)
|
height = get_and_clamp(request.args, 'height', default_height, max_height)
|
||||||
steps = int(request.args.get('steps', max_steps))
|
prompt = request.args.get('prompt', default_prompt)
|
||||||
cfg = int(request.args.get('cfg', max_cfg))
|
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
|
||||||
scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')]
|
scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a')
|
||||||
|
width = get_and_clamp(request.args, 'width', default_width, max_width)
|
||||||
|
|
||||||
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
|
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
|
||||||
image_queue.add(user)
|
|
||||||
|
|
||||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
|
@ -83,7 +91,7 @@ def txt2img():
|
||||||
guidance_scale=cfg
|
guidance_scale=cfg
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt))
|
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64]))
|
||||||
print("txt2img output: %s" % (output))
|
print("txt2img output: %s" % (output))
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
@ -91,8 +99,6 @@ def txt2img():
|
||||||
image.save(img_io, 'PNG', quality=100)
|
image.save(img_io, 'PNG', quality=100)
|
||||||
img_io.seek(0)
|
img_io.seek(0)
|
||||||
|
|
||||||
image_queue.remove(user)
|
|
||||||
|
|
||||||
res = make_response(send_file(img_io, mimetype='image/png'))
|
res = make_response(send_file(img_io, mimetype='image/png'))
|
||||||
res.headers.add('Access-Control-Allow-Origin', '*')
|
res.headers.add('Access-Control-Allow-Origin', '*')
|
||||||
return res
|
return res
|
Loading…
Reference in New Issue