Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce mBart #29

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Introduce mBart #29

wants to merge 4 commits into from

Conversation

kauterry
Copy link
Contributor

@kauterry kauterry commented Sep 9, 2023

What does this PR do? Please describe:
Implements the mBart model and its text tokenizer. We are able to successfully load the base model.

Testing the text tokenizer:

VocabularyInfo(size=65539, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
0 <s>
1 <pad>
2 </s>
3 <unk>
4 .
5 ,
65530 pleasant
65531 ▁glycogen
65532 criminalization
65533 ▁varietal
65534 ▁duplicating
65535 ▁protester
65536 [en]
65537 [es]
65538 <mask>
sample_tokens=tensor([65536,     0,   655,  9692,  2049,    19,    22,   146,    31, 29678,
           13,  1845, 17277,  4120,     5,    56,    22,    15,  5277,     4,
            2], dtype=torch.int32)
decoded_str='Some theories suggest that it may have descendants in Manchuria, but it is unlikely.'
prefix_indices:  tensor([65536,     0])
suffix_indices:  tensor([2])
encoded_tokens=tensor([65536,     0,   655,  9692,  2049,    19,    22,   146,    31, 29678,
           13,  1845, 17277,  4120,     5,    56,    22,    15,  5277,     4,
            2])
round_trip_str='Some theories suggest that it may have descendants in Manchuria, but it is unlikely.'

We see that the encoded_tokens is the same as the sample_tokens and the decoded_str is the same as the round_trip_str.

TODO: Check parity for forward pass through the same checkpoint with fairseq1.

Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@kauterry kauterry requested a review from cbalioglu as a code owner September 9, 2023 22:53
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 9, 2023
Copy link
Contributor

@cbalioglu cbalioglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Just left a few nit comments. Just wondering thoug; have you have any asset cards that we can bundle with this PR? How did you verify parity with the original fairseq implementation?

src/fairseq2/models/mbart/loader.py Outdated Show resolved Hide resolved
src/fairseq2/models/mbart/tokenizer.py Outdated Show resolved Hide resolved
num_encoder_attn_heads=16,
num_decoder_attn_heads=16,
ffn_inner_dim=4096,
pos_encoder_type="learned",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like pos_encoder_type and norm_order are always learned, and POST according to this. If that is the case, I would suggest removing these configuration parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having to do this to successfully load the mBart checkpoint with UnitY: https://github.com/fairinternal/seamless_communication/pull/28/files#diff-189811785a49637a011c2db015430cfd708d92f832f8ef30ed7e10dc7f922635R103

The argument about norm_order makes sense, I'll remove that.

src/fairseq2/models/mbart/builder.py Outdated Show resolved Hide resolved

def build_frontend(self, embed: Embedding) -> TransformerFrontend:
"""Build a Transformer encoder/decoder front-end."""
if self.config.pos_encoder_type == "sinusoidal":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, I don't think that this is necessary. mBART always uses learned positional embeddings.

@kauterry
Copy link
Contributor Author

@cbalioglu I'm yet to verify parity with the fairseq mBart model by running forward passes. The asset has an internal checkpoint, wondering what the best way to open-source that would be.

@cbalioglu
Copy link
Contributor

@cbalioglu I'm yet to verify parity with the fairseq mBart model by running forward passes. The asset has an internal checkpoint, wondering what the best way to open-source that would be.

You can use one of mBARTs public checkpoints here (e.g. mbart.CC25) to verify parity and include it as an asset card in your PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants