This commit is contained in:
parent
464bfd01b8
commit
62aa7e8473
|
@ -10,9 +10,11 @@ from .reduce_crop import reduce_crop
|
|||
from .reduce_thumbnail import reduce_thumbnail
|
||||
from .source_noise import source_noise
|
||||
from .source_txt2img import source_txt2img
|
||||
from .upscale_bsrgan import upscale_bsrgan
|
||||
from .upscale_outpaint import upscale_outpaint
|
||||
from .upscale_resrgan import upscale_resrgan
|
||||
from .upscale_stable_diffusion import upscale_stable_diffusion
|
||||
from .upscale_swinir import upscale_swinir
|
||||
|
||||
CHAIN_STAGES = {
|
||||
"blend-img2img": blend_img2img,
|
||||
|
@ -26,7 +28,9 @@ CHAIN_STAGES = {
|
|||
"reduce-thumbnail": reduce_thumbnail,
|
||||
"source-noise": source_noise,
|
||||
"source-txt2img": source_txt2img,
|
||||
"upscale-bsrgan": upscale_bsrgan,
|
||||
"upscale-outpaint": upscale_outpaint,
|
||||
"upscale-resrgan": upscale_resrgan,
|
||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||
"upscale-swinir": upscale_swinir,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def load_bsrgan(
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
device: DeviceParams,
|
||||
):
|
||||
# must be within the load function for patch to take effect
|
||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||
cache_key = (model_path,)
|
||||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing BSRGAN pipeline")
|
||||
return cache_pipe
|
||||
|
||||
logger.debug("loading BSRGAN model from %s", model_path)
|
||||
|
||||
pipe = OnnxModel(
|
||||
server,
|
||||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
server.cache.set("bsrgan", cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def upscale_bsrgan(
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warn("no upscaling model given, skipping")
|
||||
return source
|
||||
|
||||
logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
bsrgan = load_bsrgan(server, stage, upscale, device)
|
||||
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
tile_y = source.height // tile_size[1]
|
||||
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.info("BSRGAN input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||
)
|
||||
logger.info("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
for x in range(tile_x):
|
||||
for y in range(tile_y):
|
||||
xt = x * tile_size[0]
|
||||
yt = y * tile_size[1]
|
||||
|
||||
ix1 = xt
|
||||
ix2 = xt + tile_size[0]
|
||||
iy1 = yt
|
||||
iy2 = yt + tile_size[1]
|
||||
logger.info(
|
||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||
ix1,
|
||||
ix2,
|
||||
iy1,
|
||||
iy2,
|
||||
ix1 * scale,
|
||||
ix2 * scale,
|
||||
iy1 * scale,
|
||||
iy2 * scale,
|
||||
)
|
||||
|
||||
dest[
|
||||
:,
|
||||
:,
|
||||
ix1 * scale : ix2 * scale,
|
||||
iy1 * scale : iy2 * scale,
|
||||
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
|
||||
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
|
@ -0,0 +1,118 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def load_swinir(
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
device: DeviceParams,
|
||||
):
|
||||
# must be within the load function for patch to take effect
|
||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||
cache_key = (model_path,)
|
||||
cache_pipe = server.cache.get("swinir", cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing SwinIR pipeline")
|
||||
return cache_pipe
|
||||
|
||||
logger.debug("loading SwinIR model from %s", model_path)
|
||||
|
||||
pipe = OnnxModel(
|
||||
server,
|
||||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
server.cache.set("swinir", cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def upscale_swinir(
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warn("no correction model given, skipping")
|
||||
return source
|
||||
|
||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
swinir = load_swinir(server, stage, upscale, device)
|
||||
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
tile_y = source.height // tile_size[1]
|
||||
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.info("SwinIR input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||
)
|
||||
logger.info("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
for x in range(tile_x):
|
||||
for y in range(tile_y):
|
||||
xt = x * tile_size[0]
|
||||
yt = y * tile_size[1]
|
||||
|
||||
ix1 = xt
|
||||
ix2 = xt + tile_size[0]
|
||||
iy1 = yt
|
||||
iy2 = yt + tile_size[1]
|
||||
logger.info(
|
||||
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||
ix1,
|
||||
ix2,
|
||||
iy1,
|
||||
iy2,
|
||||
ix1 * scale,
|
||||
ix2 * scale,
|
||||
iy1 * scale,
|
||||
iy2 * scale,
|
||||
)
|
||||
|
||||
dest[
|
||||
:,
|
||||
:,
|
||||
ix1 * scale : ix2 * scale,
|
||||
iy1 * scale : iy2 * scale,
|
||||
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
|
||||
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
|
@ -13,12 +13,14 @@ from transformers import CLIPTokenizer
|
|||
from yaml import safe_load
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.lora import blend_loras
|
||||
from .diffusion.original import convert_diffusion_original
|
||||
from .diffusion.textual_inversion import blend_textual_inversions
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
||||
from .upscaling.resrgan import convert_upscale_resrgan
|
||||
from .upscaling.swinir import convert_upscaling_swinir
|
||||
from .utils import (
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
|
@ -108,6 +110,18 @@ base_models: Models = {
|
|||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
4,
|
||||
),
|
||||
{
|
||||
"model": "swinir",
|
||||
"name": "upscaling-swinir-classical-x4",
|
||||
"source": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth",
|
||||
"scale": 4,
|
||||
},
|
||||
{
|
||||
"model": "bsrgan",
|
||||
"name": "upscaling-bsrgan-x4",
|
||||
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
|
||||
"scale": 4,
|
||||
},
|
||||
],
|
||||
# download only
|
||||
"sources": [
|
||||
|
@ -415,7 +429,17 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
source = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown upscaling model type %s for %s", model_type, name
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting upscaling model %s",
|
||||
|
@ -435,7 +459,13 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
source = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
if model_type == "gfpgan":
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown correction model type %s for %s", model_type, name
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting correction model %s",
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.onnx import export
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +18,7 @@ def convert_correction_gfpgan(
|
|||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
scale = model.get("scale", 1)
|
||||
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
|
@ -277,6 +277,7 @@ def blend_loras(
|
|||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||
logger.trace("fixed node names: %s", fixed_node_names)
|
||||
|
||||
unmatched_keys = []
|
||||
for base_key, weights in blended.items():
|
||||
conv_key = base_key + "_Conv"
|
||||
gemm_key = base_key + "_Gemm"
|
||||
|
@ -377,7 +378,7 @@ def blend_loras(
|
|||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
else:
|
||||
logger.warning("could not find any nodes for %s", base_key)
|
||||
unmatched_keys.append(base_key)
|
||||
|
||||
logger.debug(
|
||||
"node counts: %s -> %s, %s -> %s",
|
||||
|
@ -387,6 +388,9 @@ def blend_loras(
|
|||
len(base_model.graph.node),
|
||||
)
|
||||
|
||||
if len(unmatched_keys) > 0:
|
||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
||||
|
||||
return base_model
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ...models.rrdb import RRDBNet
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscaling_bsrgan(
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale", 1)
|
||||
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting BSRGAN model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
in_nc=3,
|
||||
out_nc=3,
|
||||
nf=64,
|
||||
nb=23,
|
||||
gc=32,
|
||||
sf=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
elif "params" in torch_model:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model, strict=False)
|
||||
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||
input_names = ["input"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"input": {2: "h", 3: "w"},
|
||||
"output": {2: "h", 3: "w"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=conversion.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("BSRGAN exported to ONNX successfully")
|
|
@ -4,7 +4,7 @@ from os import path
|
|||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ...models.swinir import SwinIR
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscaling_swinir(
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale", 1)
|
||||
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting SwinIR model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
|
||||
logger.info("loading and training model")
|
||||
img_size = (64, 64) # TODO: does this need to be a fixed value?
|
||||
model = SwinIR(
|
||||
depths=[6, 6, 6, 6, 6, 6],
|
||||
embed_dim=180,
|
||||
img_range=1.0,
|
||||
img_size=img_size,
|
||||
in_chans=3,
|
||||
mlp_ratio=2,
|
||||
num_heads=[6, 6, 6, 6, 6, 6],
|
||||
resi_connection="1conv",
|
||||
upscale=scale,
|
||||
upsampler="pixelshuffle",
|
||||
window_size=8,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||
input_names = ["input"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"input": {2: "h", 3: "w"},
|
||||
"output": {2: "h", 3: "w"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=conversion.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("SwinIR exported to ONNX successfully")
|
|
@ -7,8 +7,10 @@ from ..chain import (
|
|||
ChainPipeline,
|
||||
correct_codeformer,
|
||||
correct_gfpgan,
|
||||
upscale_bsrgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
upscale_swinir,
|
||||
)
|
||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
|
@ -41,15 +43,28 @@ def run_upscale_correction(
|
|||
|
||||
upscale_stage = None
|
||||
if upscale.scale > 1:
|
||||
if "esrgan" in upscale.upscale_model:
|
||||
if "bsrgan" in upscale.upscale_model:
|
||||
bsrgan_params = StageParams(
|
||||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_bsrgan, bsrgan_params, None)
|
||||
elif "esrgan" in upscale.upscale_model:
|
||||
esrgan_params = StageParams(
|
||||
tile_size=stage.tile_size, outscale=upscale.outscale
|
||||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_resrgan, esrgan_params, None)
|
||||
elif "stable-diffusion" in upscale.upscale_model:
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
upscale_stage = (upscale_stable_diffusion, sd_params, None)
|
||||
elif "swinir" in upscale.correction_model:
|
||||
swinir_params = StageParams(
|
||||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_swinir, swinir_params, None)
|
||||
else:
|
||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..server import ServerContext
|
||||
from ..torch_before_ort import InferenceSession, SessionOptions
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
model: str,
|
||||
provider: str = "DmlExecutionProvider",
|
||||
sess_options: Optional[SessionOptions] = None,
|
||||
) -> None:
|
||||
model_path = path.join(server.model_path, model)
|
||||
self.session = InferenceSession(
|
||||
model_path, providers=[provider], provider_options=sess_options
|
||||
)
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {input_name: image})[0]
|
||||
return output
|
|
@ -0,0 +1,112 @@
|
|||
# from https://github.com/cszn/BSRGAN/blob/main/models/network_rrdbnet.py
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode="fan_in")
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode="fan_in")
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights(
|
||||
[self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
print([in_nc, out_nc, nf, nb, gc, sf])
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
# upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf == 4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
if self.sf == 4:
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
File diff suppressed because it is too large
Load Diff
|
@ -20,7 +20,7 @@ from ..image import ( # mask filters; noise sources
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from ..models import NetworkModel
|
||||
from ..models.meta import NetworkModel
|
||||
from ..params import DeviceParams
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from ..utils import merge
|
||||
|
|
|
@ -17,6 +17,7 @@ codeformer-perceptor
|
|||
facexlib
|
||||
gfpgan
|
||||
realesrgan
|
||||
timm
|
||||
|
||||
### Server packages ###
|
||||
boto3
|
||||
|
|
|
@ -62,6 +62,14 @@ $defs:
|
|||
correction_model:
|
||||
allOf:
|
||||
- $ref: "#/$defs/base_model"
|
||||
- type: object
|
||||
properties:
|
||||
model:
|
||||
type: enum
|
||||
enum: [
|
||||
codeformer,
|
||||
gfpgan
|
||||
]
|
||||
|
||||
diffusion_model:
|
||||
allOf:
|
||||
|
@ -87,6 +95,13 @@ $defs:
|
|||
- type: object
|
||||
required: [scale]
|
||||
properties:
|
||||
model:
|
||||
type: enum
|
||||
enum: [
|
||||
bsrgan,
|
||||
resrgan,
|
||||
swinir
|
||||
]
|
||||
scale:
|
||||
type: number
|
||||
|
||||
|
|
|
@ -81,10 +81,18 @@ export const I18N_STRINGS_EN = {
|
|||
'inversion-minecraft': 'Minecraft Concept',
|
||||
'inversion-ugly-sonic': 'Ugly Sonic',
|
||||
// upscaling
|
||||
'upscaling-bsrgan-x2': 'BSRGAN x2',
|
||||
'upscaling-bsrgan-x4': 'BSRGAN x4',
|
||||
'upscaling-real-esrgan-x2-plus': 'Real ESRGAN x2 Plus',
|
||||
'upscaling-real-esrgan-x4-plus': 'Real ESRGAN x4 Plus',
|
||||
'upscaling-real-esrgan-x4-v3': 'Real ESRGAN x4 v3',
|
||||
'upscaling-stable-diffusion-x4': 'Stable Diffusion x4',
|
||||
'upscaling-swinir-classical-x2': 'SwinIR Classical x2',
|
||||
'upscaling-swinir-classical-x3': 'SwinIR Classical x3',
|
||||
'upscaling-swinir-classical-x4': 'SwinIR Classical x4',
|
||||
'upscaling-swinir-classical-x8': 'SwinIR Classical x8',
|
||||
'upscaling-swinir-real-x2': 'SwinIR Real x2',
|
||||
'upscaling-swinir-real-x4': 'SwinIR Real x4',
|
||||
// extras
|
||||
'diffusion-stablydiffused-aesthetic-v2-6': 'Aesthetic Mix v2.6',
|
||||
'diffusion-anything': 'Anything',
|
||||
|
|
Loading…
Reference in New Issue