use ref arch with einops
This commit is contained in:
parent
9aad798a4a
commit
759db71ac7
|
@ -6,6 +6,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
from timm.models.layers import DropPath, trunc_normal_
|
from timm.models.layers import DropPath, trunc_normal_
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
@ -834,9 +836,13 @@ class ResidualGroup(nn.Module):
|
||||||
x = checkpoint.checkpoint(blk, x, x_size)
|
x = checkpoint.checkpoint(blk, x, x_size)
|
||||||
else:
|
else:
|
||||||
x = blk(x, x_size)
|
x = blk(x, x_size)
|
||||||
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = torch.einsum("b c h w -> b (h w) c", x)
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
|
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||||
|
# x = self.conv(x)
|
||||||
|
# x = torch.einsum("b c h w -> b (h w) c", x)
|
||||||
x = res + x
|
x = res + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -961,7 +967,8 @@ class DAT(nn.Module):
|
||||||
heads = num_heads
|
heads = num_heads
|
||||||
|
|
||||||
self.before_RG = nn.Sequential(
|
self.before_RG = nn.Sequential(
|
||||||
torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
|
Rearrange("b c h w -> b (h w) c"),
|
||||||
|
# torch.einsum("b c h w -> b (h w) c"), # TODO: will this curry?
|
||||||
nn.LayerNorm(embed_dim),
|
nn.LayerNorm(embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1040,7 +1047,8 @@ class DAT(nn.Module):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, x_size)
|
x = layer(x, x_size)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
|
||||||
|
# x = torch.einsum("b (h w) c -> b c h w", x) # h=H, w=W)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue