From 561abb81d2286dcf9565e48a47cbf3c28cba8ab4 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 7 Jan 2021 10:00:43 +0100 Subject: [PATCH 1/4] Vectorized `ngram_attention_bias` calculation --- .../models/prophetnet/modeling_prophetnet.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 11197182f85b..805dd61150fe 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -171,13 +171,16 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): """ This function computes the bias for the predict stream """ - bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf") + left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") + right_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") # create bias for stream_idx in range(ngram): - for i in range(sequence_length): - bias[stream_idx, i, sequence_length + i] = 0 - bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0 - return bias + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx] = torch.triu(left_block[stream_idx], -stream_idx+1) + for i in range(stream_idx): + left_block[stream_idx][i][0] = 0 + + return torch.cat([left_block, right_block], dim=2) def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): From 287f9cbb4ba4816dd5db8e58506c363f928587c7 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 7 Jan 2021 10:11:57 +0100 Subject: [PATCH 2/4] updated formatting with black --- src/transformers/models/prophetnet/modeling_prophetnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 805dd61150fe..9789c440293f 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -176,7 +176,7 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): # create bias for stream_idx in range(ngram): right_block[stream_idx].fill_diagonal_(0, wrap=False) - left_block[stream_idx] = torch.triu(left_block[stream_idx], -stream_idx+1) + left_block[stream_idx] = torch.triu(left_block[stream_idx], -stream_idx + 1) for i in range(stream_idx): left_block[stream_idx][i][0] = 0 From 3becee4572e445c74c744acee229327b5f3d9838 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 7 Jan 2021 10:22:46 +0100 Subject: [PATCH 3/4] Further optimization --- src/transformers/models/prophetnet/modeling_prophetnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 9789c440293f..f51b82c70e6b 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -172,11 +172,11 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): This function computes the bias for the predict stream """ left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") - right_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") + right_block = left_block.detach().clone() # create bias for stream_idx in range(ngram): right_block[stream_idx].fill_diagonal_(0, wrap=False) - left_block[stream_idx] = torch.triu(left_block[stream_idx], -stream_idx + 1) + left_block[stream_idx].triu_(-stream_idx + 1) for i in range(stream_idx): left_block[stream_idx][i][0] = 0 From 15542f5c46f05ab1bd7c428be669067f5a1c3f8c Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 7 Jan 2021 10:28:05 +0100 Subject: [PATCH 4/4] one (last) optimization --- src/transformers/models/prophetnet/modeling_prophetnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index f51b82c70e6b..ae53e54ee8f9 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -177,9 +177,8 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): for stream_idx in range(ngram): right_block[stream_idx].fill_diagonal_(0, wrap=False) left_block[stream_idx].triu_(-stream_idx + 1) - for i in range(stream_idx): - left_block[stream_idx][i][0] = 0 + left_block[:, :, 0] = 0 return torch.cat([left_block, right_block], dim=2)