1
0
Fork 0

use ref arch with einops

This commit is contained in:
Sean Sube 2023-12-31 12:25:22 -06:00
parent 9aad798a4a
commit 759db71ac7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 4 deletions

View File

@ -6,6 +6,8 @@ import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, trunc_normal_
from torch.nn import functional as F
@ -834,9 +836,13 @@ class ResidualGroup(nn.Module):
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
x = self.conv(x)
x = torch.einsum("b c h w -> b (h w) c", x)
x = rearrange(x, "b c h w -> b (h w) c")
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
# x = self.conv(x)
# x = torch.einsum("b c h w -> b (h w) c", x)
x = res + x
return x
@ -961,7 +967,8 @@ class DAT(nn.Module):
heads = num_heads
self.before_RG = nn.Sequential(
torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
Rearrange("b c h w -> b (h w) c"),
# torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
nn.LayerNorm(embed_dim),
)
@ -1040,7 +1047,8 @@ class DAT(nn.Module):
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x)
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
return x