clean up draft CodeFormer stage
This commit is contained in:
parent
54dd34d211
commit
35681efc1b
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -27,9 +27,15 @@ def correct_codeformer(
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
ARCH_REGISTRY = {}
|
||||||
|
bg_upsampler = None
|
||||||
|
face_upsampler = None
|
||||||
model = "TODO"
|
model = "TODO"
|
||||||
|
w = None
|
||||||
|
|
||||||
# ------------------ set up CodeFormer restorer -------------------
|
# ------------------ set up CodeFormer restorer -------------------
|
||||||
net = ARCH_REGISTRY.get("CodeFormer")(
|
net = ARCH_REGISTRY.get("CodeFormer")(
|
||||||
|
@ -47,7 +53,8 @@ def correct_codeformer(
|
||||||
progress=True,
|
progress=True,
|
||||||
file_name=None,
|
file_name=None,
|
||||||
)
|
)
|
||||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
checkpoint = torch.load(ckpt_path)
|
||||||
|
checkpoint = checkpoint["params_ema"]
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
|
@ -96,7 +103,7 @@ def correct_codeformer(
|
||||||
# upsample the background
|
# upsample the background
|
||||||
if bg_upsampler is not None:
|
if bg_upsampler is not None:
|
||||||
# Now only support RealESRGAN for upsampling background
|
# Now only support RealESRGAN for upsampling background
|
||||||
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
|
bg_img = bg_upsampler.enhance(source_image, outscale=upscale.scale)[0]
|
||||||
else:
|
else:
|
||||||
bg_img = None
|
bg_img = None
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,7 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
description-file = README.md
|
description-file = README.md
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
ignore = E203, W503
|
||||||
|
max-line-length = 160
|
||||||
|
per-file-ignores = __init__.py:F401
|
|
@ -14,6 +14,10 @@
|
||||||
"settings": {
|
"settings": {
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
"astype",
|
"astype",
|
||||||
|
"basicsr",
|
||||||
|
"ckpt",
|
||||||
|
"codebook",
|
||||||
|
"codeformer",
|
||||||
"CUDA",
|
"CUDA",
|
||||||
"ddim",
|
"ddim",
|
||||||
"ddpm",
|
"ddpm",
|
||||||
|
@ -43,11 +47,13 @@
|
||||||
"outpainting",
|
"outpainting",
|
||||||
"outscale",
|
"outscale",
|
||||||
"pndm",
|
"pndm",
|
||||||
|
"pretrain",
|
||||||
"pretrained",
|
"pretrained",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
"randn",
|
"randn",
|
||||||
"realesr",
|
"realesr",
|
||||||
"resrgan",
|
"resrgan",
|
||||||
|
"retinaface",
|
||||||
"rocm",
|
"rocm",
|
||||||
"RRDB",
|
"RRDB",
|
||||||
"runwayml",
|
"runwayml",
|
||||||
|
@ -63,10 +69,14 @@
|
||||||
"timestep",
|
"timestep",
|
||||||
"timesteps",
|
"timesteps",
|
||||||
"tojson",
|
"tojson",
|
||||||
|
"torchvision",
|
||||||
"uncond",
|
"uncond",
|
||||||
"unet",
|
"unet",
|
||||||
|
"unsqueeze",
|
||||||
"untruncated",
|
"untruncated",
|
||||||
|
"upsample",
|
||||||
"upsampler",
|
"upsampler",
|
||||||
|
"upsampling",
|
||||||
"upscaling",
|
"upscaling",
|
||||||
"venv",
|
"venv",
|
||||||
"virtualenv",
|
"virtualenv",
|
||||||
|
|
Loading…
Reference in New Issue