1
0
Fork 0

add seed to output filename, apply pep8 to server

This commit is contained in:
Sean Sube 2023-01-05 11:19:42 -06:00
parent d8b6d7fc15
commit d93d4659fa
1 changed files with 64 additions and 54 deletions

View File

@ -32,89 +32,99 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../../web_output")
# schedulers # schedulers
scheduler_list = { scheduler_list = {
'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"), 'ddpm': DDPMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"), 'ddim': DDIMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"), 'pndm': PNDMScheduler.from_pretrained(model_path, subfolder="scheduler"),
'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), 'lms-discrete': LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), 'euler-a': EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"), 'euler': EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler"),
'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"), 'dpm-multi': DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler"),
} }
def get_and_clamp(args, key, default_value, max_value, min_value=1): def get_and_clamp(args, key, default_value, max_value, min_value=1):
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args, key, values, default): def get_from_map(args, key, values, default):
selected = args.get(key, default) selected = args.get(key, default)
if selected in values: if selected in values:
return values[selected] return values[selected]
else: else:
return values[default] return values[default]
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size # 1 is batch size
latents_shape = (1, 4, height // 8, width // 8) latents_shape = (1, 4, height // 8, width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML # Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32) image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents return image_latents
# setup # setup
if not path.exists(model_path): if not path.exists(model_path):
raise RuntimeError('model path must exist') raise RuntimeError('model path must exist')
if not path.exists(output_path): if not path.exists(output_path):
makedirs(output_path) makedirs(output_path)
app = Flask(__name__) app = Flask(__name__)
# routes # routes
@app.route('/') @app.route('/')
def hello(): def hello():
return 'Hello, %s' % (__name__) return 'Hello, %s' % (__name__)
@app.route('/txt2img') @app.route('/txt2img')
def txt2img(): def txt2img():
user = request.remote_addr user = request.remote_addr
prompt = request.args.get('prompt', default_prompt) prompt = request.args.get('prompt', default_prompt)
scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a') scheduler = get_from_map(request.args, 'scheduler',
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0) scheduler_list, 'euler-a')
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps) cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0)
height = get_and_clamp(request.args, 'height', default_height, max_height) steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
width = get_and_clamp(request.args, 'width', default_width, max_width) height = get_and_clamp(request.args, 'height', default_height, max_height)
width = get_and_clamp(request.args, 'width', default_width, max_width)
seed = int(request.args.get('seed', -1)) seed = int(request.args.get('seed', -1))
if seed == -1: if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max) seed = np.random.randint(np.iinfo(np.int32).max)
latents = get_latents_from_seed(seed, width, height) latents = get_latents_from_seed(seed, width, height)
print("txt2img from %s: %s/%s, %sx%s, %s, %s" % (user, cfg, steps, width, height, seed, prompt)) print("txt2img from %s: %s/%s, %sx%s, %s, %s" %
(user, cfg, steps, width, height, seed, prompt))
pipe = OnnxStableDiffusionPipeline.from_pretrained( pipe = OnnxStableDiffusionPipeline.from_pretrained(
model_path, model_path,
provider="DmlExecutionProvider", provider="DmlExecutionProvider",
safety_checker=None, safety_checker=None,
scheduler=scheduler scheduler=scheduler
) )
image = pipe( image = pipe(
prompt, prompt,
height, height,
width, width,
num_inference_steps=steps, num_inference_steps=steps,
guidance_scale=cfg, guidance_scale=cfg,
latents=latents latents=latents
).images[0] ).images[0]
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64])) output = '%s/txt2img_%s_%s.png' % (output_path,
print("txt2img output: %s" % (output)) seed, spinalcase(prompt[0:64]))
image.save(output) print("txt2img output: %s" % (output))
image.save(output)
img_io = BytesIO() img_io = BytesIO()
image.save(img_io, 'PNG', quality=100) image.save(img_io, 'PNG', quality=100)
img_io.seek(0) img_io.seek(0)
res = make_response(send_file(img_io, mimetype='image/png')) res = make_response(send_file(img_io, mimetype='image/png'))
res.headers.add('Access-Control-Allow-Origin', '*') res.headers.add('Access-Control-Allow-Origin', '*')
return res return res