Skip to content

Commit

Permalink
Merge pull request #2336 from Zth9730/fix_multigpu_train
Browse files Browse the repository at this point in the history
[s2t] fix format test=asr
  • Loading branch information
JiaXiao243 authored Aug 31, 2022
2 parents 58ab7e8 + cdcb1a5 commit 0b544ee
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions paddlespeech/s2t/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def forward_chunk(
xs,
att_mask,
pos_emb,
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
Expand Down
3 changes: 1 addition & 2 deletions paddlespeech/s2t/modules/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ def forward(
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
Expand Down
8 changes: 6 additions & 2 deletions paddlespeech/server/engine/asr/online/python/asr_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,12 @@ def advance_decoding(self, is_finished=False):
# forward chunk
(y, self.att_cache,
self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, self.att_cache,
self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool))
chunk_xs,
self.offset,
required_cache_size,
att_cache=self.att_cache,
cnn_cache=self.cnn_cache,
att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y)

# update the global offset, in decoding frame unit
Expand Down

0 comments on commit 0b544ee

Please sign in to comment.