1
0
Fork 0

feat(api): add support for DAT upscalers

This commit is contained in:
Sean Sube 2023-12-31 11:20:51 -06:00
parent b7f2313489
commit 4a9ca4c4a8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 1290 additions and 1 deletions

View File

@ -18,6 +18,7 @@ from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_dat import UpscaleDATStage
from .upscale_highres import UpscaleHighresStage from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage from .upscale_resrgan import UpscaleRealESRGANStage
@ -47,6 +48,7 @@ CHAIN_STAGES = {
"source-txt2img": SourceTxt2ImgStage, "source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage, "source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage, "upscale-bsrgan": UpscaleBSRGANStage,
"upscale-dat": UpscaleDATStage,
"upscale-highres": UpscaleHighresStage, "upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage, "upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage, "upscale-resrgan": UpscaleRealESRGANStage,

View File

@ -0,0 +1,116 @@
from logging import getLogger
from os import path
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import (
DeviceParams,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class UpscaleDATStage(BaseStage):
max_tile = SizeChart.micro
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.debug("reusing existing DAT pipeline")
return cache_pipe
logger.info("loading DAT model from %s", model_path)
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])
return pipe
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None:
logger.warning("no upscaling model given, skipping")
return sources
logger.info("upscaling with DAT model: %s", upscale.upscale_model)
device = worker.get_device()
dat = self.load(server, stage, upscale, device)
outputs = []
for source in sources.as_numpy():
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("DAT input shape: %s", image.shape)
scale = upscale.outscale
logger.trace(
"DAT output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
),
)
output = dat(image)
output = np.clip(np.squeeze(output, axis=0), 0, 1)
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
logger.debug("output image shape: %s", output.shape)
outputs.append(output)
return StageResult(arrays=outputs)
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
tile = min(params.unet_tile, self.max_tile)
return size.width // tile * size.height // tile

View File

@ -25,6 +25,7 @@ from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
from .diffusion.lora import blend_loras from .diffusion.lora import blend_loras
from .diffusion.textual_inversion import blend_textual_inversions from .diffusion.textual_inversion import blend_textual_inversions
from .upscaling.bsrgan import convert_upscaling_bsrgan from .upscaling.bsrgan import convert_upscaling_bsrgan
from .upscaling.dat import convert_upscaling_dat
from .upscaling.resrgan import convert_upscale_resrgan from .upscaling.resrgan import convert_upscale_resrgan
from .upscaling.swinir import convert_upscaling_swinir from .upscaling.swinir import convert_upscaling_swinir
from .utils import ( from .utils import (
@ -395,7 +396,9 @@ def convert_model_upscaling(conversion: ConversionContext, model):
model_type = model.get("model", "resrgan") model_type = model.get("model", "resrgan")
if model_type == "bsrgan": if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source) convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan": elif model_type == "dat":
convert_upscaling_dat(conversion, model, source)
elif model_type in ["esrgan", "resrgan"]:
convert_upscale_resrgan(conversion, model, source) convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir": elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source) convert_upscaling_swinir(conversion, model, source)

View File

@ -0,0 +1,70 @@
from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.dat import DAT
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscaling_dat(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting DAT model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping")
return
model = DAT(
# TODO: params
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
model.load_state_dict(torch_model, strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {
"input": {2: "h", 3: "w"},
"output": {2: "h", 3: "w"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
logger.info("DAT exported to ONNX successfully")

1096
api/onnx_web/models/dat.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -145,6 +145,8 @@ $defs:
type: string type: string
enum: [ enum: [
bsrgan, bsrgan,
dat,
esrgan,
resrgan, resrgan,
swinir swinir
] ]