1
0
Fork 0

support multiple schedulers, save output images locally

This commit is contained in:
Sean Sube 2023-01-04 23:39:50 -06:00
parent 8bbcdc175b
commit 73027bb813
1 changed files with 34 additions and 2 deletions

View File

@ -1,5 +1,15 @@
from diffusers import OnnxStableDiffusionPipeline
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from flask import Flask, make_response, request, send_file
from stringcase import snakecase
from io import BytesIO
from os import environ, path, makedirs
@ -17,6 +27,17 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
# queue
image_queue = set()
# schedulers
scheduler_list = {
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
}
def setup():
if not path.exists(model_path):
raise RuntimeError('model path must exist')
@ -44,12 +65,23 @@ def txt2img():
width = request.args.get('width', max_width)
steps = int(request.args.get('steps', max_steps))
cfg = int(request.args.get('cfg', max_cfg))
scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')]
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
image_queue.add(user)
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
# image.save("astronaut_rides_horse.png")
image = pipe(
prompt,
height,
width,
num_inference_steps=steps,
guidance_scale=cfg,
scheduler=scheduler
).images[0]
output = '%s/txt2img-%s' % (output_path, snakecase(prompt))
print("txt2img output: %s" % (output))
image.save(output)
img_io = BytesIO()
image.save(img_io, 'PNG', quality=100)