feat(api): add img2img endpoint
This commit is contained in:
parent
9973bf1bfc
commit
09ce6546be
|
@ -1,5 +1,5 @@
|
|||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from diffusers import (
|
||||
# schedulers
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
|
@ -7,16 +7,21 @@ from diffusers import (
|
|||
EulerAncestralDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
# onnx
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from stringcase import spinalcase
|
||||
from os import environ, makedirs, path, scandir
|
||||
import numpy as np
|
||||
|
||||
# defaults
|
||||
default_model = "stable-diffusion-onnx-v1-5"
|
||||
default_platform = "amd"
|
||||
default_scheduler = "euler-a"
|
||||
default_model = 'stable-diffusion-onnx-v1-5'
|
||||
default_platform = 'amd'
|
||||
default_scheduler = 'euler-a'
|
||||
default_prompt = "a photo of an astronaut eating a hamburger"
|
||||
default_cfg = 8
|
||||
default_steps = 20
|
||||
|
@ -29,14 +34,14 @@ max_height = 512
|
|||
max_width = 512
|
||||
|
||||
# paths
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models")
|
||||
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH', '../models')
|
||||
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', '../outputs')
|
||||
|
||||
|
||||
# pipeline caching
|
||||
available_models = []
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None)
|
||||
last_pipeline_options = (None, None, None)
|
||||
last_pipeline_scheduler = None
|
||||
|
||||
# pipeline params
|
||||
|
@ -68,6 +73,10 @@ def get_from_map(args, key, values, default):
|
|||
return values[default]
|
||||
|
||||
|
||||
def get_model_path(model):
|
||||
return safer_join(model_path, model)
|
||||
|
||||
|
||||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||
# 1 is batch size
|
||||
|
@ -78,22 +87,23 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
|||
return image_latents
|
||||
|
||||
|
||||
def load_pipeline(model, provider, scheduler):
|
||||
def load_pipeline(pipeline, model, provider, scheduler):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (model, provider)
|
||||
options = (pipeline, model, provider)
|
||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||
print('reusing existing pipeline')
|
||||
pipe = last_pipeline_instance
|
||||
else:
|
||||
print('loading different pipeline')
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
# pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
provider=provider,
|
||||
safety_checker=None,
|
||||
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||
)
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
|
@ -102,7 +112,7 @@ def load_pipeline(model, provider, scheduler):
|
|||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
pipe.scheduler = scheduler.from_pretrained(
|
||||
model, subfolder="scheduler")
|
||||
model, subfolder='scheduler')
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
return pipe
|
||||
|
@ -141,6 +151,8 @@ def check_paths():
|
|||
def load_models():
|
||||
global available_models
|
||||
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
|
||||
load_pipeline(OnnxStableDiffusionPipeline, get_model_path(available_models[0]), platform_providers.get(
|
||||
default_platform), pipeline_schedulers.get(default_scheduler))
|
||||
|
||||
|
||||
check_paths()
|
||||
|
@ -176,12 +188,11 @@ def list_schedulers():
|
|||
return json_with_cors(list(pipeline_schedulers.keys()))
|
||||
|
||||
|
||||
@app.route('/txt2img')
|
||||
def txt2img():
|
||||
def pipeline_from_request(pipeline):
|
||||
user = request.remote_addr
|
||||
|
||||
# pipeline stuff
|
||||
model = safer_join(model_path, request.args.get('model', default_model))
|
||||
model = get_model_path(request.args.get('model', default_model))
|
||||
provider = get_from_map(request.args, 'platform',
|
||||
platform_providers, default_platform)
|
||||
scheduler = get_from_map(request.args, 'scheduler',
|
||||
|
@ -198,12 +209,59 @@ def txt2img():
|
|||
if seed == -1:
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
|
||||
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
pipe = load_pipeline(model, provider, scheduler)
|
||||
pipe = load_pipeline(pipeline, model, provider, scheduler)
|
||||
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
|
||||
|
||||
|
||||
@app.route('/img2img', methods=['POST'])
|
||||
def img2img():
|
||||
input_file = request.files.get('source')
|
||||
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
|
||||
input_image.thumbnail((default_width, default_height))
|
||||
|
||||
strength = get_and_clamp(request.args, 'strength', 1.0, 1.0, 0.0)
|
||||
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=input_image,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=cfg,
|
||||
strength=strength,
|
||||
).images[0]
|
||||
|
||||
output_file = 'img2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
|
||||
output_full = safer_join(output_path, output_file)
|
||||
print("img2img output: %s" % output_full)
|
||||
image.save(output_full)
|
||||
|
||||
return json_with_cors({
|
||||
'output': output_file,
|
||||
'params': {
|
||||
'model': model,
|
||||
'provider': provider,
|
||||
'scheduler': scheduler.__name__,
|
||||
'cfg': cfg,
|
||||
'steps': steps,
|
||||
'height': default_height,
|
||||
'width': default_width,
|
||||
'prompt': prompt,
|
||||
'seed': seed,
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.route('/txt2img', methods=['POST'])
|
||||
def txt2img():
|
||||
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
|
||||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
height,
|
||||
|
@ -213,7 +271,7 @@ def txt2img():
|
|||
latents=latents
|
||||
).images[0]
|
||||
|
||||
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
||||
output_file = 'txt2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
|
||||
output_full = safer_join(output_path, output_file)
|
||||
print("txt2img output: %s" % output_full)
|
||||
image.save(output_full)
|
||||
|
@ -229,7 +287,7 @@ def txt2img():
|
|||
'height': height,
|
||||
'width': width,
|
||||
'prompt': prompt,
|
||||
'seed': seed
|
||||
'seed': seed,
|
||||
}
|
||||
})
|
||||
|
||||
|
|
Loading…
Reference in New Issue