diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index 2061f244..8e789e17 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -1,9 +1,26 @@ from diffusers import ( - DiffusionPipeline, + DiffusionPipeline, + # onnx + OnnxStableDiffusionPipeline, + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, ) +from os import environ +from PIL import Image +from typing import Any, Union import numpy as np +from .image import ( + expand_image, +) +from .upscale import ( + upscale_resrgan, +) +from .utils import ( + safer_join +) + last_pipeline_instance = None last_pipeline_options = (None, None, None) last_pipeline_scheduler = None @@ -48,3 +65,105 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu last_pipeline_scheduler = scheduler return pipe + + +def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, height, width): + pipe = load_pipeline(OnnxStableDiffusionPipeline, + model, provider, scheduler) + + latents = get_latents_from_seed(seed, width, height) + rng = np.random.RandomState(seed) + + image = pipe( + prompt, + height, + width, + generator=rng, + guidance_scale=cfg, + latents=latents, + negative_prompt=negative_prompt, + num_inference_steps=steps, + ).images[0] + image = upscale_resrgan(image, model_path) + image.save(output) + + print('saved txt2img output: %s' % (output)) + + +def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, strength, input_image): + pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, + model, provider, scheduler) + + rng = np.random.RandomState(seed) + + image = pipe( + prompt, + generator=rng, + guidance_scale=cfg, + image=input_image, + negative_prompt=negative_prompt, + num_inference_steps=steps, + strength=strength, + ).images[0] + image = upscale_resrgan(image, model_path) + image.save(output) + + print('saved img2img output: %s' % (output)) + + +def run_inpaint_pipeline( + model: str, + provider: str, + scheduler: Any, + prompt: str, + negative_prompt: Union[str, None], + cfg: float, + steps: int, + seed: int, + output: str, + height: int, + width: int, + source_image: Image, + mask_image: Image, + left: int, + right: int, + top: int, + bottom: int, + noise_source: Any, + mask_filter: Any +): + pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, + model, provider, scheduler) + + latents = get_latents_from_seed(seed, width, height) + rng = np.random.RandomState(seed) + + print('applying mask filter and generating noise source') + source_image, mask_image, noise_image, _full_dims = expand_image( + source_image, + mask_image, + (left, right, top, bottom), + noise_source=noise_source, + mask_filter=mask_filter) + + if environ.get('DEBUG') is not None: + source_image.save(safer_join(output_path, 'last-source.png')) + mask_image.save(safer_join(output_path, 'last-mask.png')) + noise_image.save(safer_join(output_path, 'last-noise.png')) + + image = pipe( + prompt, + generator=rng, + guidance_scale=cfg, + height=height, + image=source_image, + latents=latents, + mask_image=mask_image, + negative_prompt=negative_prompt, + num_inference_steps=steps, + width=width, + ).images[0] + + image.save(output) + + print('saved inpaint output: %s' % (output)) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ff14e1cb..b189b088 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -12,10 +12,6 @@ from diffusers import ( KarrasVeScheduler, LMSDiscreteScheduler, PNDMScheduler, - # onnx - OnnxStableDiffusionPipeline, - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, ) from flask import Flask, jsonify, request, send_from_directory, url_for from flask_cors import CORS @@ -25,7 +21,7 @@ from io import BytesIO from PIL import Image from struct import pack from os import environ, makedirs, path, scandir -from typing import Any, Dict, Tuple, Union +from typing import Tuple, Union from .image import ( expand_image, @@ -42,12 +38,15 @@ from .image import ( noise_source_uniform, ) from .pipeline import ( - get_latents_from_seed, - load_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, ) -from .upscale import ( - upscale_gfpgan, - upscale_resrgan, +from .utils import ( + get_and_clamp_float, + get_and_clamp_int, + get_from_map, + safer_join ) import json @@ -104,22 +103,6 @@ mask_filters = { } -def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float: - return min(max(float(args.get(key, default_value)), min_value), max_value) - - -def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int: - return min(max(int(args.get(key, default_value)), min_value), max_value) - - -def get_from_map(args, key: str, values: Dict[str, Any], default: Any): - selected = args.get(key, default) - if selected in values: - return values[selected] - else: - return values[default] - - def get_model_path(model: str): return safer_join(model_path, model) @@ -151,11 +134,6 @@ def make_output_path(mode: str, seed: int, params: Tuple[Union[str, int, float]] return (output_file, output_full) -def safer_join(base, tail): - safer_path = path.relpath(path.normpath(path.join('/', tail)), '/') - return path.join(base, safer_path) - - def url_from_rule(rule): options = {} for arg in rule.arguments: @@ -214,111 +192,6 @@ def pipeline_from_request(): return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed) -def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, height, width): - pipe = load_pipeline(OnnxStableDiffusionPipeline, - model, provider, scheduler) - - latents = get_latents_from_seed(seed, width, height) - rng = np.random.RandomState(seed) - - image = pipe( - prompt, - height, - width, - generator=rng, - guidance_scale=cfg, - latents=latents, - negative_prompt=negative_prompt, - num_inference_steps=steps, - ).images[0] - image = upscale_resrgan(image, model_path) - image.save(output) - - print('saved txt2img output: %s' % (output)) - - -def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, strength, input_image): - pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - model, provider, scheduler) - - rng = np.random.RandomState(seed) - - image = pipe( - prompt, - generator=rng, - guidance_scale=cfg, - image=input_image, - negative_prompt=negative_prompt, - num_inference_steps=steps, - strength=strength, - ).images[0] - image = upscale_resrgan(image, model_path) - image.save(output) - - print('saved img2img output: %s' % (output)) - - -def run_inpaint_pipeline( - model: str, - provider: str, - scheduler: Any, - prompt: str, - negative_prompt: Union[str, None], - cfg: float, - steps: int, - seed: int, - output: str, - height: int, - width: int, - source_image: Image, - mask_image: Image, - left: int, - right: int, - top: int, - bottom: int, - noise_source: Any, - mask_filter: Any -): - pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, - model, provider, scheduler) - - latents = get_latents_from_seed(seed, width, height) - rng = np.random.RandomState(seed) - - print('applying mask filter and generating noise source') - source_image, mask_image, noise_image, _full_dims = expand_image( - source_image, - mask_image, - (left, right, top, bottom), - noise_source=noise_source, - mask_filter=mask_filter) - - if environ.get('DEBUG') is not None: - source_image.save(safer_join(output_path, 'last-source.png')) - mask_image.save(safer_join(output_path, 'last-mask.png')) - noise_image.save(safer_join(output_path, 'last-noise.png')) - - image = pipe( - prompt, - generator=rng, - guidance_scale=cfg, - height=height, - image=source_image, - latents=latents, - mask_image=mask_image, - negative_prompt=negative_prompt, - num_inference_steps=steps, - width=width, - ).images[0] - - image.save(output) - - print('saved inpaint output: %s' % (output)) - - -# setup - - def check_paths(): if not path.exists(model_path): raise RuntimeError('model path must exist') diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py new file mode 100644 index 00000000..67092523 --- /dev/null +++ b/api/onnx_web/utils.py @@ -0,0 +1,28 @@ +from os import path +from typing import Any, Dict, Tuple + + +Border = Tuple[int, int, int, int] +Point = Tuple[int, int] +Size = Tuple[int, int] + + +def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float: + return min(max(float(args.get(key, default_value)), min_value), max_value) + + +def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int: + return min(max(int(args.get(key, default_value)), min_value), max_value) + + +def get_from_map(args, key: str, values: Dict[str, Any], default: Any): + selected = args.get(key, default) + if selected in values: + return values[selected] + else: + return values[default] + + +def safer_join(base, tail): + safer_path = path.relpath(path.normpath(path.join('/', tail)), '/') + return path.join(base, safer_path)