1
0
Fork 0

add seed and latents stuff

This commit is contained in:
Sean Sube 2023-01-05 00:44:28 -06:00
parent ea9f929bf1
commit 4548a44ca3
1 changed files with 25 additions and 9 deletions

View File

@ -12,18 +12,19 @@ from flask import Flask, make_response, request, send_file
from stringcase import spinalcase from stringcase import spinalcase
from io import BytesIO from io import BytesIO
from os import environ, path, makedirs from os import environ, path, makedirs
import numpy as np
# defaults # defaults
default_prompt = "a photo of an astronaut eating a hamburger" default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8
default_steps = 20
default_height = 512 default_height = 512
default_width = 512 default_width = 512
default_steps = 20
default_cfg = 8
max_cfg = 30
max_steps = 150
max_height = 512 max_height = 512
max_width = 512 max_width = 512
max_steps = 150
max_cfg = 30
# paths # paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx") model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
@ -50,6 +51,14 @@ def get_from_map(args, key, values, default):
else: else:
return values[default] return values[default]
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
latents_shape = (1, 4, height // 8, width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
# setup # setup
if not path.exists(model_path): if not path.exists(model_path):
raise RuntimeError('model path must exist') raise RuntimeError('model path must exist')
@ -68,14 +77,20 @@ def hello():
def txt2img(): def txt2img():
user = request.remote_addr user = request.remote_addr
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg)
height = get_and_clamp(request.args, 'height', default_height, max_height)
prompt = request.args.get('prompt', default_prompt) prompt = request.args.get('prompt', default_prompt)
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a') scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a')
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0)
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
height = get_and_clamp(request.args, 'height', default_height, max_height)
width = get_and_clamp(request.args, 'width', default_width, max_width) width = get_and_clamp(request.args, 'width', default_width, max_width)
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt)) seed = int(request.args.get('seed', -1))
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)
latents = get_latents_from_seed(seed, width, height)
print("txt2img from %s: %s/%s, %sx%s, %s, %s" % (user, cfg, steps, width, height, seed, prompt))
pipe = OnnxStableDiffusionPipeline.from_pretrained( pipe = OnnxStableDiffusionPipeline.from_pretrained(
model_path, model_path,
@ -88,7 +103,8 @@ def txt2img():
height, height,
width, width,
num_inference_steps=steps, num_inference_steps=steps,
guidance_scale=cfg guidance_scale=cfg,
latents=latents
).images[0] ).images[0]
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64])) output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64]))