From 9e20eb9c330811e07e8e19dbece397ac0636746a Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Sun, 14 Apr 2024 22:58:32 +0800 Subject: [PATCH] fix ut (#2477) * fix ut * fix py version --- .github/workflows/unit_test.yml | 2 +- wenet/transformer/attention.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 0450203fa..a6122c0a5 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -12,7 +12,7 @@ jobs: max-parallel: 20 matrix: os: [ubuntu-latest] - python-version: [3.10] + python-version: [3.10.14] steps: - name: Cache Python Packages uses: actions/cache@v1 diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index c65d8de2e..020c89ef7 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -234,7 +234,9 @@ def forward( v = torch.cat([value_cache, v], dim=2) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + # new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0) # for multi query or multi group attention if self.h_kv != self.h: @@ -379,7 +381,8 @@ def forward( # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0) # for multi query or multi groups attention if self.h_kv != self.h: @@ -472,7 +475,8 @@ def forward( else: q, k, v = self.forward_qkv(query, key, value) - new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0) # for multi query or multi groups attention if self.h_kv != self.h: @@ -569,7 +573,8 @@ def forward( dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0) rel_k = self.rel_k_embed( self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k) @@ -670,7 +675,8 @@ def forward( dim=-1) k = torch.cat([key_cache, k], dim=2) v = torch.cat([value_cache, v], dim=2) - new_cache = torch.cat((k, v), dim=-1) if not self.training else cache + new_cache = torch.cat( + (k, v), dim=-1) if not self.training else torch.zeros(0, 0, 0, 0) if self.h_kv != self.h: k = torch.repeat_interleave(