feat(api): add support for DAT upscalers
This commit is contained in:
parent
b7f2313489
commit
4a9ca4c4a8
|
@ -18,6 +18,7 @@ from .source_s3 import SourceS3Stage
|
|||
from .source_txt2img import SourceTxt2ImgStage
|
||||
from .source_url import SourceURLStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_dat import UpscaleDATStage
|
||||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
|
@ -47,6 +48,7 @@ CHAIN_STAGES = {
|
|||
"source-txt2img": SourceTxt2ImgStage,
|
||||
"source-url": SourceURLStage,
|
||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||
"upscale-dat": UpscaleDATStage,
|
||||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UpscaleDATStage(BaseStage):
|
||||
max_tile = SizeChart.micro
|
||||
|
||||
def load(
|
||||
self,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
device: DeviceParams,
|
||||
):
|
||||
# must be within the load function for patch to take effect
|
||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||
cache_key = (model_path,)
|
||||
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.debug("reusing existing DAT pipeline")
|
||||
return cache_pipe
|
||||
|
||||
logger.info("loading DAT model from %s", model_path)
|
||||
|
||||
pipe = OnnxModel(
|
||||
server,
|
||||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warning("no upscaling model given, skipping")
|
||||
return sources
|
||||
|
||||
logger.info("upscaling with DAT model: %s", upscale.upscale_model)
|
||||
device = worker.get_device()
|
||||
dat = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
image = source / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("DAT input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
logger.trace(
|
||||
"DAT output shape: %s",
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
),
|
||||
)
|
||||
|
||||
output = dat(image)
|
||||
|
||||
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
|
||||
logger.debug("output image shape: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return StageResult(arrays=outputs)
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
tile = min(params.unet_tile, self.max_tile)
|
||||
return size.width // tile * size.height // tile
|
|
@ -25,6 +25,7 @@ from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
|
|||
from .diffusion.lora import blend_loras
|
||||
from .diffusion.textual_inversion import blend_textual_inversions
|
||||
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
||||
from .upscaling.dat import convert_upscaling_dat
|
||||
from .upscaling.resrgan import convert_upscale_resrgan
|
||||
from .upscaling.swinir import convert_upscaling_swinir
|
||||
from .utils import (
|
||||
|
@ -395,7 +396,9 @@ def convert_model_upscaling(conversion: ConversionContext, model):
|
|||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
elif model_type == "dat":
|
||||
convert_upscaling_dat(conversion, model, source)
|
||||
elif model_type in ["esrgan", "resrgan"]:
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ...models.dat import DAT
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscaling_dat(
|
||||
conversion: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale", 1)
|
||||
|
||||
dest = path.join(conversion.model_path, name + ".onnx")
|
||||
logger.info("converting DAT model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
|
||||
model = DAT(
|
||||
# TODO: params
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
elif "params" in torch_model:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model, strict=False)
|
||||
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||
input_names = ["input"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"input": {2: "h", 3: "w"},
|
||||
"output": {2: "h", 3: "w"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=conversion.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("DAT exported to ONNX successfully")
|
File diff suppressed because it is too large
Load Diff
|
@ -145,6 +145,8 @@ $defs:
|
|||
type: string
|
||||
enum: [
|
||||
bsrgan,
|
||||
dat,
|
||||
esrgan,
|
||||
resrgan,
|
||||
swinir
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue