log pipeline timing, add common sizes
This commit is contained in:
parent
8d57d113cd
commit
bb3a7dc0e9
|
@ -1,5 +1,6 @@
|
|||
from PIL import Image
|
||||
from os import path
|
||||
from time import monotonic
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..params import (
|
||||
|
@ -51,10 +52,11 @@ class ChainPipeline:
|
|||
'''
|
||||
self.stages.append(stage)
|
||||
|
||||
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
|
||||
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
|
||||
'''
|
||||
TODO: handle List[Image] outputs
|
||||
'''
|
||||
start = monotonic()
|
||||
print('running pipeline on source image with dimensions %sx%s' %
|
||||
source.size)
|
||||
image = source
|
||||
|
@ -62,8 +64,10 @@ class ChainPipeline:
|
|||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.name or stage_pipe.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
print('running pipeline stage %s on result image with dimensions %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
|
||||
print('running stage %s on result image with dimensions %sx%s, %s' %
|
||||
(name, image.width, image.height, kwargs))
|
||||
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
print('source image larger than tile size of %s, tiling stage' % (
|
||||
|
@ -85,8 +89,10 @@ class ChainPipeline:
|
|||
image = stage_pipe(ctx, stage_params, params, image,
|
||||
**kwargs)
|
||||
|
||||
print('finished running pipeline stage %s, result size: %sx%s' %
|
||||
print('finished stage %s, result size: %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
|
||||
print('finished running pipeline, result size: %sx%s' % image.size)
|
||||
end = monotonic()
|
||||
duration = end - start
|
||||
print('finished pipeline in %s seconds, result size: %sx%s' % (duration, image.width, image.height))
|
||||
return image
|
||||
|
|
|
@ -17,6 +17,7 @@ from ..params import (
|
|||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
|
@ -93,7 +94,7 @@ def blend_inpaint(
|
|||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tiles(source_image, 512, 1, [outpaint])
|
||||
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
print('final output image size', output.size)
|
||||
return output
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
import torch
|
||||
|
||||
pretrain_model_url = {
|
||||
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
|
||||
}
|
||||
|
||||
device = 'cpu'
|
||||
upscale = 2
|
||||
|
||||
def correct_codeformer(image: Image.Image) -> Image.Image:
|
||||
# ------------------ set up CodeFormer restorer -------------------
|
||||
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
|
||||
connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
||||
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
|
||||
model_dir='weights/CodeFormer', progress=True, file_name=None)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
# ------------------ set up FaceRestoreHelper -------------------
|
||||
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
||||
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
||||
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model = args.detection_model,
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=device)
|
||||
|
||||
# get face landmarks for each face
|
||||
num_det_faces = face_helper.get_face_landmarks_5(
|
||||
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
|
||||
print(f'\tdetect {num_det_faces} faces')
|
||||
# align and warp each face
|
||||
face_helper.align_warp_face()
|
||||
|
||||
# face restoration for each cropped face
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = net(cropped_face_t, w=w, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}')
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face, cropped_face)
|
||||
|
||||
# upsample the background
|
||||
if bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
|
||||
# paste_back
|
||||
face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
if face_upsampler is not None:
|
||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler)
|
||||
else:
|
||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
|
||||
|
||||
return restored_img
|
|
@ -1,5 +1,4 @@
|
|||
from boto3 import (
|
||||
ClientError,
|
||||
Session,
|
||||
)
|
||||
from io import BytesIO
|
||||
|
@ -25,16 +24,17 @@ def persist_s3(
|
|||
endpoint_url: str = None,
|
||||
profile_name: str = None,
|
||||
) -> Image.Image:
|
||||
sess = Session(profile_name=profile_name)
|
||||
s3 = sess.client('s3', endpoint_url=endpoint_url)
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
||||
|
||||
data = BytesIO()
|
||||
source_image.save(data, format='png')
|
||||
data.seek(0)
|
||||
|
||||
try:
|
||||
response = s3.upload_fileobj(data.getvalue(), bucket, output)
|
||||
response = s3.upload_fileobj(data, bucket, output)
|
||||
print('saved image to %s' % (response))
|
||||
except ClientError as err:
|
||||
except Exception as err:
|
||||
print('error saving image to S3: %s' % (err))
|
||||
|
||||
return source_image
|
||||
|
|
|
@ -17,6 +17,7 @@ from ..params import (
|
|||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
|
@ -93,7 +94,7 @@ def upscale_outpaint(
|
|||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tiles(source_image, 512, 1, [outpaint])
|
||||
output = process_tiles(source_image, SizeChart.auto.value, 1, [outpaint])
|
||||
|
||||
print('final output image size', output.size)
|
||||
return output
|
||||
|
|
|
@ -25,7 +25,7 @@ def process_tiles(
|
|||
idx = (y * tiles_x) + x
|
||||
left = x * tile
|
||||
top = y * tile
|
||||
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
|
||||
print('processing tile %s of %s, %s.%s' % (idx + 1, total, y, x))
|
||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
|
|
|
@ -1,6 +1,19 @@
|
|||
from enum import IntEnum
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
mini = 128 # small tile for very expensive models
|
||||
half = 256 # half tile for outpainting
|
||||
auto = 512 # auto tile size
|
||||
hd1k = 2**10
|
||||
hd2k = 2**11
|
||||
hd4k = 2**12
|
||||
hd8k = 2**13
|
||||
hd16k = 2**14
|
||||
hd64k = 2**16
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
@ -74,7 +87,7 @@ class StageParams:
|
|||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
tile_size: int = 512,
|
||||
tile_size: int = SizeChart.auto,
|
||||
outscale: int = 1,
|
||||
# batch_size: int = 1,
|
||||
) -> None:
|
||||
|
|
|
@ -8,6 +8,7 @@ from .chain import (
|
|||
)
|
||||
from .params import (
|
||||
ImageParams,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
|
@ -39,7 +40,7 @@ def run_upscale_correction(
|
|||
outscale=upscale.outscale)
|
||||
chain.append((upscale_resrgan, stage, kwargs))
|
||||
elif 'stable-diffusion' in upscale.upscale_model:
|
||||
mini_tile = min(128, stage.tile_size)
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
chain.append((upscale_stable_diffusion, stage, kwargs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue