1
0
Fork 0

add seed and latents stuff

This commit is contained in:
Sean Sube 2023-01-05 00:44:28 -06:00
parent ea9f929bf1
commit 4548a44ca3
1 changed files with 25 additions and 9 deletions

View File

@ -12,18 +12,19 @@ from flask import Flask, make_response, request, send_file
from stringcase import spinalcase
from io import BytesIO
from os import environ, path, makedirs
import numpy as np
# defaults
default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8
default_steps = 20
default_height = 512
default_width = 512
default_steps = 20
default_cfg = 8
max_cfg = 30
max_steps = 150
max_height = 512
max_width = 512
max_steps = 150
max_cfg = 30
# paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx")
@ -50,6 +51,14 @@ def get_from_map(args, key, values, default):
else:
return values[default]
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
latents_shape = (1, 4, height // 8, width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
# setup
if not path.exists(model_path):
raise RuntimeError('model path must exist')
@ -68,14 +77,20 @@ def hello():
def txt2img():
user = request.remote_addr
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg)
height = get_and_clamp(request.args, 'height', default_height, max_height)
prompt = request.args.get('prompt', default_prompt)
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a')
cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0)
steps = get_and_clamp(request.args, 'steps', default_steps, max_steps)
height = get_and_clamp(request.args, 'height', default_height, max_height)
width = get_and_clamp(request.args, 'width', default_width, max_width)
print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt))
seed = int(request.args.get('seed', -1))
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)
latents = get_latents_from_seed(seed, width, height)
print("txt2img from %s: %s/%s, %sx%s, %s, %s" % (user, cfg, steps, width, height, seed, prompt))
pipe = OnnxStableDiffusionPipeline.from_pretrained(
model_path,
@ -88,7 +103,8 @@ def txt2img():
height,
width,
num_inference_steps=steps,
guidance_scale=cfg
guidance_scale=cfg,
latents=latents
).images[0]
output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64]))