diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 11197182f85b..ae53e54ee8f9 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): """ This function computes the bias for the predict stream """ - bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf") + left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") + right_block = left_block.detach().clone() # create bias for stream_idx in range(ngram): - for i in range(sequence_length): - bias[stream_idx, i, sequence_length + i] = 0 - bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0 - return bias + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx].triu_(-stream_idx + 1) + + left_block[:, :, 0] = 0 + return torch.cat([left_block, right_block], dim=2) def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):