Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding_some_attention_modules #117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 173 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ if __name__ == '__main__':

- [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)

- [38. Frequency Channel Attention Usage](#38-Frequency-Channel-Attention-Usage)

- [39. Attention Augmented Convolutional Networks Usage](#39-Attention-Augmented-Convolutional-Networks-Usage)

- [40. Global Context Attention Usage](#40-Global-Context-Attention-Usage)

- [41. Linear Context Transform Attention Usage](#41-Linear-Context-Transform-Attention-Usage)

- [42. Gated Channel Transformation Usage](#42-Gated-Channel-Transformation-Usage)

- [43. Gaussian Context Attention Usage](#43-Gaussian-Context-Attention-Usage)

- [Backbone Series](#Backbone-series)

- [1. ResNet Usage](#1-ResNet-Usage)
Expand Down Expand Up @@ -427,10 +439,10 @@ print(output.shape)

### 3. Simplified Self Attention Usage
#### 3.1. Paper
[None]()
[SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021)](https://proceedings.mlr.press/v139/yang21o/yang21o.pdf)

#### 3.2. Overview
![](./model/img/SSA.png)
![](./model/img/SimAttention.png)

#### 3.3. Usage Code
```python
Expand Down Expand Up @@ -1184,7 +1196,7 @@ if __name__ == '__main__':

```

-
***

### 31. ACmix Attention Usage

Expand All @@ -1205,6 +1217,7 @@ if __name__ == '__main__':
print(output.shape)

```
***

### 32. MobileViTv2 Attention Usage

Expand Down Expand Up @@ -1232,6 +1245,7 @@ if __name__ == '__main__':
print(output.shape)

```
***

### 33. DAT Attention Usage

Expand Down Expand Up @@ -1276,6 +1290,7 @@ if __name__ == '__main__':
print(output[0].shape)

```
***

### 34. CrossFormer Attention Usage

Expand Down Expand Up @@ -1313,6 +1328,7 @@ if __name__ == '__main__':
print(output.shape)

```
***

### 35. MOATransformer Attention Usage

Expand Down Expand Up @@ -1350,6 +1366,7 @@ if __name__ == '__main__':
print(output.shape)

```
***

### 36. CrissCrossAttention Attention Usage

Expand All @@ -1370,6 +1387,7 @@ if __name__ == '__main__':
print(outputs.shape)

```
***

### 37. Axial_attention Attention Usage

Expand All @@ -1393,6 +1411,158 @@ if __name__ == '__main__':
outputs = model(input)
print(outputs.shape)

```
***

### 38. Frequency Channel Attention Usage

#### 38.1. Paper

[FcaNet: Frequency Channel Attention Networks (ICCV 2021)](https://arxiv.org/abs/2012.11879)

#### 38.2. Overview

![](./model/img/FCANet.png)

#### 38.3. Usage Code

```python
from model.attention.FCA import MultiSpectralAttentionLayer
import torch

if __name__ == "__main__":
input = torch.randn(32, 128, 64, 64) # (b, c, h, w)
fca_layer = MultiSpectralAttentionLayer(channel = 128, dct_h = 64, dct_w = 64, reduction = 16, freq_sel_method = 'top16')
output = fca_layer(input)
print(output.shape)

```
***

### 39. Attention Augmented Convolutional Networks Usage

#### 39.1. Paper

[Attention Augmented Convolutional Networks (ICCV 2019)](https://arxiv.org/abs/1904.09925)

#### 39.2. Overview

![](./model/img/AAAttention.png)

#### 39.3. Usage Code

```python
from model.attention.AAAttention import AugmentedConv
import torch

if __name__ == "__main__":
input = torch.randn((16, 3, 32, 32))
augmented_conv = AugmentedConv(in_channels=3, out_channels=64, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16)
output = augmented_conv(input)
print(output.shape)

```
***

### 40. Global Context Attention Usage

#### 40.1. Paper

[GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond (ICCVW 2019 Best Paper)](https://arxiv.org/abs/1904.11492)

[Global Context Networks (TPAMI 2020)](https://arxiv.org/abs/2012.13375)

#### 40.2. Overview

![](./model/img/GCNet.png)

#### 40.3. Usage Code

```python
from model.attention.GCAttention import GCModule
import torch

if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gc_layer = GCModule(64)
output = gc_layer(input)
print(output.shape)

```
***

### 41. Linear Context Transform Attention Usage

#### 41.1. Paper

[Linear Context Transform Block (AAAI 2020)](https://arxiv.org/pdf/1909.03834v2)

#### 41.2. Overview

![](./model/img/LCTAttention.png)

#### 41.3. Usage Code

```python
from model.attention.LCTAttention import LCT
import torch

if __name__ == "__main__":
x = torch.randn(16, 64, 32, 32)
attn = LCT(64, 8)
y = attn(x)
print(y.shape)

```
***

### 42. Gated Channel Transformation Usage

#### 42.1. Paper

[Gated Channel Transformation for Visual Recognition (CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_Gated_Channel_Transformation_for_Visual_Recognition_CVPR_2020_paper.pdf)

#### 42.2. Overview

![](./model/img/GCT.png)

#### 42.3. Usage Code

```python
from model.attention.GCTAttention import GCT
import torch

if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gct_layer = GCT(64)
output = gct_layer(input)
print(output.shape)

```
***

### 43. Gaussian Context Attention Usage

#### 43.1. Paper

[Gaussian Context Transformer (CVPR 2021)](https://openaccess.thecvf.com//content/CVPR2021/papers/Ruan_Gaussian_Context_Transformer_CVPR_2021_paper.pdf)

#### 43.2. Overview

![](./model/img/GaussianCA.png)

#### 43.3. Usage Code

```python
from model.attention.GaussianAttention import GCA
import torch

if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gca_layer = GCA(64)
output = gca_layer(input)
print(output.shape)

```

***
Expand Down
Binary file modified model/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
137 changes: 137 additions & 0 deletions model/attention/AAAttention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class AugmentedConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dk, dv, Nh, shape=0, relative=False, stride=1):
super(AugmentedConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dk = dk
self.dv = dv
self.Nh = Nh
self.shape = shape
self.relative = relative
self.stride = stride
self.padding = (self.kernel_size - 1) // 2

assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."

self.conv_out = nn.Conv2d(self.in_channels, self.out_channels - self.dv, self.kernel_size, stride=stride, padding=self.padding)

self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size, stride=stride, padding=self.padding)

self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)

if self.relative:
self.key_rel_w = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True))
self.key_rel_h = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True))

def forward(self, x):
# Input x
# (batch_size, channels, height, width)
# batch, _, height, width = x.size()

# conv_out
# (batch_size, out_channels, height, width)
conv_out = self.conv_out(x)
batch, _, height, width = conv_out.size()

# flat_q, flat_k, flat_v
# (batch_size, Nh, height * width, dvh or dkh)
# dvh = dv / Nh, dkh = dk / Nh
# q, k, v
# (batch_size, Nh, height, width, dv or dk)
flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
if self.relative:
h_rel_logits, w_rel_logits = self.relative_logits(q)
logits += h_rel_logits
logits += w_rel_logits
weights = F.softmax(logits, dim=-1)

# attn_out
# (batch, Nh, height * width, dvh)
attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
# combine_heads_2d
# (batch, out_channels, height, width)
attn_out = self.combine_heads_2d(attn_out)
attn_out = self.attn_out(attn_out)
return torch.cat((conv_out, attn_out), dim=1)

def compute_flat_qkv(self, x, dk, dv, Nh):
qkv = self.qkv_conv(x)
N, _, H, W = qkv.size()
q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
q = self.split_heads_2d(q, Nh)
k = self.split_heads_2d(k, Nh)
v = self.split_heads_2d(v, Nh)

dkh = dk // Nh
q = q * (dkh ** -0.5)
flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W))
flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))
return flat_q, flat_k, flat_v, q, k, v

def split_heads_2d(self, x, Nh):
batch, channels, height, width = x.size()
ret_shape = (batch, Nh, channels // Nh, height, width)
split = torch.reshape(x, ret_shape)
return split

def combine_heads_2d(self, x):
batch, Nh, dv, H, W = x.size()
ret_shape = (batch, Nh * dv, H, W)
return torch.reshape(x, ret_shape)

def relative_logits(self, q):
B, Nh, dk, H, W = q.size()
q = torch.transpose(q, 2, 4).transpose(2, 3)

rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w")
rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h")

return rel_logits_h, rel_logits_w

def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
rel_logits = self.rel_to_abs(rel_logits)

rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
rel_logits = torch.unsqueeze(rel_logits, dim=3)
rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))

if case == "w":
rel_logits = torch.transpose(rel_logits, 3, 4)
elif case == "h":
rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
return rel_logits

def rel_to_abs(self, x):
B, Nh, L, _ = x.size()

col_pad = torch.zeros((B, Nh, L, 1)).to(x)
x = torch.cat((x, col_pad), dim=3)

flat_x = torch.reshape(x, (B, Nh, L * 2 * L))
flat_pad = torch.zeros((B, Nh, L - 1)).to(x)
flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)

final_x = torch.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1))
final_x = final_x[:, :, :L, L - 1:]
return final_x

if __name__ == "__main__":
input = torch.randn((16, 3, 32, 32))
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16)
output = augmented_conv(input)
print(output.shape)
Loading