Skip to content

Commit 150fae2

Browse files
anirudhr20ushareng
authored andcommitted
Added Support for Returning Attention Scores in TransformerEncoder call (keras-team#1879)
* Added: Return attention scores argument to transformer encoder * Added: docstring for return_attention_scores and added a test to chek the working of the argument * Fixed: Test case by removing print stmts and using self.assertAllEqual * Fixed: Linting
1 parent ed035d3 commit 150fae2

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

keras_hub/src/layers/modeling/transformer_encoder.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,12 @@ def build(self, inputs_shape):
170170
self.built = True
171171

172172
def call(
173-
self, inputs, padding_mask=None, attention_mask=None, training=None
173+
self,
174+
inputs,
175+
padding_mask=None,
176+
attention_mask=None,
177+
training=None,
178+
return_attention_scores=False,
174179
):
175180
"""Forward pass of the TransformerEncoder.
176181
@@ -185,6 +190,7 @@ def call(
185190
[batch_size, sequence_length, sequence_length].
186191
training: a boolean indicating whether the layer should behave in
187192
training mode or in inference mode.
193+
return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.
188194
189195
Returns:
190196
A Tensor of the same shape as the `inputs`.
@@ -200,12 +206,24 @@ def call(
200206
residual = x
201207
if self.normalize_first:
202208
x = self._self_attention_layer_norm(x)
203-
x = self._self_attention_layer(
204-
query=x,
205-
value=x,
206-
attention_mask=self_attention_mask,
207-
training=training,
208-
)
209+
210+
if return_attention_scores:
211+
x, attention_scores = self._self_attention_layer(
212+
query=x,
213+
value=x,
214+
attention_mask=self_attention_mask,
215+
return_attention_scores=return_attention_scores,
216+
training=training,
217+
)
218+
return x, attention_scores
219+
else:
220+
x = self._self_attention_layer(
221+
query=x,
222+
value=x,
223+
attention_mask=self_attention_mask,
224+
training=training,
225+
)
226+
209227
x = self._self_attention_dropout(x, training=training)
210228
x = x + residual
211229
if not self.normalize_first:
@@ -222,6 +240,9 @@ def call(
222240
if not self.normalize_first:
223241
x = self._feedforward_layer_norm(x)
224242

243+
if return_attention_scores:
244+
return x, attention_scores
245+
225246
return x
226247

227248
def get_config(self):

keras_hub/src/layers/modeling/transformer_encoder_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,14 @@ def test_mask_propagation(self):
9595
inputs._keras_mask = mask
9696
outputs = encoder(inputs)
9797
self.assertAllEqual(outputs._keras_mask, mask)
98+
99+
def test_attention_scores(self):
100+
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
101+
inputs = random.uniform(shape=[1, 4, 6])
102+
outputs, attention_scores = encoder(
103+
inputs, return_attention_scores=True
104+
)
105+
self.assertAllEqual(outputs.shape, inputs.shape)
106+
107+
# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
108+
self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4])

0 commit comments

Comments
 (0)