Skip to content

Commit

Permalink
[bestrq] fix time stack position (#1869)
Browse files Browse the repository at this point in the history
* [bestrq] fix time stack position

fix time stack position

* fix lint

* fix mask tail of batch
  • Loading branch information
Mddct authored May 30, 2023
1 parent 2505d18 commit abcc9c5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion wenet/ssl/bestrq/bestqr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def forward(
subsampling_masks = masked_masks.unfold(1,
size=self.stack_frames,
step=self.stride)
code_ids_mask, _ = torch.max(subsampling_masks, 2)
code_ids_mask, _ = torch.min(subsampling_masks, 2)

# 2.0 stack fbank
unmasked_xs = self._stack_features(input)
Expand Down Expand Up @@ -251,6 +251,7 @@ def _apply_mask_signal(
def _stack_features(self, input: torch.Tensor) -> torch.Tensor:

stack_input = input.unfold(1, size=self.stack_frames, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)

Expand Down

0 comments on commit abcc9c5

Please sign in to comment.