From 4548a44ca347bc2fcc8a8ef5a1955ad6fadfad0e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 5 Jan 2023 00:44:28 -0600 Subject: [PATCH] add seed and latents stuff --- api/serve.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/api/serve.py b/api/serve.py index 9be67436..7060dbcc 100644 --- a/api/serve.py +++ b/api/serve.py @@ -12,18 +12,19 @@ from flask import Flask, make_response, request, send_file from stringcase import spinalcase from io import BytesIO from os import environ, path, makedirs +import numpy as np # defaults default_prompt = "a photo of an astronaut eating a hamburger" +default_cfg = 8 +default_steps = 20 default_height = 512 default_width = 512 -default_steps = 20 -default_cfg = 8 +max_cfg = 30 +max_steps = 150 max_height = 512 max_width = 512 -max_steps = 150 -max_cfg = 30 # paths model_path = environ.get('ONNX_WEB_MODEL_PATH', "../../stable_diffusion_onnx") @@ -50,6 +51,14 @@ def get_from_map(args, key, values, default): else: return values[default] +def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: + # 1 is batch size + latents_shape = (1, 4, height // 8, width // 8) + # Gotta use numpy instead of torch, because torch's randn() doesn't support DML + rng = np.random.default_rng(seed) + image_latents = rng.standard_normal(latents_shape).astype(np.float32) + return image_latents + # setup if not path.exists(model_path): raise RuntimeError('model path must exist') @@ -68,14 +77,20 @@ def hello(): def txt2img(): user = request.remote_addr - cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg) - height = get_and_clamp(request.args, 'height', default_height, max_height) prompt = request.args.get('prompt', default_prompt) - steps = get_and_clamp(request.args, 'steps', default_steps, max_steps) scheduler = get_from_map(request.args, 'scheduler', scheduler_list, 'euler-a') + cfg = get_and_clamp(request.args, 'cfg', default_cfg, max_cfg, 0) + steps = get_and_clamp(request.args, 'steps', default_steps, max_steps) + height = get_and_clamp(request.args, 'height', default_height, max_height) width = get_and_clamp(request.args, 'width', default_width, max_width) - print("txt2img from %s: %s/%s, %sx%s, %s" % (user, cfg, steps, width, height, prompt)) + seed = int(request.args.get('seed', -1)) + if seed == -1: + seed = np.random.randint(np.iinfo(np.int32).max) + + latents = get_latents_from_seed(seed, width, height) + + print("txt2img from %s: %s/%s, %sx%s, %s, %s" % (user, cfg, steps, width, height, seed, prompt)) pipe = OnnxStableDiffusionPipeline.from_pretrained( model_path, @@ -88,7 +103,8 @@ def txt2img(): height, width, num_inference_steps=steps, - guidance_scale=cfg + guidance_scale=cfg, + latents=latents ).images[0] output = '%s/txt2img_%s.png' % (output_path, spinalcase(prompt[0:64]))