fix(api): split up and cache upscaling and correction stages
This commit is contained in:
parent
caafc9ebc9
commit
56a4acee2a
|
@ -1,3 +1,8 @@
|
|||
from .chain import (
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .diffusion import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
|
@ -18,17 +23,16 @@ from .image import (
|
|||
noise_source_uniform,
|
||||
)
|
||||
from .params import (
|
||||
UpscaleParams,
|
||||
ImageParams,
|
||||
Border,
|
||||
Param,
|
||||
Point,
|
||||
Border,
|
||||
Size,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .upscale import (
|
||||
load_resrgan,
|
||||
run_upscale_correction,
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
)
|
||||
from .utils import (
|
||||
get_and_clamp_float,
|
||||
|
|
|
@ -4,3 +4,12 @@ from .base import (
|
|||
StageCallback,
|
||||
StageParams,
|
||||
)
|
||||
from .correct_gfpgan import (
|
||||
correct_gfpgan,
|
||||
)
|
||||
from .upscale_resrgan import (
|
||||
upscale_resrgan,
|
||||
)
|
||||
from .upscale_stable_diffusion import (
|
||||
upscale_stable_diffusion,
|
||||
)
|
|
@ -1,15 +1,12 @@
|
|||
from PIL import Image
|
||||
from os import path
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..image import (
|
||||
process_tiles,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ImageParams,
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
|
@ -29,6 +26,35 @@ class StageCallback(Protocol):
|
|||
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
def process_tiles(
|
||||
source: Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[Callable],
|
||||
) -> Image:
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
|
||||
tiles_x = width // tile
|
||||
tiles_y = height // tile
|
||||
total = tiles_x * tiles_y
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
idx = (y * tiles_x) + x
|
||||
left = x * tile
|
||||
top = y * tile
|
||||
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
|
||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
tile_image = filter(tile_image)
|
||||
|
||||
image.paste(tile_image, (left * scale, top * scale))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
'''
|
||||
Run many stages in series, passing the image results from each to the next, and processing
|
||||
|
@ -59,7 +85,7 @@ class ChainPipeline:
|
|||
image = source
|
||||
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.label or stage_pipe.__name__
|
||||
name = stage_params.name or stage_pipe.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
print('running pipeline stage %s on result image with dimensions %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
from gfpgan import GFPGANer
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Optional
|
||||
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from .upscale_resrgan import (
|
||||
load_resrgan,
|
||||
)
|
||||
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = None
|
||||
|
||||
|
||||
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams):
|
||||
if upsampler is None:
|
||||
upsampler = load_resrgan(ctx, upscale)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' %
|
||||
(upscale.correction_model))
|
||||
|
||||
if last_pipeline_instance != None and face_path == last_pipeline_params:
|
||||
print('reusing existing GFPGAN pipeline')
|
||||
return last_pipeline_instance
|
||||
|
||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||
gfpgan = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=upscale.outscale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
|
||||
last_pipeline_instance = gfpgan
|
||||
last_pipeline_params = face_path
|
||||
|
||||
return gfpgan
|
||||
|
||||
|
||||
def correct_gfpgan(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
upsampler: Optional[RealESRGANer] = None,
|
||||
) -> Image:
|
||||
if upscale.correction_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
||||
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
|
||||
gfpgan = load_gfpgan(ctx, upscale)
|
||||
|
||||
_, _, output = gfpgan.enhance(
|
||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
||||
|
||||
return output
|
|
@ -0,0 +1,84 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from ..onnx import (
|
||||
OnnxNet,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = (None, None)
|
||||
|
||||
|
||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
if not path.isfile(model_path):
|
||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
||||
|
||||
if last_pipeline_instance != None and (model_path, params.format) == last_pipeline_params:
|
||||
print('reusing existing Real ESRGAN pipeline')
|
||||
return last_pipeline_instance
|
||||
|
||||
# use ONNX acceleration, if available
|
||||
if params.format == 'onnx':
|
||||
model = OnnxNet(ctx, model_file, provider=params.provider)
|
||||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
raise Exception('unknown platform %s' % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(
|
||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||
model_path = [model_path, wdn_model_path]
|
||||
dni_weight = [params.denoise, 1 - params.denoise]
|
||||
|
||||
# TODO: shouldn't need the PTH file
|
||||
upsampler = RealESRGANer(
|
||||
scale=params.scale,
|
||||
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
|
||||
dni_weight=dni_weight,
|
||||
model=model,
|
||||
tile=tile,
|
||||
tile_pad=params.tile_pad,
|
||||
pre_pad=params.pre_pad,
|
||||
half=params.half)
|
||||
|
||||
last_pipeline_instance = upsampler
|
||||
last_pipeline_params = (model_path, params.format)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
||||
def upscale_resrgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
print('upscaling image with Real ESRGAN', upscale.scale)
|
||||
|
||||
output = np.array(source_image)
|
||||
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
print('final output image size', output.size)
|
||||
return output
|
|
@ -0,0 +1,72 @@
|
|||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from os import path
|
||||
from PIL import Image
|
||||
|
||||
from ..onnx import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = (None, None)
|
||||
|
||||
|
||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||
|
||||
if last_pipeline_instance != None and (model_path, upscale.format) == last_pipeline_params:
|
||||
print('reusing existing Stable Diffusion upscale pipeline')
|
||||
return last_pipeline_instance
|
||||
|
||||
if upscale.format == 'onnx':
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
vae=AutoencoderKL.from_pretrained(
|
||||
model_path, subfolder='vae_encoder'),
|
||||
low_res_scheduler=DDPMScheduler.from_pretrained(
|
||||
model_path, subfolder='scheduler'),
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler')
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def upscale_stable_diffusion(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image.Image:
|
||||
print('upscaling with Stable Diffusion')
|
||||
|
||||
pipeline = load_stable_diffusion(ctx, upscale)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
seed = generator.initial_seed()
|
||||
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
source,
|
||||
generator=torch.manual_seed(seed),
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
|
@ -8,10 +8,6 @@ from diffusers import (
|
|||
from PIL import Image, ImageChops
|
||||
from typing import Any, Optional
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .chain import (
|
||||
StageParams,
|
||||
)
|
||||
|
@ -33,6 +29,10 @@ from .utils import (
|
|||
ServerContext,
|
||||
)
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None, None)
|
||||
last_pipeline_scheduler = None
|
||||
|
|
|
@ -185,32 +185,3 @@ def expand_image(
|
|||
full_noise, full_source, full_mask.convert('L'))
|
||||
|
||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||
|
||||
|
||||
def process_tiles(
|
||||
source: Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[Callable],
|
||||
) -> Image:
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
|
||||
tiles_x = width // tile
|
||||
tiles_y = height // tile
|
||||
total = tiles_x * tiles_y
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
idx = (y * tiles_x) + x
|
||||
left = x * tile
|
||||
top = y * tile
|
||||
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
|
||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
tile_image = filter(tile_image)
|
||||
|
||||
image.paste(tile_image, (left * scale, top * scale))
|
||||
|
||||
return image
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .onnx_net import (
|
||||
ONNXImage,
|
||||
ONNXNet,
|
||||
OnnxImage,
|
||||
OnnxNet,
|
||||
)
|
||||
from .pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..utils import (
|
|||
ServerContext,
|
||||
)
|
||||
|
||||
class ONNXImage():
|
||||
class OnnxImage():
|
||||
def __init__(self, source) -> None:
|
||||
self.source = source
|
||||
self.data = self
|
||||
|
@ -38,7 +38,7 @@ class ONNXImage():
|
|||
return np.shape(self.source)
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
class OnnxNet():
|
||||
'''
|
||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||
'''
|
||||
|
@ -57,7 +57,7 @@ class ONNXNet():
|
|||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return ONNXImage(output)
|
||||
return OnnxImage(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -35,3 +35,76 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
**kwargs,
|
||||
):
|
||||
super().__call__(*args, **kwargs)
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device))
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device))
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
|
|
@ -45,12 +45,12 @@ from .params import (
|
|||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .upscale import (
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
is_debug,
|
||||
|
|
|
@ -1,29 +1,14 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from gfpgan import GFPGANer
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .chain import (
|
||||
correct_gfpgan,
|
||||
upscale_stable_diffusion,
|
||||
upscale_resrgan,
|
||||
ChainPipeline,
|
||||
StageParams,
|
||||
)
|
||||
from .onnx import (
|
||||
ONNXNet,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
|
@ -31,144 +16,6 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
'''
|
||||
TODO: cache this instance
|
||||
'''
|
||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
if not path.isfile(model_path):
|
||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
||||
|
||||
# use ONNX acceleration, if available
|
||||
if params.format == 'onnx':
|
||||
model = ONNXNet(ctx, model_file, provider=params.provider)
|
||||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
raise Exception('unknown platform %s' % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(
|
||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||
model_path = [model_path, wdn_model_path]
|
||||
dni_weight = [params.denoise, 1 - params.denoise]
|
||||
|
||||
# TODO: shouldn't need the PTH file
|
||||
upsampler = RealESRGANer(
|
||||
scale=params.scale,
|
||||
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
|
||||
dni_weight=dni_weight,
|
||||
model=model,
|
||||
tile=tile,
|
||||
tile_pad=params.tile_pad,
|
||||
pre_pad=params.pre_pad,
|
||||
half=params.half)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||
'''
|
||||
TODO: cache this instance
|
||||
'''
|
||||
if upscale.format == 'onnx':
|
||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
vae=AutoencoderKL.from_pretrained(
|
||||
model_path, subfolder='vae_encoder'),
|
||||
low_res_scheduler=DDPMScheduler.from_pretrained(
|
||||
model_path, subfolder='scheduler'),
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler')
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def upscale_resrgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
print('upscaling image with Real ESRGAN', upscale.scale)
|
||||
|
||||
output = np.array(source_image)
|
||||
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
print('final output image size', output.size)
|
||||
return output
|
||||
|
||||
|
||||
def correct_gfpgan(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
upsampler: Optional[RealESRGANer] = None,
|
||||
) -> Image:
|
||||
if upscale.correction_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
||||
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
|
||||
|
||||
if upsampler is None:
|
||||
upsampler = load_resrgan(ctx, upscale)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' %
|
||||
(upscale.correction_model))
|
||||
|
||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=upscale.outscale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
|
||||
_, _, output = face_enhancer.enhance(
|
||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def upscale_stable_diffusion(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
print('upscaling with Stable Diffusion')
|
||||
|
||||
pipeline = load_stable_diffusion(ctx, upscale)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
seed = generator.initial_seed()
|
||||
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
source,
|
||||
generator=torch.manual_seed(seed),
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
|
||||
|
||||
def run_upscale_correction(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from os import environ, path
|
||||
from time import time
|
||||
from struct import pack
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from hashlib import sha256
|
||||
|
||||
from .params import (
|
||||
|
|
|
@ -12,6 +12,10 @@
|
|||
"max": 30,
|
||||
"step": 0.1
|
||||
},
|
||||
"correction": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
},
|
||||
"denoise": {
|
||||
"default": 0.5,
|
||||
"min": 0,
|
||||
|
@ -110,6 +114,10 @@
|
|||
"max": 512,
|
||||
"step": 8
|
||||
},
|
||||
"upscaling": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
},
|
||||
"width": {
|
||||
"default": 512,
|
||||
"min": 64,
|
||||
|
|
Loading…
Reference in New Issue