From 759db71ac7de32a5e231fded41bd5ee81ddc26ef Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 31 Dec 2023 12:25:22 -0600 Subject: [PATCH] use ref arch with einops --- api/onnx_web/models/dat.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/models/dat.py b/api/onnx_web/models/dat.py index bef039bd..36c73442 100644 --- a/api/onnx_web/models/dat.py +++ b/api/onnx_web/models/dat.py @@ -6,6 +6,8 @@ import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint +from einops import rearrange +from einops.layers.torch import Rearrange from timm.models.layers import DropPath, trunc_normal_ from torch.nn import functional as F @@ -834,9 +836,13 @@ class ResidualGroup(nn.Module): x = checkpoint.checkpoint(blk, x, x_size) else: x = blk(x, x_size) - x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W) + + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) x = self.conv(x) - x = torch.einsum("b c h w -> b (h w) c", x) + x = rearrange(x, "b c h w -> b (h w) c") + # x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W) + # x = self.conv(x) + # x = torch.einsum("b c h w -> b (h w) c", x) x = res + x return x @@ -961,7 +967,8 @@ class DAT(nn.Module): heads = num_heads self.before_RG = nn.Sequential( - torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry? + Rearrange("b c h w -> b (h w) c"), + # torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry? nn.LayerNorm(embed_dim), ) @@ -1040,7 +1047,8 @@ class DAT(nn.Module): for layer in self.layers: x = layer(x, x_size) x = self.norm(x) - x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + # x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W) return x