Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MuParametrization and MuTransfer #64

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

Conversation

NZ99
Copy link
Contributor

@NZ99 NZ99 commented Dec 22, 2023

This PR adds initial modeling-related changes to support MuP. It is incomplete, as the interaction between some modeling-related techniques is not yet clear to me (e.g. do we have to include any modifications as to make MuP compatible with AliBi?) but is shared for now as to facilitate discussion and collaboration, as agreed during community project meetings.

I'm uploading my code as is after leaving for NeurIPS -- I'll be checking whether anything important is missing or needs fixing in the code so far over the holidays. Apologies in case there are indeed issues with the code as pushed.

Important TODOs:

  • check overall MuP implementation for correctness
  • implement coordinate checking
  • include the use of Optuna (or any alternative, desirable framework) to handle running the actual hparam optimization experiments

@NZ99 NZ99 added enhancement New feature or request help wanted Extra attention is needed labels Dec 22, 2023
@NZ99 NZ99 self-assigned this Dec 22, 2023
Copy link
Collaborator

@othertea othertea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @NZ99 thanks a ton for this MuParametrization PR!! I'm very excited about getting this working 🙂

I have a few high-level comments:

  • Coordinate checking sounds like a great idea! I'm fully in support of doing this before merging a final PR.
  • On the other hand, I wonder if we should consider saving the integration into Optuna for a subsequent PR. This first MuP PR already has a lot of changes, and I would be happy if we just got a version that was verified by coordinate checking. But I do see that an advantage with integrating Optuna in this PR might be that we can try a hyperparam search with it and verify that hyperparams do transfer successfully. So I see the pros and cons of each, and of course, it's your call!
  • Regarding the organization of the mup code: I noticed that in both the repo that you refer to in this PR as well as the muTransformers repo, they have an initialization function that is called on the modules, but in this PR, you chose to initialize them within the __init__ function of each module. I'm curious what are the pros and cons of this choice? Again, I leave it up to you whether or not you want a more global init function. (Though personally I think that the LayerNorm init happens often enough that you may want to factor that out either way.)
  • Are we planning on using the MuP repo? It can be useful if we want to use their MuReadout module (a potential location for using this is in an in-line comment below). And I believe we can't use the default PyTorch optimizers, so we either need to use the MuP optimizers or implement our own with the correct scaling.

And of course, @pascalnotin please feel free to chime in on any thoughts regarding the above comments or anything else!

I've also left some (not-100%-comprehensive) in-line comments. Let me know what you think!

mup_init_scale = 1.0,
mup_output_temp = 1.0,
mup_attn_mult = 1.0,
mup_embedding_mult = 1.0,
Copy link
Collaborator

@othertea othertea Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, some of these config args seem unused in the modeling code as far as I can tell. I assume this is because the PR is still a draft, and they will actually be used in the final PR?

I also believe one is missing: wte_zero_init, which used here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be the result of starting to work on this implementation based on the one from Marco in GPT-NeoX and later on switching to base it on the MuP-scaling one. I will recheck them all and update accordingly.

Comment on lines 57 to 74
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)

#muP -- q_attn
if self.use_mup:
self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale))
self.q_attn.bias.zero_()

