feat(api): initial support for Stable Diffusion upscaling (#66)
This commit is contained in:
parent
483b8e3f19
commit
819af824b0
|
@ -31,6 +31,8 @@ base_models: Models = {
|
|||
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
||||
('stable-diffusion-onnx-v2-inpainting',
|
||||
'stabilityai/stable-diffusion-2-inpainting'),
|
||||
# should be upscaling with a different converter
|
||||
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
|
||||
],
|
||||
'correction': [
|
||||
('correction-gfpgan-v1-3',
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .onnx_net import (
|
||||
ONNXImage,
|
||||
ONNXNet,
|
||||
)
|
||||
from .pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
from onnxruntime import InferenceSession
|
||||
from os import path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
class ONNXImage():
|
||||
def __init__(self, source) -> None:
|
||||
self.source = source
|
||||
self.data = self
|
||||
|
||||
def __getitem__(self, *args):
|
||||
return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32)
|
||||
|
||||
def squeeze(self):
|
||||
self.source = np.squeeze(self.source, (0))
|
||||
return self
|
||||
|
||||
def float(self):
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def clamp_(self, min, max):
|
||||
self.source = np.clip(self.source, min, max)
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
return self.source
|
||||
|
||||
def size(self):
|
||||
return np.shape(self.source)
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
'''
|
||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||
'''
|
||||
|
||||
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
|
||||
'''
|
||||
TODO: get platform provider from request params
|
||||
'''
|
||||
model_path = path.join(ctx.model_path, model)
|
||||
self.session = InferenceSession(
|
||||
model_path, providers=[provider])
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return ONNXImage(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
||||
def half(self):
|
||||
return self
|
||||
|
||||
def load_state_dict(self, net, strict=True) -> None:
|
||||
pass
|
||||
|
||||
def to(self, device):
|
||||
return self
|
|
@ -0,0 +1,58 @@
|
|||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
OnnxRuntimeModel,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.models import (
|
||||
CLIPTokenizer,
|
||||
)
|
||||
from diffusers.schedulers import (
|
||||
KarrasDiffusionSchedulers,
|
||||
)
|
||||
from typing import (
|
||||
Callable,
|
||||
Union,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image,
|
||||
List[PIL.Image.Image]] = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[
|
||||
int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
pass
|
|
@ -1,14 +1,16 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from gfpgan import GFPGANer
|
||||
from onnxruntime import InferenceSession
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .onnx import (
|
||||
ONNXNet,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
Size,
|
||||
|
@ -19,69 +21,6 @@ pre_pad = 0
|
|||
tile_pad = 10
|
||||
|
||||
|
||||
class ONNXImage():
|
||||
def __init__(self, source) -> None:
|
||||
self.source = source
|
||||
self.data = self
|
||||
|
||||
def __getitem__(self, *args):
|
||||
return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32)
|
||||
|
||||
def squeeze(self):
|
||||
self.source = np.squeeze(self.source, (0))
|
||||
return self
|
||||
|
||||
def float(self):
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def clamp_(self, min, max):
|
||||
self.source = np.clip(self.source, min, max)
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
return self.source
|
||||
|
||||
def size(self):
|
||||
return np.shape(self.source)
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
'''
|
||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||
'''
|
||||
|
||||
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
|
||||
'''
|
||||
TODO: get platform provider from request params
|
||||
'''
|
||||
model_path = path.join(ctx.model_path, model)
|
||||
self.session = InferenceSession(
|
||||
model_path, providers=[provider])
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return ONNXImage(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
||||
def half(self):
|
||||
return self
|
||||
|
||||
def load_state_dict(self, net, strict=True) -> None:
|
||||
pass
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
|
||||
class UpscaleParams():
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -153,11 +92,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
|
|||
output = np.array(source_image)
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
|
||||
if params.scale > 1:
|
||||
output, _ = upsampler.enhance(output, outscale=params.outscale)
|
||||
|
||||
if params.faces:
|
||||
output = upscale_gfpgan(ctx, params, output, upsampler=upsampler)
|
||||
output, _ = upsampler.enhance(output, outscale=params.outscale)
|
||||
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
print('final output image size', output.size)
|
||||
|
@ -190,13 +125,23 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
|||
return output
|
||||
|
||||
|
||||
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
print('upscaling with Stable Diffusion')
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model)
|
||||
result = pipeline('', image=image)
|
||||
return result.images[0]
|
||||
|
||||
|
||||
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
print('running upscale pipeline')
|
||||
|
||||
if params.scale > 1:
|
||||
image = upscale_resrgan(ctx, params, image)
|
||||
if 'esrgan' in params.upscale_model:
|
||||
image = upscale_resrgan(ctx, params, image)
|
||||
elif 'stable-diffusion' in params.upscale_model:
|
||||
image = upscale_stable_diffusion(ctx, params, image)
|
||||
|
||||
if params.faces:
|
||||
image = upscale_gfpgan(ctx, params, image)
|
||||
|
||||
return image
|
||||
return image
|
||||
|
|
Loading…
Reference in New Issue