1
0
Fork 0

feat: return json struct with output path instead of image data, load images from outputs endpoint

This commit is contained in:
Sean Sube 2023-01-05 20:32:46 -06:00
parent 50221af55a
commit 466884113f
4 changed files with 45 additions and 30 deletions

View File

@ -77,7 +77,9 @@ Install Git and Python 3.10 for your environment:
- https://www.python.org/downloads/ - https://www.python.org/downloads/
- https://gitforwindows.org/ - https://gitforwindows.org/
The latest version of git should be fine. Python must be 3.10 or earlier, 3.10 seems to work well. The latest version of git should be fine. Python must be 3.10 or earlier, 3.10 seems to work well. If you already have
Python installed for another form of Stable Diffusion, that should work, but make sure to verify the version in the next
step.
### Create a virtual environment ### Create a virtual environment

View File

@ -8,9 +8,8 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from flask import Flask, make_response, request, send_file, send_from_directory from flask import Flask, jsonify, request, send_from_directory
from stringcase import spinalcase from stringcase import spinalcase
from io import BytesIO
from os import environ, path, makedirs from os import environ, path, makedirs
import numpy as np import numpy as np
@ -60,6 +59,7 @@ def get_from_map(args, key, values, default):
return values[default] return values[default]
# TODO: credit this function
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size # 1 is batch size
latents_shape = (1, 4, height // 8, width // 8) latents_shape = (1, 4, height // 8, width // 8)
@ -123,19 +123,26 @@ def txt2img():
latents=latents latents=latents
).images[0] ).images[0]
output = '%s/txt2img_%s_%s.png' % (output_path, output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
seed, spinalcase(prompt[0:64])) output_full = '%s/%s' % (output_path, output_file)
print("txt2img output: %s" % (output)) print("txt2img output: %s" % output_full)
image.save(output) image.save(output_full)
img_io = BytesIO() res = jsonify({
image.save(img_io, 'PNG', quality=100) 'output': output_file,
img_io.seek(0) 'params': {
'cfg': cfg,
res = make_response(send_file(img_io, mimetype='image/png')) 'steps': steps,
'height': height,
'width': width,
'prompt': prompt,
'seed': seed
}
})
res.headers.add('Access-Control-Allow-Origin', '*') res.headers.add('Access-Control-Allow-Origin', '*')
return res return res
@app.route('/output/<path:filename>') @app.route('/output/<path:filename>')
def output(filename): def output(filename):
return send_from_directory(output_path, filename, as_attachment=False) return send_from_directory(output_path, filename, as_attachment=False)

View File

@ -11,30 +11,34 @@ export interface Txt2ImgParams {
} }
export interface ApiResponse { export interface ApiResponse {
output: string;
params: Txt2ImgParams; params: Txt2ImgParams;
path: string;
} }
export interface ApiClient { export interface ApiClient {
txt2img(params: Txt2ImgParams): Promise<string>; txt2img(params: Txt2ImgParams): Promise<ApiResponse>;
} }
export const STATUS_SUCCESS = 200; export const STATUS_SUCCESS = 200;
export async function imageFromResponse(res: Response) { export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
if (res.status === STATUS_SUCCESS) { if (res.status === STATUS_SUCCESS) {
const imageBlob = await res.blob(); const data = await res.json() as ApiResponse;
return URL.createObjectURL(imageBlob); const output = new URL(['output', data.output].join('/'), root).toString();
return {
output,
params: data.params,
};
} else { } else {
throw new Error('request error'); throw new Error('request error');
} }
} }
export function makeClient(root: string, f = fetch): ApiClient { export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<string> | undefined; let pending: Promise<ApiResponse> | undefined;
return { return {
async txt2img(params: Txt2ImgParams): Promise<string> { async txt2img(params: Txt2ImgParams): Promise<ApiResponse> {
if (doesExist(pending)) { if (doesExist(pending)) {
return pending; return pending;
} }
@ -61,7 +65,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('prompt', params.prompt); url.searchParams.append('prompt', params.prompt);
pending = f(url).then((res) => imageFromResponse(res)).finally(() => { pending = f(url).then((res) => imageFromResponse(root, res)).finally(() => {
pending = undefined; pending = undefined;
}); });

View File

@ -1,7 +1,8 @@
import { doesExist } from '@apextoaster/js-utils';
import { Box, Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import { Box, Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { ApiClient } from '../api/client.js'; import { ApiClient, ApiResponse } from '../api/client.js';
import { ImageControl, ImageParams } from './ImageControl.js'; import { ImageControl, ImageParams } from './ImageControl.js';
const { useState } = React; const { useState } = React;
@ -12,7 +13,6 @@ export interface Txt2ImgProps {
export function Txt2Img(props: Txt2ImgProps) { export function Txt2Img(props: Txt2ImgProps) {
const { client } = props; const { client } = props;
const [image, setImage] = useState('');
const [prompt, setPrompt] = useState('an astronaut eating a hamburger'); const [prompt, setPrompt] = useState('an astronaut eating a hamburger');
const [params, setParams] = useState<ImageParams>({ const [params, setParams] = useState<ImageParams>({
@ -23,16 +23,18 @@ export function Txt2Img(props: Txt2ImgProps) {
}); });
const [scheduler, setScheduler] = useState('euler-a'); const [scheduler, setScheduler] = useState('euler-a');
const [result, setResult] = useState<ApiResponse | undefined>();
async function getImage() { async function getImage() {
const data = await client.txt2img({ ...params, prompt, scheduler }); const data = await client.txt2img({ ...params, prompt, scheduler });
setImage(data); setResult(data);
} }
function renderImage() { function renderImage() {
if (image === '') { if (doesExist(result)) {
return <div>No image</div>; return <img src={result.output} />;
} else { } else {
return <img src={image} />; return <div>No result. Press Generate.</div>;
} }
} }
@ -47,11 +49,11 @@ export function Txt2Img(props: Txt2ImgProps) {
> >
<MenuItem value='ddim'>DDIM</MenuItem> <MenuItem value='ddim'>DDIM</MenuItem>
<MenuItem value='ddpm'>DDPM</MenuItem> <MenuItem value='ddpm'>DDPM</MenuItem>
<MenuItem value='pndm'>PNDM</MenuItem> <MenuItem value='dpm-multi'>DPM Multistep</MenuItem>
<MenuItem value='lms-discrete'>LMS</MenuItem>
<MenuItem value='euler'>Euler</MenuItem> <MenuItem value='euler'>Euler</MenuItem>
<MenuItem value='euler-a'>Euler A</MenuItem> <MenuItem value='euler-a'>Euler Ancestral</MenuItem>
<MenuItem value='dpm-multi'>DPM</MenuItem> <MenuItem value='lms-discrete'>LMS Discrete</MenuItem>
<MenuItem value='pndm'>PNDM</MenuItem>
</Select> </Select>
<ImageControl params={params} onChange={(newParams) => { <ImageControl params={params} onChange={(newParams) => {
setParams(newParams); setParams(newParams);