Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jul 4, 2022

What does this PR do?

Fix torchscript tests for GPT-NeoX. The main issue comes from the fact that current RotaryEmbedding changes the model structure in forward.

This PR creates the necessary embeddings in __init__, which basically makes the cache (of embedding) mechanism useless. Furthermore, the attribute names seems a bit confusing now. We could probably add some attribute (ex. init_sin_cos_cache_seq_len) in config with a value <= max_position_embeddings, but I think it's way too much.

Not certain if it is worth it. However, with a PR opened, we have a reference.

The current failing test is
https://github.com/huggingface/transformers/runs/7216768053?check_suite_focus=true

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

beta=1.0,
alpha=(1.0 / self.norm_factor),
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
# alpha=(1.0 / self.norm_factor),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be cleaned up.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

However, could we add the failing test for reference or do we need to add a new test here?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jul 7, 2022

LGTM!

However, could we add the failing test for reference or do we need to add a new test here?

I updated the PR description to include the current failing test. Regarding new tests, I don't think it's necessary, as we just build the necessary tensors in __init__ instead of in forward, and the current set of tests should be enough :-)

(however, let me know if you have some idea of new necessary test cases!)

@patrickvonplaten
Copy link
Contributor

LGTM!
However, could we add the failing test for reference or do we need to add a new test here?

I updated the PR description to include the current failing test. Regarding new tests, I don't think it's necessary, as we just build the necessary tensors in __init__ instead of in forward, and the current set of tests should be enough :-)

(however, let me know if you have some idea of new necessary test cases!)

Perfect thanks!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LysandreJik LysandreJik merged commit ac98a88 into huggingface:main Jul 11, 2022
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* fix dtype issue in _attn

* fix RotaryEmbedding

* fix RotaryEmbedding 2

* clean up

Co-authored-by: ydshieh <[email protected]>
@ydshieh ydshieh deleted the fix_torchscript_tests_for_gpt_neox branch September 7, 2022 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants