Skip to content

Conversation

@abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jan 16, 2023

This bug cropped up when I was implementing BartBackbone: #661. Instead of passing value_dim = hidden_dim in the cross-attention layer, we should pass head_dim.

Let's look at the TransformerDecoderBlock layer given in the tensorflow/models repo. value_dim is not passed to keras.layers.MultiHeadAttention layer, which means that value_dim = key_dim = head_dim.

Intuitively, if we pass value_dim as hidden_dim = 768, with num_heads = 12, the weight matrix for value will be of shape (768, 12, 768). This is incorrect. The shape should be (768, 12, 64).

@abheesht17 abheesht17 requested review from jbischof and mattdangerw and removed request for mattdangerw January 16, 2023 10:54
@abheesht17
Copy link
Collaborator Author

Oops, accidentally removed review request for @mattdangerw. Adding it back.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix looks good! I am unclear why we need the testing change though, is it actually changing the test in any way?

intermediate_dim=4, num_heads=2
intermediate_dim=4,
num_heads=2,
has_cross_attention=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why do we need this actually? Won't this line at the start of call has_encoder_sequence = encoder_sequence is not None, mean the layer will be built with cross attention as soon as the decoder is called on two inputs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry 🤦🏼 , not needed. Changing it back

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops actually marked this as "changes requested" until we figure out the testing bit.

@mattdangerw mattdangerw merged commit 8ea419b into keras-team:master Jan 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants