Skip to content

Conversation

@thomasw21
Copy link
Contributor

@thomasw21 thomasw21 commented Jul 14, 2022

Notable changes:

  • Remove attention_mask sum trick, and instead use torch.masked_fill
  • Simplify the causal attention creation
  • Move back to baddbmm instead of bmm. It was unclear why the change was necessary.
  • Remove Deprecate position_ids as they don't make sense in BLOOM.
  • Introduce a fp32 cast for lm_head (and consequently the word embeddings in order to respect the sharing). The intuition is as follows
    One of the thing we're wondering about if something we'd like to call "max-collapse". Given that 16bit allows to generate at most 65536 different values, this means that with a vocabulary of 255k+ values are going to collapse, ie multiple values are going to be equal. So if that happens to the max value, this means that greedy decoding can change between fp32 and fp16/bf16.
  • move back test to test generation on 16bit precision

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@thomasw21 thomasw21 changed the title [WIP] BLOOM cleanup some code BLOOM cleanup some code Jul 27, 2022
Comment on lines 777 to 779
# FIXME @thomasw21: it's quite annoying that weight tie is not done in `super().post_init()` but in `from_pretrained`, not sure about the reason why. Consequently, we need to modify the word embeddings instead ...
# self.lm_head.to(torch.float32)
self.transformer.word_embeddings.to(torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure why, but when instantiating a model using from_pretrained weight tying isn't done in post_init but rather in from_pretrained. This causes the old implementation to fail as you would cast lm_head to fp32-> tie lm_head weight to word embeddings which causes to ignore the fp32 casting. Ideally we'd want to do tie lm_head weight to word embeddings -> cast lm_head to fp32 (which should also cast the word_embeddings). So I guess my question is why is weights tie done in from_pretrained?

use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_dict=None,
return_dict=None,
**kwargs,

I would remove position_ids from the forward and absorb it with kwargs and then do:

if kwargs.pop("position_ids", None) is not None:
    warnigs.warn(...)

Copy link
Contributor Author

@thomasw21 thomasw21 Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum how is this superior? Introducing kwargs feels like opening a can full of worm? I mean we could just keep position_ids in the function argument and do nothing from it and signal that we don't use it and deprecate it? I've been meaning to introduce a decorator where we signal deprecated arguments if that makes sense to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess it's not that big of a deal in the end, but again as I understand the whole PR about position_ids is that we are correcting a bug - not deprecating a working functionality, i.e. if people have passed position_ids before (which they very probably have not done), then it's just a no-op operation => so now removing this not used / useless function argument is completely fine for me. To be more or less 100% backward compatible, I think it's best if we add a kwarg that absorb position_ids as explained above because it makes it very clear from the doc string that position_ids doesn't exist anymore. If we leave as is, then position_ids is from the docs still an existing and functional function arg which it is not IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum ... the advantage of keeping position_ids without introducing **kwargs is that forward will fail if you feed it anything else typically potato=None and doesn't if you use kwargs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the kwargs should then be consumed: you pop position_ids, raise a deprecation warning if we want to, and then raise a ValueError if there is anything else left in the kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay.

self.post_init()

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uff that's quite hacky here - not a big fan of overwriting methods such as from_pretrained(...) :-/ I think I'd be more in favor of adding a:

if model.config.force_lm_head_in_fp32 and model.lm_head.weight.dtype != torch.float32:
    model.lm_head.to(torch.float32)

to the forward call of the model. IMO better than adding a new from_pretrained(...) and a post_init() - wdyt?

Also just to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the forward is okay as well but it feels like the wrong callback, this means that everything we add after the cast_fo_fp32 when setting up the model is going to have to be in forward ...

Well the flow I want is tie weights -> cast lm_head to fp32 that should happen more or less in the same callback (typically I thought that post_init was enough, but turns out from_pretrained ignores the weight tying in post_init to do it itself, if you could explain why I'd be interested @patrickvonplaten )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I've found a better way: fec68bf

) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
if position_ids is not None:
warnings.warn(
"`position_ids` is deprecated and will be removed in v5.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought position_ids were simply wrong, or are there use cases where people might have passed position_ids to overwrite alibi? But actually looking a bit closer into the code, it seems like position_ids is never used, so I think we can safely remove it and should therefore adapt this message to "position_ids have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore passing position_ids."

@patrickvonplaten patrickvonplaten changed the title BLOOM cleanup some code [Bloom] Remove unused position_ids, improve modeling code (lm head, alibi, attention multiplication) Jul 28, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with the fact that forcing the lm head to float32 is just "fixing a bug" and I insist that the corresponding flag should be left as False by default. It will instantly break the API that has a careful way of balancing those weights and the final layer (which is also the biggest) will suddenly be 2x bigger. Likewise for the scripts we will publish in the BLOOM inference blog post.

self.post_init()

def tie_weights(self):
super(BloomForCausalLM, self).tie_weights()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super(BloomForCausalLM, self).tie_weights()
super().tie_weights()

Let's write Python3 code pretty please.

super(BloomForCausalLM, self).tie_weights()

if self.config.force_lm_head_in_fp32:
self.lm_head.to(torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also work with device_map="auto" since this happens after weights are loaded, so good for me.

# Initialize weights and apply final processing
self.post_init()

def tie_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think that this is better than adding an if statement to the forward. To me upcasting the lm_head is similar/identical to upcasting the attention softmax to make it more stable which happens in the forward.

From a user's perspective, I would expect the weights to be downloaded and loaded into the model exactly how they are on the Hub and not changed on the fly depending on a config - think this is quite confusing.

Then during the forward the operation, can't we just force the outputs of the model to be in float32 depending on the config like we do in GPTJ:

lm_logits = self.lm_head(hidden_states).to(torch.float32)

I'm also fine with upcasting the weights in the forward then doing the nn.Linear(...) operation and then downcasting them again, so that weights dtype before forward() == weights dtype after forward(). It would be nice that neither loading the weights into the model nor the forward pass actually changes the weights dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just don't like overwriting functions like tie_weights from PreTrainedModel this makes the code hard to read (e.g. just the loading of the model weights now switches back multiple times between files)

@sgugger @stas00 what do you think here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing the logits to be in float32 instead of the whole lm head would be way better in terms of memory management (if it works). That's a change we could have without opt-in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user's perspective, I would expect the weights to be downloaded and loaded into the model exactly how they are on the Hub and not changed on the fly depending on a config - think this is quite confusing.

But that's not true, if you don't specify torch_dtype it gets loaded as fp32, even when the weights are in another precision.

Then during the forward the operation, can't we just force the outputs of the model to be in float32 depending on the config like we do in GPTJ:

This shouldn't work as you run nn.Linear in fp16/bf16 and it collapses at that point and then cast fo fp32, but it's already collapsed. I don't think that nn.Linear allows to output tensors in another precision than the input.

@sgugger
Copy link
Collaborator

sgugger commented Jul 28, 2022

This PR tries to group way too many things together which is very bad practice as when we realize after merging it that everything is broken, we won't find the cause easily. Please break it down in at lest three parts:

  • clean up code without any change
  • removing/deprecating position Ids
  • the float32 upcasting (which is where all the heated discussion is, so really should be its own PR)

I am veto-ing any merge of everything altogether as we've already had quite enough of "Oh the last BLOOM PR broke Xxx." 🙃

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the position_ids is ok for me since it's a bug fix. I'm not very happy about upcasting weights at loading time in a way that's very much hidden to the user, I'd prefer to do this in the forward while making sure dtype doesn't change before / after forward.

Also this PR does many things in one "Clean code" PR - it would have been nice to actually do multiple PRs for this since BLOOM is important and there are many important changes. As a reviewer, it's super helpful to be able to concentrate fully on one thing in a PR (e.g. removing position_ids )

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 28, 2022

Here are some comments so far:

  • Move back to baddbmm instead of bmm. It was unclear why the change was necessary.

    • Here: Bloom Optimize operations #17866 (comment)
    • But if this PR solves the issue for FP16, OK for me to use baddbmm.
    • (just wondering what are the difference, and if there is any reason you prefer to use baddbmm?)
  • There are 4 force_lm_head_in_fp32 in the test file. Other than the one in test_force_lm_head_in_fp32_is_close_to_fp16, I don't know why we set it to False.
    • Is it to keep the same behavior as before (the whole model in FP16)?
    • But prepare_config_and_inputs has default force_lm_head_in_fp32=True, so most tests now use True. It is a bit confusing to me we keep them False in a few places.
  • I agree with @sgugger that the default value for force_lm_head_in_fp32 should be False.
    • Although True here is good for generation, this is kind special (casting parts of model weights to different dtype)
    • Also it's good to keep the previous behavior by default -> do not introduce surprise things to users

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 28, 2022

Also, I really appreciate your finding on "max-collapse" (especially being able to demonstrate it!), and glad that it improves the generation here.

But I personally would not expect FP16 generations will always match FP32 generations (even with force_lm_head_in_fp32=True), and we don't need to have tests that compare results across FP16/FP32. (I don't remember if we have a common test doing so though).

@thomasw21
Copy link
Contributor Author

Here: #17866 (comment)
But if this PR solves the issue for FP16, OK for me to use baddbmm.
(just wondering what are the difference, and if there is any reason you prefer to use baddbmm?)

I think @younesbelkada and @NouamaneTazi changed the original behaviour, it was unclear what it actually fixed. The reason why I want to use baddbmm is because the training codebase used baddbmm and so there's no reason to use bmm.

There are 4 force_lm_head_in_fp32 in the test file. Other than the one in test_force_lm_head_in_fp32_is_close_to_fp16, I don't know why we set it to False.
Is it to keep the same behavior as before (the whole model in FP16)?
But prepare_config_and_inputs has default force_lm_head_in_fp32=True, so most tests now use True. It is a bit confusing to me we keep them False in a few places.

Yeah so I initially thought that upcasting would have much better inference (at least in greedy style). turns out that's not true at least for 176b (it was true on the small models in test), so as @sgugger and @patrickvonplaten I'll try to figure out more if that feature is actually necessary at all.

@thomasw21
Copy link
Contributor Author

Woops forgot to answer some question:
I agree that default should be False now :D

But I personally would not expect FP16 generations will always match FP32 generations (even with force_lm_head_in_fp32=True), and we don't need to have tests that compare results across FP16/FP32. (I don't remember if we have a common test doing so though).

Well technically given checkpoints are in float16 of bfloat16, there should be little reason that generation don't match. I mean it's the promise of pretraining on those half precision: "use twice less compute/time to get more or less the same model". I would not be surprised that it doesn't match perfectly, but at the same time, now that they do, it's a great signal that the model is robust to numerical inacurracies. Consequently, I think the test matching fp16 (with fp32 lm_head) output with full fp32 output makes sense.

@Muennighoff
Copy link
Contributor

As #18344 (comment) has been merged, can you merge main into this branch?

@thomasw21
Copy link
Contributor Author

thomasw21 commented Aug 4, 2022

Actually going to close this PR, any reason why you want this branch to still be alive? What should be missing if the fp32 upcasting that I've done in another branch.

@patrickvonplaten
Copy link
Contributor

Good to be closed for me

@thomasw21 thomasw21 closed this Aug 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants