diff --git a/api/serve.py b/api/serve.py index b49ddcb9..2c7286ea 100644 --- a/api/serve.py +++ b/api/serve.py @@ -1,5 +1,15 @@ from diffusers import OnnxStableDiffusionPipeline +from diffusers import ( + DDPMScheduler, + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, +) from flask import Flask, make_response, request, send_file +from stringcase import snakecase from io import BytesIO from os import environ, path, makedirs @@ -17,6 +27,17 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output") # queue image_queue = set() +# schedulers +scheduler_list = { + 'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), + 'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"), +} + def setup(): if not path.exists(model_path): raise RuntimeError('model path must exist') @@ -44,12 +65,23 @@ def txt2img(): width = request.args.get('width', max_width) steps = int(request.args.get('steps', max_steps)) cfg = int(request.args.get('cfg', max_cfg)) + scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')] print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt)) image_queue.add(user) - image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0] - # image.save("astronaut_rides_horse.png") + image = pipe( + prompt, + height, + width, + num_inference_steps=steps, + guidance_scale=cfg, + scheduler=scheduler + ).images[0] + + output = '%s/txt2img-%s' % (output_path, snakecase(prompt)) + print("txt2img output: %s" % (output)) + image.save(output) img_io = BytesIO() image.save(img_io, 'PNG', quality=100)