1
0
Fork 0

feat(api): add img2img endpoint

This commit is contained in:
Sean Sube 2023-01-07 15:05:29 -06:00
parent 9973bf1bfc
commit 09ce6546be
1 changed files with 79 additions and 21 deletions

View File

@ -1,5 +1,5 @@
from diffusers import OnnxStableDiffusionPipeline
from diffusers import ( from diffusers import (
# schedulers
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
@ -7,16 +7,21 @@ from diffusers import (
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
# onnx
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
) )
from flask import Flask, jsonify, request, send_from_directory, url_for from flask import Flask, jsonify, request, send_from_directory, url_for
from io import BytesIO
from PIL import Image
from stringcase import spinalcase from stringcase import spinalcase
from os import environ, makedirs, path, scandir from os import environ, makedirs, path, scandir
import numpy as np import numpy as np
# defaults # defaults
default_model = "stable-diffusion-onnx-v1-5" default_model = 'stable-diffusion-onnx-v1-5'
default_platform = "amd" default_platform = 'amd'
default_scheduler = "euler-a" default_scheduler = 'euler-a'
default_prompt = "a photo of an astronaut eating a hamburger" default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8 default_cfg = 8
default_steps = 20 default_steps = 20
@ -29,14 +34,14 @@ max_height = 512
max_width = 512 max_width = 512
# paths # paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models") model_path = environ.get('ONNX_WEB_MODEL_PATH', '../models')
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs") output_path = environ.get('ONNX_WEB_OUTPUT_PATH', '../outputs')
# pipeline caching # pipeline caching
available_models = [] available_models = []
last_pipeline_instance = None last_pipeline_instance = None
last_pipeline_options = (None, None) last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None last_pipeline_scheduler = None
# pipeline params # pipeline params
@ -68,6 +73,10 @@ def get_from_map(args, key, values, default):
return values[default] return values[default]
def get_model_path(model):
return safer_join(model_path, model)
# from https://www.travelneil.com/stable-diffusion-updates.html # from https://www.travelneil.com/stable-diffusion-updates.html
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size # 1 is batch size
@ -78,22 +87,23 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents return image_latents
def load_pipeline(model, provider, scheduler): def load_pipeline(pipeline, model, provider, scheduler):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
options = (model, provider) options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options: if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline') print('reusing existing pipeline')
pipe = last_pipeline_instance pipe = last_pipeline_instance
else: else:
print('loading different pipeline') print('loading different pipeline')
pipe = OnnxStableDiffusionPipeline.from_pretrained( # pipe = OnnxStableDiffusionPipeline.from_pretrained(
pipe = pipeline.from_pretrained(
model, model,
provider=provider, provider=provider,
safety_checker=None, safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder="scheduler") scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
) )
last_pipeline_instance = pipe last_pipeline_instance = pipe
last_pipeline_options = options last_pipeline_options = options
@ -102,7 +112,7 @@ def load_pipeline(model, provider, scheduler):
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler') print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained( pipe.scheduler = scheduler.from_pretrained(
model, subfolder="scheduler") model, subfolder='scheduler')
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
return pipe return pipe
@ -141,6 +151,8 @@ def check_paths():
def load_models(): def load_models():
global available_models global available_models
available_models = [f.name for f in scandir(model_path) if f.is_dir()] available_models = [f.name for f in scandir(model_path) if f.is_dir()]
load_pipeline(OnnxStableDiffusionPipeline, get_model_path(available_models[0]), platform_providers.get(
default_platform), pipeline_schedulers.get(default_scheduler))
check_paths() check_paths()
@ -176,12 +188,11 @@ def list_schedulers():
return json_with_cors(list(pipeline_schedulers.keys())) return json_with_cors(list(pipeline_schedulers.keys()))
@app.route('/txt2img') def pipeline_from_request(pipeline):
def txt2img():
user = request.remote_addr user = request.remote_addr
# pipeline stuff # pipeline stuff
model = safer_join(model_path, request.args.get('model', default_model)) model = get_model_path(request.args.get('model', default_model))
provider = get_from_map(request.args, 'platform', provider = get_from_map(request.args, 'platform',
platform_providers, default_platform) platform_providers, default_platform)
scheduler = get_from_map(request.args, 'scheduler', scheduler = get_from_map(request.args, 'scheduler',
@ -198,12 +209,59 @@ def txt2img():
if seed == -1: if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max) seed = np.random.randint(np.iinfo(np.int32).max)
latents = get_latents_from_seed(seed, width, height) print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt)) (user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
pipe = load_pipeline(model, provider, scheduler) pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
@app.route('/img2img', methods=['POST'])
def img2img():
input_file = request.files.get('source')
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
input_image.thumbnail((default_width, default_height))
strength = get_and_clamp(request.args, 'strength', 1.0, 1.0, 0.0)
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
image = pipe(
prompt=prompt,
image=input_image,
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
).images[0]
output_file = 'img2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file)
print("img2img output: %s" % output_full)
image.save(output_full)
return json_with_cors({
'output': output_file,
'params': {
'model': model,
'provider': provider,
'scheduler': scheduler.__name__,
'cfg': cfg,
'steps': steps,
'height': default_height,
'width': default_width,
'prompt': prompt,
'seed': seed,
}
})
@app.route('/txt2img', methods=['POST'])
def txt2img():
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
latents = get_latents_from_seed(seed, width, height)
image = pipe( image = pipe(
prompt, prompt,
height, height,
@ -213,7 +271,7 @@ def txt2img():
latents=latents latents=latents
).images[0] ).images[0]
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64])) output_file = 'txt2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file) output_full = safer_join(output_path, output_file)
print("txt2img output: %s" % output_full) print("txt2img output: %s" % output_full)
image.save(output_full) image.save(output_full)
@ -229,7 +287,7 @@ def txt2img():
'height': height, 'height': height,
'width': width, 'width': width,
'prompt': prompt, 'prompt': prompt,
'seed': seed 'seed': seed,
} }
}) })