else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
#muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487
if self.use_mup:
if config.query_zero_init:
_, fanout = self.c_attn.weight.shape
self.c_attn.weight.data[:, :fanout//3] = 0
self.c_attn.bias.zero_()
Copy link
Collaborator

@othertea othertea Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lines 70-74 should be under the else (line 66) block?

More generally, my understanding is that (self.c_attn, self.q_attn) in the case of cross-attention (self.is_cross_attention==True) has a similar function as self.c_attn in the no cross-attention (self.is_cross_attention==False) case. In particular, self.q_attn's parameters in the cross-attention case has the same function as the first last third of the parameters of the self.c_attn (the ones that are set in line 73) in the no cross-attention case. Therefore, the way you initialize the weights of self.q_attn should be the same as the way you initialize the weights of the first last third of the weights of self.c_attn in the no cross-attention case.

In addition to moving lines 70-74 within the else block mentioned about, this includes

  • adding zero-initialization of self.q_attn to the cross-attention case when config.query_zero_init==True, and
  • adding muP initialization (lines 62-64) of the first last third of the parameters of self.c_attn in the no cross-attention case.

We may also need to initialize (the rest of) self.c_attn with muP init as well.

Let me know what you think of this reasoning!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely helpful, thank you. Updating accordingly.

if self.use_mup:
attn_weights = attn_weights / torch.full(
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, I'm assuming this is where you're planning on using config.mup_attn_mult?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes, though from a configuration point of view I'm unsure about how to best handle this (and the same applies to some of the other MuP-specific configuration options. In theory I wanted an high level configuration option that enables (or disables) MuP, but I also want it to be configurable for experimentation. At the same time, I'm conflicted about this more finegrained configuration since MuP simply doesn't work (that is, hparams do not transfer) unless MuPs requirements, including attention scaling, are respected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'll require both config options to be set.

protein_lm/modeling/models/apt/model_pytorch.py Outdated Show resolved Hide resolved

#muP TO DO: check proper behavior for LM head, nothing should be done (?)
#see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472
#see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to replace this nn.Linear with MuReadout from the muP repository/package or do the equivalent initialization manually. I think cofe-ai's Mu-scaling repo also uses MuReadout; lm_head is initialized in a different file than the one your comment references: https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/lm_mup.py#L17-L25

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've looked into it and agree that we need to switch from the nn.Linear layer to MuReadout. Do we plan to use weight tying? Because mup also has a MuSharedReadout (https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L59-L68) that is likely handy in this case. For now I'll add a note about weight tying and just switch to the normal MuReadout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, we should probably just add a config option about weight tying, and support both cases accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the MuScaling repo adds a width_mult config option to MuReadout because the original mup repo has it calculated and set on a layer by layer basis. I think we might want to follow the MuScaling approach of just integrating MuReadout and MuSharedReadout in our code accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not quite sure I'm understanding -- are you suggesting that we write our own MuReadout and MuSharedReadout modules ourselves?

@NZ99
Copy link
Contributor Author

NZ99 commented Jan 3, 2024

Hey @othertea, thank you very much for these comments, they are really appreciated! Also sorry for the late reply, being sick with COVID over the holidays I have only got

  • Re: coordinate checking and integrating Optuna (and general PR organization), I am happy to go either of the two ways. I have a similar sense w.r.t. the advantages of either approaches as you mention above, so maybe we can be satisfied with coordinate checking results for this PR and have a separate one for the full hparam search pipeline. I maybe err on the side of this approach, but I'm happy to go with any really.
  • My feeling with how the code is organized in the codebases I used as inspiration is that (and this is just a guess to be clear) they wanted to maintain the modeling code as close to the original HF implementation as possible, leading to a separate global initialization function. I chose to go with handling initialization via each module's __init__ function mostly because I wanted to maintain flexibility and avoid stuff potentially breaking should we introduce new modules that we want to test. On the other hand, I'm unsure as to just how much flexibility this approach does provide in reality, mostly because if a new module is introduced without making sure it conforms to muP's rules we won't see any breaking change immediately, rather we'll just have hparam results stop transfering as e.g. model width is increased. I am happy to refactor the code either way in case a global initialization function is preferred.
  • I'm unsure about relying on the MuP repo, since it does not seem to have received updates recently. I'm happy to use that too if that is the general preference though. One other option that I'd suggest (and that was not available back when I opened this PR) is ezmup, which is up to date w.r.t. the latest developments in MuP's general area (namely https://arxiv.org/abs/2310.17813, which I've not yet taken a look at), seems to be self-contained (e.g. it also contains MuP-compatible optimizers) and already implements coordinate checking. It would also allow us to keep modeling edits at a minimum, since all it seemingly requires is to set our width-delated hparams to a multiple of some large, prime "width_basis". I'm going to test this as IMO this would be a good solution, I just wonder about consequences for scaling both in this case and in the one integrating stuff from the mup repo (which, for example, does not seem to support FSDP yet).

Also, thank you very much for the detailed inline comments! I will take a look soon.

@NZ99
Copy link
Contributor Author

NZ99 commented Jan 11, 2024

Made a mistake and pushed from an older account. Anyway, I tried ezmup, but without much success. It runs into errors because it doesn't have support for embedding biases and transformers Conv1Ds out of the box, but even when fixing that it still errors out. Pity because it would have been an easy way to add support for MuP. I will keep testing it. In the meantime, I pushed some updates based on your extremely helpful review. I've kept organization as is for now, but happy to refactor depending on your (and other folks') preference. Re: MuReadout, there is also a MuSharedReadout (see also the dedicated comment mentioned above). I also want to add more documentation w.r.t. configuration options. Next up is the coordinate checking, looking into that today.

@NZ99
Copy link
Contributor Author

NZ99 commented Jan 25, 2024

Added the mup refactor and the coordinate checking integration discussed previously on Discord. You can find a colab showcasing mup coordinate checking results on the model here.

@NZ99
Copy link
Contributor Author

NZ99 commented Jan 25, 2024

Given that MuP now passes coordinate checking in my tests, maybe we can review (cc @othertea? would love to have a second pair of eyes given that I really don't trust myself much) and consider merging, with follow up work going to a to-be-opened second PR for optuna.

Let me know what you think :)

Copy link
Collaborator

@othertea othertea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the updates @NZ99 ! I can confirm I can reproduce your coord check plots and they look good, and that if you do not apply _init_weights, the plots look bad, which is great!
I added some more in-line comments down below. I'll take another look, but for now, the only thing I really notice is that we may be missing a attn_mult parameter.
And to prepare for a merge, could you resolve the merge conflicts (e.g., by rebasing with main or by merging with main)? I think we should be able to merge this very soon 🙂

@@ -0,0 +1 @@
from models import *
Copy link
Collaborator

@othertea othertea Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from models import *
from .models import *

I think you need these to be relative imports? They didn't work as is for me.
Alternatively, instead of changing all of these to relative imports we can remove these lines and import them by specifying the full module paths in test_coord_check.py

@@ -0,0 +1 @@
from apt import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from apt import *
from .apt import *

protein_lm/modeling/models/apt/__init__.py Outdated Show resolved Hide resolved
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
if self.use_mup:
self.attn_dropout = nn.Identity()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)?

[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
if self.use_mup:
attn_weights = attn_weights / torch.full(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be multiplying by some attn_mult here that we add as config option? (as in Mu-Scaling or mutransformers )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes, thank you! Will update accordingly

prepend_bos=True,
append_eos=True,
eos_idx=2)
# mup implementation does not currently support this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the dropout case, should we consider adding an assertion that we are not using mup with this in the configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a good idea!

protein_lm/modeling/models/apt/model_pytorch.py Outdated Show resolved Hide resolved
# note that this has to be run after mup.set_base_shape for it to work
# see https://github.com/microsoft/mup#basic-usage
# not sure if this is required here
self.apply(self._init_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems to me like we shouldn't call this here? As in your coordinate check example, you will have to call it again anyway (and only if you're using mup?)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think this might have been the result of some earlier testing and of forgetting to remove. Indeed this shouldn't have an effect so no reason to keep. Thanks!


if __name__ == "__main__":
delta_model = APTLMHeadModel(config=APTConfig(n_embd=200, n_layer=8, num_attention_heads=10, n_inner=200, use_mup=True))
delta_model.apply(delta_model._init_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
delta_model.apply(delta_model._init_weights)

I think only the actual model needs to have _init_weights applied? I checked, but you should double check too!

Copy link

@fabigr8 fabigr8 Mar 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look into the mup example for transformer and it seems they do not init_weights when mup is active.
see https://github.com/microsoft/mup/blob/main/examples/Transformer/main.py#L189 (line189 and Line 310 following) They only call a weights initialization at the end of the definition of their transformer model.
https://github.com/microsoft/mup/blob/main/examples/Transformer/model.py#L105C9-L105C28 (Line 105).

We also call initialization

# Initialize weights and apply final processing
self.post_init()

Only that our models function is based on the inherited post_init from transformers PreTrainedModel (Line 1243)
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L1243C5-L1243C8

If I understood this right, you are right yes! So, we have to skipp the _init_weights when using MuP.

delta_model.apply(delta_model._init_weights)

base_model = APTLMHeadModel(config=APTConfig(n_embd=1, n_layer=8, num_attention_heads=1, n_inner=1, use_mup=True))
base_model.apply(base_model._init_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model.apply(base_model._init_weights)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants