Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fixed GPTNeoX Flax supports #25334

Closed
wants to merge 31 commits into from
Closed

Conversation

HeegyuKim
Copy link

@HeegyuKim HeegyuKim commented Aug 6, 2023

What does this PR do?

Fixes #22950:

  • previous FlaxGPTNeoX support PR GPTNeoX Flax support #22950 contains a error in the cached generation process. I resolved it.
  • And I inserted a test code from Addition of test code for GPTNeoX Flax support #24002
  • There were 7 failures in the course of testing, I'm not sure why. It doesn't seem like a very fatal issue, but I'd appreciate it if you could check it out.
  • The output with the pytorch model is very similar and the model works fine.
  • k/v cache is already implemented.

@sanchit-gandhi

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking super clean already @HeegyuKim! Mainly just some very minor comments from me - the overall design is great. Once these are addressed and the tests pass we can get it ready for merge 🚀

src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
@HeegyuKim
Copy link
Author

I suffering from test issue. Can you help me? @sanchit-gandhi

summary

  • In pytorch, my GPTNeoX test failed to both test_equivalence_flax_to_pt and test_equivalence_pt_to_flax tests
  • But in flax, GPTNeoX doesn't failed because my test code overrides it for not using check_pt_flax_outputs
  • It's same to GPT Neo, Flax GPTNeo don't use check_pt_flax_outputs but Pytorch GPTNeo use it
  • However, GPT Neo test in pytorch do not fail.
  • Flax GPTNeo fails if it uses check_pt_flax_outputs

I don't think this is a problem with my model implementation. I wonder why pytorch's test fails.

This PR failed two tests below

FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_equivalence_flax_to_pt - AssertionError: 1.0483556 not less than or equal to 1e-05 : outputs.last_hidden_state: Difference between PyTorch and Flax is 1.0483555793762207 (>= 1e-05).
FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_equivalence_pt_to_flax - AssertionError: 1.8777691 not less than or equal to 1e-05 : outputs.last_hidden_state: Difference between PyTorch and Flax is 1.877769112586975 (>= 1e-05).

But two flax tests in tests/models/gpt_neox/test_modeling_flax_gptneox.py are fine.

the test code which was copied from #24002 override both test_equivalence_pt_to_flax and test_equivalence_flax_to_pt methods with this comment.

    # overwrite from common since `attention_mask` in combination
    # with `causal_mask` behaves slighly differently

and they use below assert code

# test_modeling_flax_gptneox.py line 267
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)

instead of self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) in test_equivalence_pt_to_flax method in test_modeling_common.py

This overrides are equal to tests/models/gpt_neo/test_modeling_flax_gptneo.py but GPTNeo doesn't fail to pytorch test. I don't know what is different and

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Aug 17, 2023

Hey @HeegyuKim - could you confirm that you get the same logits out from the Flax model when you generate as with the PyTorch model? i.e. that the generation scores are the same in both cases. If this is indeed the case, then we can know for certain that the Flax implementation is correct, and that we need to override the PT-FX cross tests. Otherwise, there's a divergence that we need to fix!

We can check this with the full GPT NeoX model to ensure we have the right logits here

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sanchit-gandhi
Copy link
Contributor

Hey @HeegyuKim! I thought a little bit more about the PT-FX cross tests in the [WIP] Flax LLaMA port with @vvvm23, and suggested that probably the reason for the failing tests is the random attention mask: #24587 (comment)

If we switch to using a causal attention mask, we are able to get PT-FX equivalence for Flax LLaMA without overriding the tests. Since Flax LLaMA is heavily based off Flax GPT-Neo, I'm fairly certain we'll observe similar behaviour for Flax GPT-NeoX

Would you like to try running the tests using a causal attention mask? E.g. as per #24587 (comment)

@liutianlin0121
Copy link
Contributor

Hi! Thanks @HeegyuKim for the PR. I am wondering if there is any update on this? It would be really cool if we could use GPTNeoXForCausalLM in flax!

@HeegyuKim
Copy link
Author

Hello @liutianlin0121 I'm trying to solve the problem whenever I have time. However, even if causal masking is applied, the error in the model output is still larger than 1e-5. The current error is around 0.02-0.03. I'm going to try again this weekend.

Even though there are errors, the model works better than expected. I trained several models with this code.

I want to contribute to huggingface but it's not as easy as I thought.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you're making solid progress @HeegyuKim! Nice work! The most recent CI run is reporting that the difference in the Flax and PyTorch model outputs is > 1 (see CI Output). This suggests to me that there is a divergence between the Flax and PyTorch models. Generally, for any model less than 1B params, we should be able to get equivalence to within 1e-5 between Flax and PyTorch. It's quite likely that you won't get this equivalence running the matmuls in bfloat16 on TPU. But you should be able to running the matmuls in float32, see #15754 and google/jax#10413 (comment) for details

Here's a script that I used previously for checking PT / Flax equivalence for BLOOM: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb You can ignore the bits about JIT'ing the forward pass for the time being. You can also uncomment the check to run it on CPU to force the highest precision, or use the decorator as provided

If we don't get 1e-5 precision, it's usually an indicator that we have a divergence in our model. Here, going through layer-by-layer and checking the hidden-states might be required to pinpoint it. Once you have this equivalence, it's almost guaranteed that the CI will report a difference of less than 1e-5, since it runs on CPU.

Let me know if you have any questions / queries about finishing this PR. You've done a great job and I'd be more than happy to assist you in seeing this to completion!

tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
@HeegyuKim
Copy link
Author

HeegyuKim commented Sep 23, 2023

Ohhhhh I finally pass the equivalence issue! 🎉🎉

  • I use FlaxGPTNeoXRotaryEmbedding class for RoPE and implement caching. This is a problem of the equivalence failure
  • I remove overrides in tests/models/gpt_neox/test_modeling_flax_gpt_neox.py and it works!

But there are CI failures...

@sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

Well done @HeegyuKim, that's excellent news! Regarding the two failing tests:

  • You can run make fix-copies to update the modelling code with any copied functions? The linter will copy all the code so that it is one-for-one the same
  • Could you open a pull request on the Hugging Face Hub to add the Flax weights to EleutherAI/gpt-neox-20b? You can convert first load the PyTorch weights into Flax:
from transformers import FlaxAutoModelForCausalLM

model = FlaxAutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", from_pt=True)

And then push the converted Flax weights to the Hub:

model.push_to_hub("EleutherAI/gpt-neox-20b", create_pr=True)

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks in great shape @HeegyuKim! And nice job on getting equivalence with PyTorch! Left a few suggestions below, mainly just small re-factoring to get it ready for merge. Feel free to ping me as soon as you're ready for a final look - think it should be pretty fast to get it merged from here


self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")

self.rotary_emb = FlaxGPTNeoXRotaryEmbedding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've only implemented one of the three possible rotary embedding types (rope_scaling=None). There are two more RoPE types in the PyTorch modelling code:

I don't think it would be too much work to add these so that we have equivalence with the PyTorch modelling code? WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied RoPE scaling code and test code. But in test CI, the frozen flax model raises SetAttributeFrozenModuleError.

For the cached RoPE embedding, I think I should use the variable. I'm trying to implement it, and I think it'd be nice if you had an appropriate reference or suggestion. @sanchit-gandhi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah we can only set attributes in Flax in the setup method. After that, the module gets frozen, meaning we can't add new attributes or update existing ones. So probably the cached embedding isn't going to work - what we can instead do is always initialise the embeddings to max length (config.max_position_embeddings), and then slice the first N embeddings as required each time. This way, we don't ever need to re-compute or update the embeddings, since we always have the max embedding length we require stored in the setup

src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved

input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: we use a causal attention mask in the Flax generation tests. This is required to get sensible outputs from the Flax model (which is typical behaviour).

tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_flax_gpt_neox.py Outdated Show resolved Hide resolved
@HeegyuKim
Copy link
Author

I think we're almost at the end of our work but there are small issues.

Suddenly wav2vec2 test fails??

  • Suddenly wav2vec model tests are failed (tests_torch CI link). It seems to have something to do with gelu_fast that I added. I don't think wav2vec2 uses gelu_fast, but I don't know why.
  • GPT-NeoX-20B uses gelu_fast activation. I converted pytorch implemented to flax version in src/transformers/modeling_flax_utils.py and added gelu_fast.

Flax weights

Copied from issue

  • make fix-copies changes every GPTNeoXBlahBlah -> GPTNeoBlahBlah (config, comments) even there is a Copied from ... with GPTNeo->GPTNeoX mark. a670443#r1287443977
  • So I moved copied from comments to the exactly same method.

@sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Oct 10, 2023

Hey @HeegyuKim! Nice job on iterating here! Answering your questions in-line below:

  1. For the failing Wav2Vec2 issues, you can rebase onto the main branch of transformers, where the tests have been fixed:
git fetch upstream
git rebase upstream main
git push -f origin main

Note that it's important you force push (-f flag) after a rebase to preserve the correct commit history of this PR!
2. Thanks for converting the PyTorch GELU Fast activation to JAX - this looks great!
3. Thanks also for pushing the Flax weights - we can merge them once this PR is approved by the next reviewer and prior to merging this PR. Using from_pt=True is ok for the Flax tests - just make sure you decorate the tests with @is_pt_flax_cross_test since we need both PyTorch and Flax when we load from pre-trained with from_pt=True
4. The issue you had before was that there was an extra space after the start and end of the right arrow: GPTNeo -> GPTNeoX should be GPTNeo->GPTNeoX. Can you try adding this copied from before FlaxGPTNeoXPreTrainedModel? This should allow you to copy the entire module:

# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->GPTNeoX, GPT_NEO->GPT_NEO_X, "transformer"->"gpt_neox"

@HeegyuKim
Copy link
Author

Finally documentation is left, how can I make a documentation for it? @sanchit-gandhi

Exception: The following objects are in the public init so should be documented:
 - FlaxGPTNeoXForCausalLM
 - FlaxGPTNeoXModel

@sanchit-gandhi
Copy link
Contributor

You can do so with make repo-consistency!

@HeegyuKim
Copy link
Author

I may passed necessary CI tests! @sanchit-gandhi

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice @HeegyuKim! Sorry about the delay with getting you another review. It's mainly small nits from me. Let's put in the request for a core maintainer to take a final look at this and get it merged!

@@ -63,12 +63,17 @@ def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)


def gelu_fast(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

return self.cos_cached, self.sin_cached

def _compute_cos_sin(self, seq_len):
t = jnp.arange(seq_len, dtype=self.inv_freq.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: single-letter variables chosen to maintain equivalence with the PyTorch modelling code

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

return jnp.concatenate((-second_half, first_half), axis=-1)


class FlaxGPTNeoXRotaryEmbedding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embedding classes are very nice! Ported to JAX while following closely the logic from PyTorch

return unfreeze(init_variables["cache"])

@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
def __call__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not copy this entire method from GPT Neo as well? The code looks to be one-to-one the same now, it's just a comment which is different "# if past_key_values are passed..."

If we do so, then we can actually just copy the entire class from GPT Neo, which would make the copied from statements much simpler.


hidden_states = outputs[0]

lm_logits = self.embed_out(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about possibly tied word embeddings?

if self.config.tie_word_embeddings:
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work here! @HeegyuKim Thanks for adding the flax support 🤗

@@ -0,0 +1,783 @@
# coding=utf-8
# Copyright 2023 The EleutherAI and The HuggingFace Inc. team.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023 The EleutherAI and The HuggingFace Inc. team.
# Copyright 2023 The HuggingFace Inc. team.

Comment on lines +133 to +134
cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we got rid of the extra dimensions in the pytorch version #26162

return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointers! 😉

attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask

def _split_heads(self, hidden_states):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's tiny bit counter intuitive that the split_head does not split! It's a nit but here's what we have in bloom for the same function:

# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(other jax implementation have the same so feel free to choose what you prefere

Comment on lines +282 to +283
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only used once so a bit useless, let's remove it

dropout_rng=dropout_rng,
dropout_rate=self.config.attention_dropout,
deterministic=deterministic,
dtype=jnp.promote_types(self.dtype, jnp.float32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have to use this? (Just FMI 🤗 )

return (hidden_states,) + attn_outputs[1:]


class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have a # Ignore copy soon 😉

@github-actions github-actions bot closed this Dec 12, 2023
@sanchit-gandhi
Copy link
Contributor

The PR is almost finished! Would you like to make the last remaining changes @HeegyuKim such that we can get this one merged? Let us know if you have any questions, more than happy to help here

@HeegyuKim
Copy link
Author

Thank you for your comment! I'll check it this weekend

@HeegyuKim
Copy link
Author

Hello @sanchit-gandhi, I rebased this PR to main branch and pushed again.

There are two CI failures - First is a documentation issue.

OSError: Can't load the model for 'EleutherAI/gpt-neox-20b'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'EleutherAI/gpt-neox-20b' is the correct path to a directory containing a file named flax_model.msgpack or pytorch_model.bin.

This problem can be fixed when this GPT-NeoX model PR is merged. Alternatively, we can add from_pt=True to the example.

As for the second issue, I don't know why. I would appreciate it if you could tell me the cause and solution to this problem.

ValueError: The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py. Run `make fix-copies` to fix this.

I ran make fix-copies but following error occurs.

> make fix-copies

python utils/check_copies.py --fix_and_overwrite
Traceback (most recent call last):
  File "/home/heegyu/transformers/utils/check_copies.py", line 1129, in <module>
    check_copies(args.fix_and_overwrite, args.file)
  File "/home/heegyu/transformers/utils/check_copies.py", line 778, in check_copies
    new_diffs = is_copy_consistent(filename, overwrite, buffer)
  File "/home/heegyu/transformers/utils/check_copies.py", line 736, in is_copy_consistent
    diff_index = check_codes_match(observed_code, theoretical_code)
  File "/home/heegyu/transformers/utils/check_copies.py", line 549, in check_codes_match
    theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
AttributeError: 'NoneType' object has no attribute 'groups'

Copy link

github-actions bot commented Feb 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Feb 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants