feat(api): add img2img endpoint
This commit is contained in:
parent
9973bf1bfc
commit
09ce6546be
|
@ -1,5 +1,5 @@
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
# schedulers
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -7,16 +7,21 @@ from diffusers import (
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
# onnx
|
||||||
|
OnnxStableDiffusionPipeline,
|
||||||
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
from flask import Flask, jsonify, request, send_from_directory, url_for
|
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from stringcase import spinalcase
|
from stringcase import spinalcase
|
||||||
from os import environ, makedirs, path, scandir
|
from os import environ, makedirs, path, scandir
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# defaults
|
# defaults
|
||||||
default_model = "stable-diffusion-onnx-v1-5"
|
default_model = 'stable-diffusion-onnx-v1-5'
|
||||||
default_platform = "amd"
|
default_platform = 'amd'
|
||||||
default_scheduler = "euler-a"
|
default_scheduler = 'euler-a'
|
||||||
default_prompt = "a photo of an astronaut eating a hamburger"
|
default_prompt = "a photo of an astronaut eating a hamburger"
|
||||||
default_cfg = 8
|
default_cfg = 8
|
||||||
default_steps = 20
|
default_steps = 20
|
||||||
|
@ -29,14 +34,14 @@ max_height = 512
|
||||||
max_width = 512
|
max_width = 512
|
||||||
|
|
||||||
# paths
|
# paths
|
||||||
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models")
|
model_path = environ.get('ONNX_WEB_MODEL_PATH', '../models')
|
||||||
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', '../outputs')
|
||||||
|
|
||||||
|
|
||||||
# pipeline caching
|
# pipeline caching
|
||||||
available_models = []
|
available_models = []
|
||||||
last_pipeline_instance = None
|
last_pipeline_instance = None
|
||||||
last_pipeline_options = (None, None)
|
last_pipeline_options = (None, None, None)
|
||||||
last_pipeline_scheduler = None
|
last_pipeline_scheduler = None
|
||||||
|
|
||||||
# pipeline params
|
# pipeline params
|
||||||
|
@ -68,6 +73,10 @@ def get_from_map(args, key, values, default):
|
||||||
return values[default]
|
return values[default]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model):
|
||||||
|
return safer_join(model_path, model)
|
||||||
|
|
||||||
|
|
||||||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||||
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||||
# 1 is batch size
|
# 1 is batch size
|
||||||
|
@ -78,22 +87,23 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(model, provider, scheduler):
|
def load_pipeline(pipeline, model, provider, scheduler):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
|
||||||
options = (model, provider)
|
options = (pipeline, model, provider)
|
||||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||||
print('reusing existing pipeline')
|
print('reusing existing pipeline')
|
||||||
pipe = last_pipeline_instance
|
pipe = last_pipeline_instance
|
||||||
else:
|
else:
|
||||||
print('loading different pipeline')
|
print('loading different pipeline')
|
||||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
# pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||||
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||||
)
|
)
|
||||||
last_pipeline_instance = pipe
|
last_pipeline_instance = pipe
|
||||||
last_pipeline_options = options
|
last_pipeline_options = options
|
||||||
|
@ -102,7 +112,7 @@ def load_pipeline(model, provider, scheduler):
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
print('changing pipeline scheduler')
|
print('changing pipeline scheduler')
|
||||||
pipe.scheduler = scheduler.from_pretrained(
|
pipe.scheduler = scheduler.from_pretrained(
|
||||||
model, subfolder="scheduler")
|
model, subfolder='scheduler')
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
@ -141,6 +151,8 @@ def check_paths():
|
||||||
def load_models():
|
def load_models():
|
||||||
global available_models
|
global available_models
|
||||||
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
|
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
|
||||||
|
load_pipeline(OnnxStableDiffusionPipeline, get_model_path(available_models[0]), platform_providers.get(
|
||||||
|
default_platform), pipeline_schedulers.get(default_scheduler))
|
||||||
|
|
||||||
|
|
||||||
check_paths()
|
check_paths()
|
||||||
|
@ -176,12 +188,11 @@ def list_schedulers():
|
||||||
return json_with_cors(list(pipeline_schedulers.keys()))
|
return json_with_cors(list(pipeline_schedulers.keys()))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/txt2img')
|
def pipeline_from_request(pipeline):
|
||||||
def txt2img():
|
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
model = safer_join(model_path, request.args.get('model', default_model))
|
model = get_model_path(request.args.get('model', default_model))
|
||||||
provider = get_from_map(request.args, 'platform',
|
provider = get_from_map(request.args, 'platform',
|
||||||
platform_providers, default_platform)
|
platform_providers, default_platform)
|
||||||
scheduler = get_from_map(request.args, 'scheduler',
|
scheduler = get_from_map(request.args, 'scheduler',
|
||||||
|
@ -198,12 +209,59 @@ def txt2img():
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
latents = get_latents_from_seed(seed, width, height)
|
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||||
|
|
||||||
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
|
||||||
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||||
|
|
||||||
pipe = load_pipeline(model, provider, scheduler)
|
pipe = load_pipeline(pipeline, model, provider, scheduler)
|
||||||
|
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/img2img', methods=['POST'])
|
||||||
|
def img2img():
|
||||||
|
input_file = request.files.get('source')
|
||||||
|
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
|
||||||
|
input_image.thumbnail((default_width, default_height))
|
||||||
|
|
||||||
|
strength = get_and_clamp(request.args, 'strength', 1.0, 1.0, 0.0)
|
||||||
|
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||||
|
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
image=input_image,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=cfg,
|
||||||
|
strength=strength,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output_file = 'img2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
|
||||||
|
output_full = safer_join(output_path, output_file)
|
||||||
|
print("img2img output: %s" % output_full)
|
||||||
|
image.save(output_full)
|
||||||
|
|
||||||
|
return json_with_cors({
|
||||||
|
'output': output_file,
|
||||||
|
'params': {
|
||||||
|
'model': model,
|
||||||
|
'provider': provider,
|
||||||
|
'scheduler': scheduler.__name__,
|
||||||
|
'cfg': cfg,
|
||||||
|
'steps': steps,
|
||||||
|
'height': default_height,
|
||||||
|
'width': default_width,
|
||||||
|
'prompt': prompt,
|
||||||
|
'seed': seed,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/txt2img', methods=['POST'])
|
||||||
|
def txt2img():
|
||||||
|
(model, provider, scheduler, prompt, cfg, steps, height,
|
||||||
|
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)
|
||||||
|
|
||||||
|
latents = get_latents_from_seed(seed, width, height)
|
||||||
|
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
height,
|
height,
|
||||||
|
@ -213,7 +271,7 @@ def txt2img():
|
||||||
latents=latents
|
latents=latents
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
output_file = 'txt2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
|
||||||
output_full = safer_join(output_path, output_file)
|
output_full = safer_join(output_path, output_file)
|
||||||
print("txt2img output: %s" % output_full)
|
print("txt2img output: %s" % output_full)
|
||||||
image.save(output_full)
|
image.save(output_full)
|
||||||
|
@ -229,7 +287,7 @@ def txt2img():
|
||||||
'height': height,
|
'height': height,
|
||||||
'width': width,
|
'width': width,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'seed': seed
|
'seed': seed,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue