From 819af824b047db4c5af471e9c8e306215a0a2d10 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 25 Jan 2023 21:04:00 -0600 Subject: [PATCH] feat(api): initial support for Stable Diffusion upscaling (#66) --- api/onnx_web/convert.py | 2 + api/onnx_web/onnx/__init__.py | 7 ++ api/onnx_web/onnx/onnx_net.py | 72 +++++++++++++++ .../pipeline_onnx_stable_diffusion_upscale.py | 58 ++++++++++++ api/onnx_web/upscale.py | 91 ++++--------------- 5 files changed, 157 insertions(+), 73 deletions(-) create mode 100644 api/onnx_web/onnx/__init__.py create mode 100644 api/onnx_web/onnx/onnx_net.py create mode 100644 api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index bb175bc1..094eac0e 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -31,6 +31,8 @@ base_models: Models = { ('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'), ('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'), + # should be upscaling with a different converter + ('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'), ], 'correction': [ ('correction-gfpgan-v1-3', diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py new file mode 100644 index 00000000..7978d5dc --- /dev/null +++ b/api/onnx_web/onnx/__init__.py @@ -0,0 +1,7 @@ +from .onnx_net import ( + ONNXImage, + ONNXNet, +) +from .pipeline_onnx_stable_diffusion_upscale import ( + OnnxStableDiffusionUpscalePipeline, +) \ No newline at end of file diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py new file mode 100644 index 00000000..4d26bd72 --- /dev/null +++ b/api/onnx_web/onnx/onnx_net.py @@ -0,0 +1,72 @@ +from onnxruntime import InferenceSession +from os import path +from typing import Any + +import numpy as np +import torch + +from ..utils import ( + ServerContext, +) + +class ONNXImage(): + def __init__(self, source) -> None: + self.source = source + self.data = self + + def __getitem__(self, *args): + return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32) + + def squeeze(self): + self.source = np.squeeze(self.source, (0)) + return self + + def float(self): + return self + + def cpu(self): + return self + + def clamp_(self, min, max): + self.source = np.clip(self.source, min, max) + return self + + def numpy(self): + return self.source + + def size(self): + return np.shape(self.source) + + +class ONNXNet(): + ''' + Provides the RRDBNet interface using an ONNX session for DirectML acceleration. + ''' + + def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: + ''' + TODO: get platform provider from request params + ''' + model_path = path.join(ctx.model_path, model) + self.session = InferenceSession( + model_path, providers=[provider]) + + def __call__(self, image: Any) -> Any: + input_name = self.session.get_inputs()[0].name + output_name = self.session.get_outputs()[0].name + output = self.session.run([output_name], { + input_name: image.cpu().numpy() + })[0] + return ONNXImage(output) + + def eval(self) -> None: + pass + + def half(self): + return self + + def load_state_dict(self, net, strict=True) -> None: + pass + + def to(self, device): + return self diff --git a/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py new file mode 100644 index 00000000..48077b74 --- /dev/null +++ b/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py @@ -0,0 +1,58 @@ +from diffusers import ( + DDPMScheduler, + OnnxRuntimeModel, + StableDiffusionUpscalePipeline, +) +from diffusers.models import ( + CLIPTokenizer, +) +from diffusers.schedulers import ( + KarrasDiffusionSchedulers, +) +from typing import ( + Callable, + Union, + List, + Optional, +) + +import PIL +import torch + + +class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): + def __init__( + self, + vae: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + max_noise_level: int = 350, + ): + super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, + List[PIL.Image.Image]] = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + pass diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 3cd5b4c8..b5eb85f9 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -1,14 +1,16 @@ from basicsr.archs.rrdbnet_arch import RRDBNet from gfpgan import GFPGANer -from onnxruntime import InferenceSession from os import path from PIL import Image from realesrgan import RealESRGANer -from typing import Any, Literal, Union +from typing import Literal, Union import numpy as np -import torch +from .onnx import ( + ONNXNet, + OnnxStableDiffusionUpscalePipeline, +) from .utils import ( ServerContext, Size, @@ -19,69 +21,6 @@ pre_pad = 0 tile_pad = 10 -class ONNXImage(): - def __init__(self, source) -> None: - self.source = source - self.data = self - - def __getitem__(self, *args): - return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32) - - def squeeze(self): - self.source = np.squeeze(self.source, (0)) - return self - - def float(self): - return self - - def cpu(self): - return self - - def clamp_(self, min, max): - self.source = np.clip(self.source, min, max) - return self - - def numpy(self): - return self.source - - def size(self): - return np.shape(self.source) - - -class ONNXNet(): - ''' - Provides the RRDBNet interface using an ONNX session for DirectML acceleration. - ''' - - def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: - ''' - TODO: get platform provider from request params - ''' - model_path = path.join(ctx.model_path, model) - self.session = InferenceSession( - model_path, providers=[provider]) - - def __call__(self, image: Any) -> Any: - input_name = self.session.get_inputs()[0].name - output_name = self.session.get_outputs()[0].name - output = self.session.run([output_name], { - input_name: image.cpu().numpy() - })[0] - return ONNXImage(output) - - def eval(self) -> None: - pass - - def half(self): - return self - - def load_state_dict(self, net, strict=True) -> None: - pass - - def to(self, device): - return self - - class UpscaleParams(): def __init__( self, @@ -153,11 +92,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima output = np.array(source_image) upsampler = make_resrgan(ctx, params, tile=512) - if params.scale > 1: - output, _ = upsampler.enhance(output, outscale=params.outscale) - - if params.faces: - output = upscale_gfpgan(ctx, params, output, upsampler=upsampler) + output, _ = upsampler.enhance(output, outscale=params.outscale) output = Image.fromarray(output, 'RGB') print('final output image size', output.size) @@ -190,13 +125,23 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N return output +def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: + print('upscaling with Stable Diffusion') + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model) + result = pipeline('', image=image) + return result.images[0] + + def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: print('running upscale pipeline') if params.scale > 1: - image = upscale_resrgan(ctx, params, image) + if 'esrgan' in params.upscale_model: + image = upscale_resrgan(ctx, params, image) + elif 'stable-diffusion' in params.upscale_model: + image = upscale_stable_diffusion(ctx, params, image) if params.faces: image = upscale_gfpgan(ctx, params, image) - return image \ No newline at end of file + return image