1
0
Fork 0

feat(api): initial support for Stable Diffusion upscaling (#66)

This commit is contained in:
Sean Sube 2023-01-25 21:04:00 -06:00
parent 483b8e3f19
commit 819af824b0
5 changed files with 157 additions and 73 deletions

View File

@ -31,6 +31,8 @@ base_models: Models = {
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
('stable-diffusion-onnx-v2-inpainting',
'stabilityai/stable-diffusion-2-inpainting'),
# should be upscaling with a different converter
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
],
'correction': [
('correction-gfpgan-v1-3',

View File

@ -0,0 +1,7 @@
from .onnx_net import (
ONNXImage,
ONNXNet,
)
from .pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)

View File

@ -0,0 +1,72 @@
from onnxruntime import InferenceSession
from os import path
from typing import Any
import numpy as np
import torch
from ..utils import (
ServerContext,
)
class ONNXImage():
def __init__(self, source) -> None:
self.source = source
self.data = self
def __getitem__(self, *args):
return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32)
def squeeze(self):
self.source = np.squeeze(self.source, (0))
return self
def float(self):
return self
def cpu(self):
return self
def clamp_(self, min, max):
self.source = np.clip(self.source, min, max)
return self
def numpy(self):
return self.source
def size(self):
return np.shape(self.source)
class ONNXNet():
'''
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(ctx.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider])
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return ONNXImage(output)
def eval(self) -> None:
pass
def half(self):
return self
def load_state_dict(self, net, strict=True) -> None:
pass
def to(self, device):
return self

View File

@ -0,0 +1,58 @@
from diffusers import (
DDPMScheduler,
OnnxRuntimeModel,
StableDiffusionUpscalePipeline,
)
from diffusers.models import (
CLIPTokenizer,
)
from diffusers.schedulers import (
KarrasDiffusionSchedulers,
)
from typing import (
Callable,
Union,
List,
Optional,
)
import PIL
import torch
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
def __init__(
self,
vae: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
low_res_scheduler: DDPMScheduler,
scheduler: KarrasDiffusionSchedulers,
max_noise_level: int = 350,
):
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image,
List[PIL.Image.Image]] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[
int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
pass

View File

@ -1,14 +1,16 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer
from onnxruntime import InferenceSession
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Any, Literal, Union
from typing import Literal, Union
import numpy as np
import torch
from .onnx import (
ONNXNet,
OnnxStableDiffusionUpscalePipeline,
)
from .utils import (
ServerContext,
Size,
@ -19,69 +21,6 @@ pre_pad = 0
tile_pad = 10
class ONNXImage():
def __init__(self, source) -> None:
self.source = source
self.data = self
def __getitem__(self, *args):
return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32)
def squeeze(self):
self.source = np.squeeze(self.source, (0))
return self
def float(self):
return self
def cpu(self):
return self
def clamp_(self, min, max):
self.source = np.clip(self.source, min, max)
return self
def numpy(self):
return self.source
def size(self):
return np.shape(self.source)
class ONNXNet():
'''
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(ctx.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider])
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return ONNXImage(output)
def eval(self) -> None:
pass
def half(self):
return self
def load_state_dict(self, net, strict=True) -> None:
pass
def to(self, device):
return self
class UpscaleParams():
def __init__(
self,
@ -153,11 +92,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
output = np.array(source_image)
upsampler = make_resrgan(ctx, params, tile=512)
if params.scale > 1:
output, _ = upsampler.enhance(output, outscale=params.outscale)
if params.faces:
output = upscale_gfpgan(ctx, params, output, upsampler=upsampler)
output, _ = upsampler.enhance(output, outscale=params.outscale)
output = Image.fromarray(output, 'RGB')
print('final output image size', output.size)
@ -190,13 +125,23 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
return output
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
print('upscaling with Stable Diffusion')
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model)
result = pipeline('', image=image)
return result.images[0]
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
print('running upscale pipeline')
if params.scale > 1:
image = upscale_resrgan(ctx, params, image)
if 'esrgan' in params.upscale_model:
image = upscale_resrgan(ctx, params, image)
elif 'stable-diffusion' in params.upscale_model:
image = upscale_stable_diffusion(ctx, params, image)
if params.faces:
image = upscale_gfpgan(ctx, params, image)
return image
return image