support multiple schedulers, save output images locally
This commit is contained in:
parent
8bbcdc175b
commit
73027bb813
36
api/serve.py
36
api/serve.py
|
@ -1,5 +1,15 @@
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
from diffusers import (
|
||||||
|
DDPMScheduler,
|
||||||
|
DDIMScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
)
|
||||||
from flask import Flask, make_response, request, send_file
|
from flask import Flask, make_response, request, send_file
|
||||||
|
from stringcase import snakecase
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from os import environ, path, makedirs
|
from os import environ, path, makedirs
|
||||||
|
|
||||||
|
@ -17,6 +27,17 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
|
||||||
# queue
|
# queue
|
||||||
image_queue = set()
|
image_queue = set()
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
scheduler_list = {
|
||||||
|
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
|
||||||
|
}
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
raise RuntimeError('model path must exist')
|
raise RuntimeError('model path must exist')
|
||||||
|
@ -44,12 +65,23 @@ def txt2img():
|
||||||
width = request.args.get('width', max_width)
|
width = request.args.get('width', max_width)
|
||||||
steps = int(request.args.get('steps', max_steps))
|
steps = int(request.args.get('steps', max_steps))
|
||||||
cfg = int(request.args.get('cfg', max_cfg))
|
cfg = int(request.args.get('cfg', max_cfg))
|
||||||
|
scheduler = scheduler_list[request.args.get('scheduler', 'euler-a')]
|
||||||
|
|
||||||
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
|
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
|
||||||
image_queue.add(user)
|
image_queue.add(user)
|
||||||
|
|
||||||
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
|
image = pipe(
|
||||||
# image.save("astronaut_rides_horse.png")
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=cfg,
|
||||||
|
scheduler=scheduler
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = '%s/txt2img-%s' % (output_path, snakecase(prompt))
|
||||||
|
print("txt2img output: %s" % (output))
|
||||||
|
image.save(output)
|
||||||
|
|
||||||
img_io = BytesIO()
|
img_io = BytesIO()
|
||||||
image.save(img_io, 'PNG', quality=100)
|
image.save(img_io, 'PNG', quality=100)
|
||||||
|
|
Loading…
Reference in New Issue