Skip to content

Commit

Permalink
complete the conformer, just layers of conformer blocks, to ready for…
Browse files Browse the repository at this point in the history
… soundstorm
  • Loading branch information
lucidrains committed May 17, 2023
1 parent 0e91f03 commit 7a3ba7c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,39 @@ block = ConformerBlock(
)

x = torch.randn(1, 1024, 512)

block(x) # (1, 1024, 512)
```

Conformer - just multiple `ConformerBlock` from above

```python
import torch
from conformer import Conformer

conformer = Conformer(
dim = 512,
depth = 12, # 12 blocks
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)

conformer(x) # (1, 1024, 512)
```

## Todo

- [ ] switch to a better relative positional encoding. shaw's is dated
- [ ] flash attention with a better RPE

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from conformer.conformer import ConformerConvModule, ConformerBlock
from conformer.conformer import ConformerConvModule, ConformerBlock, Conformer
47 changes: 46 additions & 1 deletion conformer/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ def __init__(

self.dropout = nn.Dropout(dropout)

def forward(self, x, context = None, mask = None, context_mask = None):
def forward(
self,
x,
context = None,
mask = None,
context_mask = None
):
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
context = default(context, x)

Expand All @@ -95,6 +101,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

# shaw's relative positional embedding

seq = torch.arange(n, device = device)
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
Expand Down Expand Up @@ -199,3 +206,41 @@ def forward(self, x, mask = None):
x = self.ff2(x) + x
x = self.post_norm(x)
return x

# Conformer

class Conformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,

))

def forward(self, x):

for block in self.layers:
x = block(x)

return x
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
setup(
name = 'conformer',
packages = find_packages(),
version = '0.2.5',
version = '0.3.0',
license='MIT',
description = 'The convolutional module from the Conformer paper',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/conformer',
keywords = ['transformers', 'artificial intelligence', 'transformer'],
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'audio'
],
install_requires=[
'einops',
'einops>=0.6.1',
'torch'
],
classifiers=[
Expand Down

0 comments on commit 7a3ba7c

Please sign in to comment.