feat(api): add support for DAT upscalers
This commit is contained in:
parent
b7f2313489
commit
4a9ca4c4a8
|
@ -18,6 +18,7 @@ from .source_s3 import SourceS3Stage
|
||||||
from .source_txt2img import SourceTxt2ImgStage
|
from .source_txt2img import SourceTxt2ImgStage
|
||||||
from .source_url import SourceURLStage
|
from .source_url import SourceURLStage
|
||||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
|
from .upscale_dat import UpscaleDATStage
|
||||||
from .upscale_highres import UpscaleHighresStage
|
from .upscale_highres import UpscaleHighresStage
|
||||||
from .upscale_outpaint import UpscaleOutpaintStage
|
from .upscale_outpaint import UpscaleOutpaintStage
|
||||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||||
|
@ -47,6 +48,7 @@ CHAIN_STAGES = {
|
||||||
"source-txt2img": SourceTxt2ImgStage,
|
"source-txt2img": SourceTxt2ImgStage,
|
||||||
"source-url": SourceURLStage,
|
"source-url": SourceURLStage,
|
||||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||||
|
"upscale-dat": UpscaleDATStage,
|
||||||
"upscale-highres": UpscaleHighresStage,
|
"upscale-highres": UpscaleHighresStage,
|
||||||
"upscale-outpaint": UpscaleOutpaintStage,
|
"upscale-outpaint": UpscaleOutpaintStage,
|
||||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..models.onnx import OnnxModel
|
||||||
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
SizeChart,
|
||||||
|
StageParams,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
|
from ..server import ModelTypes, ServerContext
|
||||||
|
from ..utils import run_gc
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleDATStage(BaseStage):
|
||||||
|
max_tile = SizeChart.micro
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
device: DeviceParams,
|
||||||
|
):
|
||||||
|
# must be within the load function for patch to take effect
|
||||||
|
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||||
|
cache_key = (model_path,)
|
||||||
|
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
|
||||||
|
|
||||||
|
if cache_pipe is not None:
|
||||||
|
logger.debug("reusing existing DAT pipeline")
|
||||||
|
return cache_pipe
|
||||||
|
|
||||||
|
logger.info("loading DAT model from %s", model_path)
|
||||||
|
|
||||||
|
pipe = OnnxModel(
|
||||||
|
server,
|
||||||
|
model_path,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
|
||||||
|
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
|
||||||
|
run_gc([device])
|
||||||
|
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
|
if upscale.upscale_model is None:
|
||||||
|
logger.warning("no upscaling model given, skipping")
|
||||||
|
return sources
|
||||||
|
|
||||||
|
logger.info("upscaling with DAT model: %s", upscale.upscale_model)
|
||||||
|
device = worker.get_device()
|
||||||
|
dat = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for source in sources.as_numpy():
|
||||||
|
image = source / 255.0
|
||||||
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
|
image = np.expand_dims(image, axis=0)
|
||||||
|
logger.trace("DAT input shape: %s", image.shape)
|
||||||
|
|
||||||
|
scale = upscale.outscale
|
||||||
|
logger.trace(
|
||||||
|
"DAT output shape: %s",
|
||||||
|
(
|
||||||
|
image.shape[0],
|
||||||
|
image.shape[1],
|
||||||
|
image.shape[2] * scale,
|
||||||
|
image.shape[3] * scale,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = dat(image)
|
||||||
|
|
||||||
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
|
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
logger.debug("output image shape: %s", output.shape)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
return StageResult(arrays=outputs)
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
) -> int:
|
||||||
|
tile = min(params.unet_tile, self.max_tile)
|
||||||
|
return size.width // tile * size.height // tile
|
|
@ -25,6 +25,7 @@ from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
|
||||||
from .diffusion.lora import blend_loras
|
from .diffusion.lora import blend_loras
|
||||||
from .diffusion.textual_inversion import blend_textual_inversions
|
from .diffusion.textual_inversion import blend_textual_inversions
|
||||||
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
||||||
|
from .upscaling.dat import convert_upscaling_dat
|
||||||
from .upscaling.resrgan import convert_upscale_resrgan
|
from .upscaling.resrgan import convert_upscale_resrgan
|
||||||
from .upscaling.swinir import convert_upscaling_swinir
|
from .upscaling.swinir import convert_upscaling_swinir
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
@ -395,7 +396,9 @@ def convert_model_upscaling(conversion: ConversionContext, model):
|
||||||
model_type = model.get("model", "resrgan")
|
model_type = model.get("model", "resrgan")
|
||||||
if model_type == "bsrgan":
|
if model_type == "bsrgan":
|
||||||
convert_upscaling_bsrgan(conversion, model, source)
|
convert_upscaling_bsrgan(conversion, model, source)
|
||||||
elif model_type == "resrgan":
|
elif model_type == "dat":
|
||||||
|
convert_upscaling_dat(conversion, model, source)
|
||||||
|
elif model_type in ["esrgan", "resrgan"]:
|
||||||
convert_upscale_resrgan(conversion, model, source)
|
convert_upscale_resrgan(conversion, model, source)
|
||||||
elif model_type == "swinir":
|
elif model_type == "swinir":
|
||||||
convert_upscaling_swinir(conversion, model, source)
|
convert_upscaling_swinir(conversion, model, source)
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx import export
|
||||||
|
|
||||||
|
from ...models.dat import DAT
|
||||||
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_upscaling_dat(
|
||||||
|
conversion: ConversionContext,
|
||||||
|
model: ModelDict,
|
||||||
|
source: str,
|
||||||
|
):
|
||||||
|
name = model.get("name")
|
||||||
|
source = source or model.get("source")
|
||||||
|
scale = model.get("scale", 1)
|
||||||
|
|
||||||
|
dest = path.join(conversion.model_path, name + ".onnx")
|
||||||
|
logger.info("converting DAT model: %s -> %s", name, dest)
|
||||||
|
|
||||||
|
if path.isfile(dest):
|
||||||
|
logger.info("ONNX model already exists, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
model = DAT(
|
||||||
|
# TODO: params
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
|
if "params_ema" in torch_model:
|
||||||
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||||
|
elif "params" in torch_model:
|
||||||
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch_model, strict=False)
|
||||||
|
|
||||||
|
model.to(conversion.training_device).train(False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
||||||
|
input_names = ["input"]
|
||||||
|
output_names = ["output"]
|
||||||
|
dynamic_axes = {
|
||||||
|
"input": {2: "h", 3: "w"},
|
||||||
|
"output": {2: "h", 3: "w"},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("exporting ONNX model to %s", dest)
|
||||||
|
export(
|
||||||
|
model,
|
||||||
|
rng,
|
||||||
|
dest,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=conversion.opset,
|
||||||
|
export_params=True,
|
||||||
|
)
|
||||||
|
logger.info("DAT exported to ONNX successfully")
|
File diff suppressed because it is too large
Load Diff
|
@ -145,6 +145,8 @@ $defs:
|
||||||
type: string
|
type: string
|
||||||
enum: [
|
enum: [
|
||||||
bsrgan,
|
bsrgan,
|
||||||
|
dat,
|
||||||
|
esrgan,
|
||||||
resrgan,
|
resrgan,
|
||||||
swinir
|
swinir
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue