Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ut #2477

Merged
merged 2 commits into from
Apr 14, 2024
Merged

fix ut #2477

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
max-parallel: 20
matrix:
os: [ubuntu-latest]
python-version: [3.10]
python-version: [3.10.14]
steps:
- name: Cache Python Packages
uses: actions/cache@v1
Expand Down
16 changes: 11 additions & 5 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def forward(
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
# new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi group attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -379,7 +381,8 @@ def forward(

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -472,7 +475,8 @@ def forward(

else:
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

# for multi query or multi groups attention
if self.h_kv != self.h:
Expand Down Expand Up @@ -569,7 +573,8 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
Expand Down Expand Up @@ -670,7 +675,8 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = torch.cat(
(k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0)

if self.h_kv != self.h:
k = torch.repeat_interleave(
Expand Down
Loading