use ref arch with einops
This commit is contained in:
parent
9aad798a4a
commit
759db71ac7
|
@ -6,6 +6,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
@ -834,9 +836,13 @@ class ResidualGroup(nn.Module):
|
|||
x = checkpoint.checkpoint(blk, x, x_size)
|
||||
else:
|
||||
x = blk(x, x_size)
|
||||
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
||||
x = self.conv(x)
|
||||
x = torch.einsum("b c h w -> b (h w) c", x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||
# x = self.conv(x)
|
||||
# x = torch.einsum("b c h w -> b (h w) c", x)
|
||||
x = res + x
|
||||
|
||||
return x
|
||||
|
@ -961,7 +967,8 @@ class DAT(nn.Module):
|
|||
heads = num_heads
|
||||
|
||||
self.before_RG = nn.Sequential(
|
||||
torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
|
||||
Rearrange("b c h w -> b (h w) c"),
|
||||
# torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
|
||||
nn.LayerNorm(embed_dim),
|
||||
)
|
||||
|
||||
|
@ -1040,7 +1047,8 @@ class DAT(nn.Module):
|
|||
for layer in self.layers:
|
||||
x = layer(x, x_size)
|
||||
x = self.norm(x)
|
||||
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
||||
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue