1
0
Fork 0

feat(api): add img2img endpoint

This commit is contained in:
Sean Sube 2023-01-07 15:05:29 -06:00
parent 9973bf1bfc
commit 09ce6546be
1 changed files with 79 additions and 21 deletions

View File

@ -1,5 +1,5 @@
from diffusers import OnnxStableDiffusionPipeline
from diffusers import (
# schedulers
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
@ -7,16 +7,21 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
# onnx
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
)
from flask import Flask, jsonify, request, send_from_directory, url_for
from io import BytesIO
from PIL import Image
from stringcase import spinalcase
from os import environ, makedirs, path, scandir
import numpy as np
# defaults
default_model = "stable-diffusion-onnx-v1-5"
default_platform = "amd"
default_scheduler = "euler-a"
default_model = 'stable-diffusion-onnx-v1-5'
default_platform = 'amd'
default_scheduler = 'euler-a'
default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8
default_steps = 20
@ -29,14 +34,14 @@ max_height = 512
max_width = 512
# paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models")
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
model_path = environ.get('ONNX_WEB_MODEL_PATH', '../models')
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', '../outputs')
# pipeline caching
available_models = []
last_pipeline_instance = None
last_pipeline_options = (None, None)
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None
# pipeline params
@ -68,6 +73,10 @@ def get_from_map(args, key, values, default):
return values[default]
def get_model_path(model):
return safer_join(model_path, model)
# from https://www.travelneil.com/stable-diffusion-updates.html
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
@ -78,22 +87,23 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents
def load_pipeline(model, provider, scheduler):
def load_pipeline(pipeline, model, provider, scheduler):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (model, provider)
options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline')
pipe = last_pipeline_instance
else:
print('loading different pipeline')
pipe = OnnxStableDiffusionPipeline.from_pretrained(
# pipe = OnnxStableDiffusionPipeline.from_pretrained(
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)
last_pipeline_instance = pipe
last_pipeline_options = options
@ -102,7 +112,7 @@ def load_pipeline(model, provider, scheduler):
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
model, subfolder="scheduler")
model, subfolder='scheduler')
last_pipeline_scheduler = scheduler
return pipe
@ -141,6 +151,8 @@ def check_paths():
def load_models():
global available_models
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
load_pipeline(OnnxStableDiffusionPipeline, get_model_path(available_models[0]), platform_providers.get(
default_platform), pipeline_schedulers.get(default_scheduler))
check_paths()
@ -176,12 +188,11 @@ def list_schedulers():
return json_with_cors(list(pipeline_schedulers.keys()))
@app.route('/txt2img')
def txt2img():
def pipeline_from_request(pipeline):
user = request.remote_addr
# pipeline stuff
model = safer_join(model_path, request.args.get('model', default_model))
model = get_model_path(request.args.get('model', default_model))
provider = get_from_map(request.args, 'platform',
platform_providers, default_platform)
scheduler = get_from_map(request.args, 'scheduler',
@ -198,12 +209,59 @@ def txt2img():
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)
latents = get_latents_from_seed(seed, width, height)
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
pipe = load_pipeline(model, provider, scheduler)
pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
@app.route('/img2img', methods=['POST'])
def img2img():
input_file = request.files.get('source')
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
input_image.thumbnail((default_width, default_height))
strength = get_and_clamp(request.args, 'strength', 1.0, 1.0, 0.0)
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
image = pipe(
prompt=prompt,
image=input_image,
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
).images[0]
output_file = 'img2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file)
print("img2img output: %s" % output_full)
image.save(output_full)
return json_with_cors({
'output': output_file,
'params': {
'model': model,
'provider': provider,
'scheduler': scheduler.__name__,
'cfg': cfg,
'steps': steps,
'height': default_height,
'width': default_width,
'prompt': prompt,
'seed': seed,
}
})
@app.route('/txt2img', methods=['POST'])
def txt2img():
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
latents = get_latents_from_seed(seed, width, height)
image = pipe(
prompt,
height,
@ -213,7 +271,7 @@ def txt2img():
latents=latents
).images[0]
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
output_file = 'txt2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file)
print("txt2img output: %s" % output_full)
image.save(output_full)
@ -229,7 +287,7 @@ def txt2img():
'height': height,
'width': width,
'prompt': prompt,
'seed': seed
'seed': seed,
}
})