-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Bloom] Remove unused position_ids, improve modeling code (lm head, alibi, attention multiplication) #18141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
c9569cc to
823e516
Compare
… force_lm_head_in_fp32 set to True
| # FIXME @thomasw21: it's quite annoying that weight tie is not done in `super().post_init()` but in `from_pretrained`, not sure about the reason why. Consequently, we need to modify the word embeddings instead ... | ||
| # self.lm_head.to(torch.float32) | ||
| self.transformer.word_embeddings.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm not sure why, but when instantiating a model using from_pretrained weight tying isn't done in post_init but rather in from_pretrained. This causes the old implementation to fail as you would cast lm_head to fp32-> tie lm_head weight to word embeddings which causes to ignore the fp32 casting. Ideally we'd want to do tie lm_head weight to word embeddings -> cast lm_head to fp32 (which should also cast the word_embeddings). So I guess my question is why is weights tie done in from_pretrained?
| use_cache=None, | ||
| output_attentions=None, | ||
| output_hidden_states=None, | ||
| return_dict=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return_dict=None, | |
| return_dict=None, | |
| **kwargs, |
I would remove position_ids from the forward and absorb it with kwargs and then do:
if kwargs.pop("position_ids", None) is not None:
warnigs.warn(...)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum how is this superior? Introducing kwargs feels like opening a can full of worm? I mean we could just keep position_ids in the function argument and do nothing from it and signal that we don't use it and deprecate it? I've been meaning to introduce a decorator where we signal deprecated arguments if that makes sense to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess it's not that big of a deal in the end, but again as I understand the whole PR about position_ids is that we are correcting a bug - not deprecating a working functionality, i.e. if people have passed position_ids before (which they very probably have not done), then it's just a no-op operation => so now removing this not used / useless function argument is completely fine for me. To be more or less 100% backward compatible, I think it's best if we add a kwarg that absorb position_ids as explained above because it makes it very clear from the doc string that position_ids doesn't exist anymore. If we leave as is, then position_ids is from the docs still an existing and functional function arg which it is not IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum ... the advantage of keeping position_ids without introducing **kwargs is that forward will fail if you feed it anything else typically potato=None and doesn't if you use kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the kwargs should then be consumed: you pop position_ids, raise a deprecation warning if we want to, and then raise a ValueError if there is anything else left in the kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
| self.post_init() | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uff that's quite hacky here - not a big fan of overwriting methods such as from_pretrained(...) :-/ I think I'd be more in favor of adding a:
if model.config.force_lm_head_in_fp32 and model.lm_head.weight.dtype != torch.float32:
model.lm_head.to(torch.float32)to the forward call of the model. IMO better than adding a new from_pretrained(...) and a post_init() - wdyt?
Also just to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the forward is okay as well but it feels like the wrong callback, this means that everything we add after the cast_fo_fp32 when setting up the model is going to have to be in forward ...
Well the flow I want is tie weights -> cast lm_head to fp32 that should happen more or less in the same callback (typically I thought that post_init was enough, but turns out from_pretrained ignores the weight tying in post_init to do it itself, if you could explain why I'd be interested @patrickvonplaten )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think I've found a better way: fec68bf
| ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: | ||
| if position_ids is not None: | ||
| warnings.warn( | ||
| "`position_ids` is deprecated and will be removed in v5.0.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought position_ids were simply wrong, or are there use cases where people might have passed position_ids to overwrite alibi? But actually looking a bit closer into the code, it seems like position_ids is never used, so I think we can safely remove it and should therefore adapt this message to "position_ids have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore passing position_ids."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree with the fact that forcing the lm head to float32 is just "fixing a bug" and I insist that the corresponding flag should be left as False by default. It will instantly break the API that has a careful way of balancing those weights and the final layer (which is also the biggest) will suddenly be 2x bigger. Likewise for the scripts we will publish in the BLOOM inference blog post.
| self.post_init() | ||
|
|
||
| def tie_weights(self): | ||
| super(BloomForCausalLM, self).tie_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| super(BloomForCausalLM, self).tie_weights() | |
| super().tie_weights() |
Let's write Python3 code pretty please.
| super(BloomForCausalLM, self).tie_weights() | ||
|
|
||
| if self.config.force_lm_head_in_fp32: | ||
| self.lm_head.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also work with device_map="auto" since this happens after weights are loaded, so good for me.
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| def tie_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think that this is better than adding an if statement to the forward. To me upcasting the lm_head is similar/identical to upcasting the attention softmax to make it more stable which happens in the forward.
From a user's perspective, I would expect the weights to be downloaded and loaded into the model exactly how they are on the Hub and not changed on the fly depending on a config - think this is quite confusing.
Then during the forward the operation, can't we just force the outputs of the model to be in float32 depending on the config like we do in GPTJ:
| lm_logits = self.lm_head(hidden_states).to(torch.float32) |
I'm also fine with upcasting the weights in the forward then doing the
nn.Linear(...) operation and then downcasting them again, so that weights dtype before forward() == weights dtype after forward(). It would be nice that neither loading the weights into the model nor the forward pass actually changes the weights dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forcing the logits to be in float32 instead of the whole lm head would be way better in terms of memory management (if it works). That's a change we could have without opt-in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a user's perspective, I would expect the weights to be downloaded and loaded into the model exactly how they are on the Hub and not changed on the fly depending on a config - think this is quite confusing.
But that's not true, if you don't specify torch_dtype it gets loaded as fp32, even when the weights are in another precision.
Then during the forward the operation, can't we just force the outputs of the model to be in float32 depending on the config like we do in GPTJ:
This shouldn't work as you run nn.Linear in fp16/bf16 and it collapses at that point and then cast fo fp32, but it's already collapsed. I don't think that nn.Linear allows to output tensors in another precision than the input.
|
This PR tries to group way too many things together which is very bad practice as when we realize after merging it that everything is broken, we won't find the cause easily. Please break it down in at lest three parts:
I am veto-ing any merge of everything altogether as we've already had quite enough of "Oh the last BLOOM PR broke Xxx." 🙃 |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the position_ids is ok for me since it's a bug fix. I'm not very happy about upcasting weights at loading time in a way that's very much hidden to the user, I'd prefer to do this in the forward while making sure dtype doesn't change before / after forward.
Also this PR does many things in one "Clean code" PR - it would have been nice to actually do multiple PRs for this since BLOOM is important and there are many important changes. As a reviewer, it's super helpful to be able to concentrate fully on one thing in a PR (e.g. removing position_ids )
|
Here are some comments so far:
|
|
Also, I really appreciate your finding on "max-collapse" (especially being able to demonstrate it!), and glad that it improves the generation here. But I personally would not expect FP16 generations will always match FP32 generations (even with |
I think @younesbelkada and @NouamaneTazi changed the original behaviour, it was unclear what it actually fixed. The reason why I want to use
Yeah so I initially thought that upcasting would have much better inference (at least in greedy style). turns out that's not true at least for 176b (it was true on the small models in test), so as @sgugger and @patrickvonplaten I'll try to figure out more if that feature is actually necessary at all. |
|
Woops forgot to answer some question:
Well technically given checkpoints are in float16 of bfloat16, there should be little reason that generation don't match. I mean it's the promise of pretraining on those half precision: "use twice less compute/time to get more or less the same model". I would not be surprised that it doesn't match perfectly, but at the same time, now that they do, it's a great signal that the model is robust to numerical inacurracies. Consequently, I think the test matching fp16 (with fp32 lm_head) output with full fp32 output makes sense. |
|
As #18344 (comment) has been merged, can you merge main into this branch? |
|
Actually going to close this PR, any reason why you want this branch to still be alive? What should be missing if the |
|
Good to be closed for me |
Notable changes:
attention_masksum trick, and instead usetorch.masked_fillbaddbmminstead ofbmm. It was unclear why the change was necessary.RemoveDeprecateposition_idsas they don't make sense in BLOOM.One of the thing we're wondering about if something we'd like to call "max-collapse". Given that 16bit allows to generate at most 65536 different values, this means that with a vocabulary of 255k+ values are going to collapse, ie multiple values are going to be equal. So if that happens to the max value, this means that greedy decoding can change between fp32 and fp16/bf16.