1
0
Fork 0

clean up draft CodeFormer stage

This commit is contained in:
Sean Sube 2023-02-05 08:06:50 -06:00
parent 54dd34d211
commit 35681efc1b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 26 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from PIL import Image
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from ..device_pool import JobContext from ..device_pool import JobContext
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -27,9 +27,15 @@ def correct_codeformer(
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*,
upscale: UpscaleParams = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
ARCH_REGISTRY = {}
bg_upsampler = None
face_upsampler = None
model = "TODO" model = "TODO"
w = None
# ------------------ set up CodeFormer restorer ------------------- # ------------------ set up CodeFormer restorer -------------------
net = ARCH_REGISTRY.get("CodeFormer")( net = ARCH_REGISTRY.get("CodeFormer")(
@ -47,7 +53,8 @@ def correct_codeformer(
progress=True, progress=True,
file_name=None, file_name=None,
) )
checkpoint = torch.load(ckpt_path)["params_ema"] checkpoint = torch.load(ckpt_path)
checkpoint = checkpoint["params_ema"]
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
@ -96,7 +103,7 @@ def correct_codeformer(
# upsample the background # upsample the background
if bg_upsampler is not None: if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background # Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] bg_img = bg_upsampler.enhance(source_image, outscale=upscale.scale)[0]
else: else:
bg_img = None bg_img = None

View File

@ -1,2 +1,7 @@
[metadata] [metadata]
description-file = README.md description-file = README.md
[flake8]
ignore = E203, W503
max-line-length = 160
per-file-ignores = __init__.py:F401

View File

@ -14,6 +14,10 @@
"settings": { "settings": {
"cSpell.words": [ "cSpell.words": [
"astype", "astype",
"basicsr",
"ckpt",
"codebook",
"codeformer",
"CUDA", "CUDA",
"ddim", "ddim",
"ddpm", "ddpm",
@ -43,11 +47,13 @@
"outpainting", "outpainting",
"outscale", "outscale",
"pndm", "pndm",
"pretrain",
"pretrained", "pretrained",
"protobuf", "protobuf",
"randn", "randn",
"realesr", "realesr",
"resrgan", "resrgan",
"retinaface",
"rocm", "rocm",
"RRDB", "RRDB",
"runwayml", "runwayml",
@ -63,10 +69,14 @@
"timestep", "timestep",
"timesteps", "timesteps",
"tojson", "tojson",
"torchvision",
"uncond", "uncond",
"unet", "unet",
"unsqueeze",
"untruncated", "untruncated",
"upsample",
"upsampler", "upsampler",
"upsampling",
"upscaling", "upscaling",
"venv", "venv",
"virtualenv", "virtualenv",