-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Add Mega: Moving Average Equipped Gated Attention #21766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ache back to config
… and use unified attention mask
…lues, added sequence length enforcement
…atedCrossAttention
…and without k/v caching
…ntion type inputs; started work on downstream classes; removed mentions of position_ids
…necessary inputs in cross-attention
…ue where from_pretrained is renaming gamma and beta parameters
|
Alright @ArthurZucker, I think that's everything except the threads with ongoing discussion. I'm super happy with how this is shaping up! In the latest batch of commits:
Thanks for the feedback and I'll wait on any more changes until you get a chance to review the updates and resolve the open discussions. Excited to get up and running with MEGA in |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow ! A lot of work, and I think we are almost there! I left a few nits here and there again, but should be ready soon. 🚀
| return torch.clip(gelu(x), self.min, self.max) | ||
|
|
||
|
|
||
| class AccurateGELUActivation(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 😉
| ema_delta_alpha_range (`float`, *optional*, defaults to 0.2): | ||
| The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in | ||
| MultiDimensionDampedEMA. | ||
| ema_beta_range (`float`, *optional*, defaults to 0.02): | ||
| The standard deviation for initializing the beta parameter (expansion matrix) in MultiDimensionDampedEMA. | ||
| ema_gamma_omega_range (`float`, *optional*, defaults to 1.0): | ||
| The standard deviation for initializing the gamma (projection matrix) and omega (residual weight) | ||
| parameters in MultiDimensionEMA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice 😉
|
|
||
|
|
||
| # utility for causal LM masking in the format that Mega expects | ||
| def generate_causal_mask(seq_len): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a lot like the create_extended_attention_mask_for_decoder, you can use it in all the PretrainedModels ! (biggest difference seems to be that this one is not batched
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, great point! I didn't know about that method - super helpful 😄
Mega's attention methods expect a non-batched causal mask, so I can just index the one produced by the built-in method!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is addressed in the upcoming commit - I also realized that my old method was not handling the device. Since I haven't been able to test locally with a GPU, that had not triggered a test failure.
|
|
||
|
|
||
| @require_torch | ||
| class MegaModelIntegrationTest(TestCasePlus): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding that, looks good to me
…ns, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings
|
@ArthurZucker as an update, it looks like the fix for left-padding is going to be a more significant effort to implement -- the relative bias is applied in the attention function, and it expects all of the inputs to be left-to-right starting at position 0. We can probably refactor to accept the position IDs like they did for CodeGen, but we'll also need to change how the bias is added since it is currently using a single I'll dig more into this tomorrow, but for the meantime, I've pushed updates that address the rest of your comments! If you have any other suggestions on the fix for relative positions, I'd love to hear them! 😄 |
|
Sure! Also it's not that important to have left padding in this PR, can be added in another PR! |
|
Thanks @ArthurZucker! After digging into it, I do think it will require a pretty significant refactor to support left-padding in this PR. If you're comfortable with it, I agree that it could make sense in a new PR. I just added an entry in the Also pulled latest changes from |
|
Awesome, it's alright with me to leave this to another PR. Will do my final review before pinging @sgugger for another pair of eyes! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done! 🔥
I left a few comments here ans there about naming conventions and dosctrings, but this is very detailed, and love that you took the time to adress all of my comments! Thanks for bearing with me 😉
| return output | ||
|
|
||
|
|
||
| class MovingAverageGatedAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class MovingAverageGatedAttention(nn.Module): | |
| class MegaMovingAverageGatedAttention(nn.Module): |
Mega should prefix all the Mega classes (norms and MultiDimensionDampedEMA included) We are probabl also gonna rename MultiDimensionDampedEMA to MultiDimensionDampedEma! This is really a nit but users have an easier time guessing the names if we stick to camel everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Working on this now and will add in a local commit to make sure it's updated wherever it's referenced
| self.norm = MegaSequenceNorm( | ||
| self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine | ||
| ) | ||
| self.move = MultiDimensionDampedEMA(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.move = MultiDimensionDampedEMA(config) | |
| self.ema_gate = MultiDimensionDampedEMA(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also going in a local commit as this will need to change in the weight conversion script
| config.chunk_size = ( | ||
| input_ids.size(1) * 2 | ||
| ) # we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| config.chunk_size = ( | |
| input_ids.size(1) * 2 | |
| ) # we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size | |
| # we want the chunk size to be < sequence length, and the sequence length to be a multiple of chunk size | |
| config.chunk_size = (input_ids.size(1) * 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applying this in the local commit because I'm not sure whether the parentheses are forcing the automated style checks to expand across lines
|
|
||
| self.parent.assertEqual(result[0].shape, (self.batch_size, self.seq_length, self.hidden_size)) | ||
|
|
||
| def check_chunking_shorter_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
| query_key, attention_gate = torch.split( | ||
| F.silu(query_key_gates), [self.config.shared_representation_size, self.config.intermediate_size], dim=-1 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also split this in 2 line (activation then this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in upcoming commit
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for all your work adding this model! My main comment is that all building blocks in the modeling file should be prefixed by Mega to avoid any name conflicts with other models.
| return embeddings | ||
|
|
||
|
|
||
| class SimpleRelativePositionalBias(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be prefixed by Mega
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in upcoming commit
| return tile | ||
|
|
||
|
|
||
| class RotaryRelativePositionalBias(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be prefixed by Mega
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in upcoming commit
| NORM2FN = { | ||
| "layernorm": lambda embedding_dim, eps, affine: nn.LayerNorm(embedding_dim, eps, elementwise_affine=affine), | ||
| "scalenorm": lambda embedding_dim, eps, affine: ScaleNorm(dim=-1, eps=eps, affine=affine), | ||
| "rmsnorm": lambda embedding_dim, eps, affine: MegaRMSNorm(embedding_dim, eps=eps, affine=affine), | ||
| "batchnorm": lambda embedding_dim, eps, affine: nn.BatchNorm1d(embedding_dim, eps=eps, affine=affine), | ||
| "syncbatchnorm": lambda embedding_dim, eps, affine: nn.SyncBatchNorm(embedding_dim, eps=eps, affine=affine), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using lambda functions here will make MegaSequenceNorm and then the whole model unpicklable I fear. It's probably better to have five if/else in the MegaSequenceNorm module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad! I was the one pushing for this! Miscalculated the pickling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes sense! I originally had it as an if/else in the MegaSequenceNorm.__init__, so I'll just go back to that design
| ALL_LAYERNORM_LAYERS.append(MegaSequenceNorm) | ||
|
|
||
|
|
||
| class MultiDimensionDampedEMA(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be prefixed by Mega
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in upcoming commit
| # Normalization modules | ||
| # copied from original Mega repo without modification except variable names | ||
| class ScaleNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also needs to prefixed by Mega
| from transformers.models.mega.modeling_mega import ( | ||
| MEGA_PRETRAINED_MODEL_ARCHIVE_LIST, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fits in one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in upcoming commit
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
…se, other module renaming requested in PR
|
Thanks again @ArthurZucker and @sgugger! Appreciate the feedback, and it should all be addressed in the latest changes 🤗 |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Congrats on adding this new model to Transformers!
|
Great working with you @mnaylor5 ! Congrats again on the merge 🔥 |
|
Congrats @mnaylor5 ! Feel free to share on social media and we'll amplify your post |
|
Thanks so much @ArthurZucker and @NielsRogge! I learned a ton through this process, and it's so rewarding to see my code in a library I use so much ❤️ I posted something here on LinkedIn a couple days ago - I'll tag you guys in the comments as well! |
* add mega file structure and plain pytorch version of mega source code * added config class with old naming conventions * filled in mega documentation * added config class and embeddings with optional token types * updated notes * starting the conversion process, deleted intermediate and added use_cache back to config * renamed config attributes in modeling_mega.py * checkpointing before refactoring incremental decoding functions * removed stateful incremental key/values for EMA and self-attention * refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask * MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement * more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention * bug fix in attention mask handling in MovingAverageGatedAttention * removed incremental state from GatedCrossAttention and removed IncrementalState class * finished gated cross attention and got MegaLayer working * fixed causal masking in mega decoder * fixed how padding and causal masks are passed through MegaLayer with and without k/v caching * finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids * added optional dense hidden layer for masked and causal LM classes * docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention * removed before_attn_fn in Mega class and updated docstrings and comments up to there * bug fix in MovingAverageGatedAttention masking * working conversion of MLM checkpoint in scratchpad script -- perfect matches * moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters * renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint * finished checkpoint conversion script * cleanup old class in mega config script * removed 'copied from' statements and passing integration tests * added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing * fixed tuple output of megamodel * all common tests passing after fixing issues in decoder, gradient retention, and initialization * added mega-specific tests, ready for more documentation and style checks * updated docstrings; checkpoint before style fixes * style and quality checks, fixed initialization problem in float_tensor, ready for PR * added mega to toctree * removed unnecessary arg in megaconfig * removed unused arg and fixed code samples with leftover roberta models * Apply suggestions from code review Applied all suggestions except the one renaming a class, as I'll need to update that througout Co-authored-by: Arthur <[email protected]> * fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA * removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms * reformatted .forward() docstrings to match style and removed unused mask input in cross-attention * removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights() * renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files * variable names in NFFN * manual Mega->MEGA changes in docs * Mega->MEGA in config auto * style and quality fixes * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments * commit before dealing with merge conflicts * made new attention activation functions available in ACT2FN and added generation test from OPT * style and quality in activations and tests * documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings * style and quality fixes after latest updates, before rotary position ids * causal mask in MegaBlock docstring + added missing device passing * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * Update README.md Co-authored-by: Sylvain Gugger <[email protected]> * added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR * style and quality fixes + readme updates pointing to main --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* add mega file structure and plain pytorch version of mega source code * added config class with old naming conventions * filled in mega documentation * added config class and embeddings with optional token types * updated notes * starting the conversion process, deleted intermediate and added use_cache back to config * renamed config attributes in modeling_mega.py * checkpointing before refactoring incremental decoding functions * removed stateful incremental key/values for EMA and self-attention * refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask * MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement * more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention * bug fix in attention mask handling in MovingAverageGatedAttention * removed incremental state from GatedCrossAttention and removed IncrementalState class * finished gated cross attention and got MegaLayer working * fixed causal masking in mega decoder * fixed how padding and causal masks are passed through MegaLayer with and without k/v caching * finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids * added optional dense hidden layer for masked and causal LM classes * docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention * removed before_attn_fn in Mega class and updated docstrings and comments up to there * bug fix in MovingAverageGatedAttention masking * working conversion of MLM checkpoint in scratchpad script -- perfect matches * moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters * renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint * finished checkpoint conversion script * cleanup old class in mega config script * removed 'copied from' statements and passing integration tests * added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing * fixed tuple output of megamodel * all common tests passing after fixing issues in decoder, gradient retention, and initialization * added mega-specific tests, ready for more documentation and style checks * updated docstrings; checkpoint before style fixes * style and quality checks, fixed initialization problem in float_tensor, ready for PR * added mega to toctree * removed unnecessary arg in megaconfig * removed unused arg and fixed code samples with leftover roberta models * Apply suggestions from code review Applied all suggestions except the one renaming a class, as I'll need to update that througout Co-authored-by: Arthur <[email protected]> * fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA * removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms * reformatted .forward() docstrings to match style and removed unused mask input in cross-attention * removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights() * renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files * variable names in NFFN * manual Mega->MEGA changes in docs * Mega->MEGA in config auto * style and quality fixes * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments * commit before dealing with merge conflicts * made new attention activation functions available in ACT2FN and added generation test from OPT * style and quality in activations and tests * documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings * style and quality fixes after latest updates, before rotary position ids * causal mask in MegaBlock docstring + added missing device passing * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * Update README.md Co-authored-by: Sylvain Gugger <[email protected]> * added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR * style and quality fixes + readme updates pointing to main --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
What does this PR do?
Fixes #19982
This pull request adds Mega: Moving Average Equipped Gated Attention, which is the current leader of the LRA benchmark. Adapted from the original fairseq-based repo and used a MLM checkpoint I created using the original implementation on the wikitext-103 dataset. There is no proposed Mega tokenizer, so I used the RoBERTa tokenizer which I used on the wikitext checkpoint. The proposed implementation works in encoder and decoder settings, and all relevant tests are passing.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker and @younesbelkada for text models; tagging @NielsRogge for visibility as he responded to the original issue.