Skip to content

Commit

Permalink
Merge branch 'main' into moe
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Apr 14, 2024
2 parents dd0a055 + 9e20eb9 commit 4871bae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
max-parallel: 20
matrix:
os: [ubuntu-latest]
python-version: [3.10]
python-version: [3.10.14]
steps:
- name: Cache Python Packages
uses: actions/cache@v1
Expand Down
16 changes: 11 additions & 5 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def forward(
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
# new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi group attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -379,7 +381,8 @@ def forward(

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -472,7 +475,8 @@ def forward(

else:
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -569,7 +573,8 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
Expand Down Expand Up @@ -670,7 +675,8 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

if self.h_kv != self.h:
k = torch.repeat_interleave(
Expand Down

0 comments on commit 4871bae

Please sign in to comment.