1
0
Fork 0

fix(api): split up and cache upscaling and correction stages

This commit is contained in:
Sean Sube 2023-01-27 23:28:14 -06:00
parent caafc9ebc9
commit 56a4acee2a
15 changed files with 371 additions and 209 deletions

View File

@ -1,3 +1,8 @@
from .chain import (
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion import (
get_latents_from_seed,
load_pipeline,
@ -18,17 +23,16 @@ from .image import (
noise_source_uniform,
)
from .params import (
UpscaleParams,
ImageParams,
Border,
Param,
Point,
Border,
Size,
ImageParams,
StageParams,
UpscaleParams,
)
from .upscale import (
load_resrgan,
run_upscale_correction,
correct_gfpgan,
upscale_resrgan,
)
from .utils import (
get_and_clamp_float,

View File

@ -4,3 +4,12 @@ from .base import (
StageCallback,
StageParams,
)
from .correct_gfpgan import (
correct_gfpgan,
)
from .upscale_resrgan import (
upscale_resrgan,
)
from .upscale_stable_diffusion import (
upscale_stable_diffusion,
)

View File

@ -1,15 +1,12 @@
from PIL import Image
from os import path
from typing import Any, List, Optional, Protocol, Tuple
from typing import Any, Callable, List, Optional, Protocol, Tuple
from ..image import (
process_tiles,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ImageParams,
ServerContext,
)
@ -29,6 +26,35 @@ class StageCallback(Protocol):
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
def process_tiles(
source: Image,
tile: int,
scale: int,
filters: List[Callable],
) -> Image:
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
tiles_x = width // tile
tiles_y = height // tile
total = tiles_x * tiles_y
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * tile
top = y * tile
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters:
tile_image = filter(tile_image)
image.paste(tile_image, (left * scale, top * scale))
return image
class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
@ -59,7 +85,7 @@ class ChainPipeline:
image = source
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.label or stage_pipe.__name__
name = stage_params.name or stage_pipe.__name__
kwargs = stage_kwargs or {}
print('running pipeline stage %s on result image with dimensions %sx%s' %
(name, image.width, image.height))

View File

@ -0,0 +1,68 @@
from gfpgan import GFPGANer
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Optional
from ..params import (
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
ServerContext,
)
from .upscale_resrgan import (
load_resrgan,
)
last_pipeline_instance = None
last_pipeline_params = None
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams):
if upsampler is None:
upsampler = load_resrgan(ctx, upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
if last_pipeline_instance != None and face_path == last_pipeline_params:
print('reusing existing GFPGAN pipeline')
return last_pipeline_instance
# TODO: doesn't have a model param, not sure how to pass ONNX model
gfpgan = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
last_pipeline_instance = gfpgan
last_pipeline_params = face_path
return gfpgan
def correct_gfpgan(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
) -> Image:
if upscale.correction_model is None:
print('no face model given, skipping')
return image
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
gfpgan = load_gfpgan(ctx, upscale)
_, _, output = gfpgan.enhance(
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
return output

View File

@ -0,0 +1,84 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from ..onnx import (
OnnxNet,
)
from ..params import (
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
ServerContext,
)
import numpy as np
last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path)
if last_pipeline_instance != None and (model_path, params.format) == last_pipeline_params:
print('reusing existing Real ESRGAN pipeline')
return last_pipeline_instance
# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=params.provider)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
raise Exception('unknown platform %s' % params.format)
dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
# TODO: shouldn't need the PTH file
upsampler = RealESRGANer(
scale=params.scale,
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=params.half)
last_pipeline_instance = upsampler
last_pipeline_params = (model_path, params.format)
return upsampler
def upscale_resrgan(
ctx: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB')
print('final output image size', output.size)
return output

View File

@ -0,0 +1,72 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionUpscalePipeline,
)
from os import path
from PIL import Image
from ..onnx import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import (
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
ServerContext,
)
import torch
last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
model_path = path.join(ctx.model_path, upscale.upscale_model)
if last_pipeline_instance != None and (model_path, upscale.format) == last_pipeline_params:
print('reusing existing Stable Diffusion upscale pipeline')
return last_pipeline_instance
if upscale.format == 'onnx':
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
return pipeline
def upscale_stable_diffusion(
ctx: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
) -> Image.Image:
print('upscaling with Stable Diffusion')
pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
seed = generator.initial_seed()
return pipeline(
params.prompt,
source,
generator=torch.manual_seed(seed),
num_inference_steps=params.steps,
).images[0]

View File

@ -8,10 +8,6 @@ from diffusers import (
from PIL import Image, ImageChops
from typing import Any, Optional
import gc
import numpy as np
import torch
from .chain import (
StageParams,
)
@ -33,6 +29,10 @@ from .utils import (
ServerContext,
)
import gc
import numpy as np
import torch
last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None

View File

@ -185,32 +185,3 @@ def expand_image(
full_noise, full_source, full_mask.convert('L'))
return (full_source, full_mask, full_noise, (full_width, full_height))
def process_tiles(
source: Image,
tile: int,
scale: int,
filters: List[Callable],
) -> Image:
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
tiles_x = width // tile
tiles_y = height // tile
total = tiles_x * tiles_y
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * tile
top = y * tile
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters:
tile_image = filter(tile_image)
image.paste(tile_image, (left * scale, top * scale))
return image

View File

@ -1,6 +1,6 @@
from .onnx_net import (
ONNXImage,
ONNXNet,
OnnxImage,
OnnxNet,
)
from .pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,

View File

@ -9,7 +9,7 @@ from ..utils import (
ServerContext,
)
class ONNXImage():
class OnnxImage():
def __init__(self, source) -> None:
self.source = source
self.data = self
@ -38,7 +38,7 @@ class ONNXImage():
return np.shape(self.source)
class ONNXNet():
class OnnxNet():
'''
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''
@ -57,7 +57,7 @@ class ONNXNet():
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return ONNXImage(output)
return OnnxImage(output)
def eval(self) -> None:
pass

View File

@ -35,3 +35,76 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
**kwargs,
):
super().__call__(*args, **kwargs)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device))
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device))
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings

View File

@ -45,12 +45,12 @@ from .params import (
Border,
ImageParams,
Size,
UpscaleParams,
)
from .upscale import (
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
UpscaleParams,
)
from .utils import (
is_debug,

View File

@ -1,29 +1,14 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionUpscalePipeline,
)
from gfpgan import GFPGANer
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Optional
import numpy as np
import torch
from .chain import (
correct_gfpgan,
upscale_stable_diffusion,
upscale_resrgan,
ChainPipeline,
StageParams,
)
from .onnx import (
ONNXNet,
OnnxStableDiffusionUpscalePipeline,
)
from .params import (
ImageParams,
Size,
StageParams,
UpscaleParams,
)
from .utils import (
@ -31,144 +16,6 @@ from .utils import (
)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
'''
TODO: cache this instance
'''
model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path)
# use ONNX acceleration, if available
if params.format == 'onnx':
model = ONNXNet(ctx, model_file, provider=params.provider)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
raise Exception('unknown platform %s' % params.format)
dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
# TODO: shouldn't need the PTH file
upsampler = RealESRGANer(
scale=params.scale,
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=params.half)
return upsampler
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
'''
TODO: cache this instance
'''
if upscale.format == 'onnx':
model_path = path.join(ctx.model_path, upscale.upscale_model)
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
return pipeline
def upscale_resrgan(
ctx: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB')
print('final output image size', output.size)
return output
def correct_gfpgan(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
) -> Image:
if upscale.correction_model is None:
print('no face model given, skipping')
return image
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
if upsampler is None:
upsampler = load_resrgan(ctx, upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
# TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
return output
def upscale_stable_diffusion(
ctx: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling with Stable Diffusion')
pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
seed = generator.initial_seed()
return pipeline(
params.prompt,
source,
generator=torch.manual_seed(seed),
num_inference_steps=params.steps,
).images[0]
def run_upscale_correction(
ctx: ServerContext,
stage: StageParams,

View File

@ -1,7 +1,7 @@
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
from hashlib import sha256
from .params import (

View File

@ -12,6 +12,10 @@
"max": 30,
"step": 0.1
},
"correction": {
"default": "",
"keys": []
},
"denoise": {
"default": 0.5,
"min": 0,
@ -110,6 +114,10 @@
"max": 512,
"step": 8
},
"upscaling": {
"default": "",
"keys": []
},
"width": {
"default": 512,
"min": 64,