1
0
Fork 0

use ref arch with einops

This commit is contained in:
Sean Sube 2023-12-31 12:25:22 -06:00
parent 9aad798a4a
commit 759db71ac7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 4 deletions

View File

@ -6,6 +6,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, trunc_normal_ from timm.models.layers import DropPath, trunc_normal_
from torch.nn import functional as F from torch.nn import functional as F
@ -834,9 +836,13 @@ class ResidualGroup(nn.Module):
x = checkpoint.checkpoint(blk, x, x_size) x = checkpoint.checkpoint(blk, x, x_size)
else: else:
x = blk(x, x_size) x = blk(x, x_size)
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
x = self.conv(x) x = self.conv(x)
x = torch.einsum("b c h w -> b (h w) c", x) x = rearrange(x, "b c h w -> b (h w) c")
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
# x = self.conv(x)
# x = torch.einsum("b c h w -> b (h w) c", x)
x = res + x x = res + x
return x return x
@ -961,7 +967,8 @@ class DAT(nn.Module):
heads = num_heads heads = num_heads
self.before_RG = nn.Sequential( self.before_RG = nn.Sequential(
torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry? Rearrange("b c h w -> b (h w) c"),
# torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
nn.LayerNorm(embed_dim), nn.LayerNorm(embed_dim),
) )
@ -1040,7 +1047,8 @@ class DAT(nn.Module):
for layer in self.layers: for layer in self.layers:
x = layer(x, x_size) x = layer(x, x_size)
x = self.norm(x) x = self.norm(x)
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W) x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
return x return x