Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _build(self, input_shape, has_cross_attention):
self._cross_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=hidden_dim,
value_dim=head_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ class MyModel(keras.Model):
def __init__(self):
super().__init__()
self._decoder = transformer_decoder.TransformerDecoder(
intermediate_dim=4, num_heads=2
intermediate_dim=4,
num_heads=2,
has_cross_attention=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why do we need this actually? Won't this line at the start of call has_encoder_sequence = encoder_sequence is not None, mean the layer will be built with cross attention as soon as the decoder is called on two inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry 🤦🏼 , not needed. Changing it back

)
self._dense = keras.layers.Dense(1, activation="sigmoid")

Expand Down