-
Notifications
You must be signed in to change notification settings - Fork 32.3k
[Fp8] Fix experts
#43154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[Fp8] Fix experts
#43154
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
125115c
fix fp8 experts
vasqu d426e07
sync index select to base indexing
vasqu 3635125
Merge branch 'main' into fix-fp8-experts
vasqu b3f4c52
add test and comment; test will be enabled with the minimax PR
vasqu 3b3acdd
Merge branch 'main' into fix-fp8-experts
vasqu 2619c96
style
vasqu d0d8606
Merge branch 'main' into fix-fp8-experts
vasqu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -126,6 +126,14 @@ def setUpClass(cls): | |
| cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config | ||
| ) | ||
|
|
||
| def setup(self): | ||
| """ | ||
| Clear also on each setup (e.g. if a different model is used than the base cls one) | ||
| """ | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
| gc.collect() | ||
|
|
||
| def tearDown(self): | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
@@ -368,6 +376,38 @@ def test_compute_module_sizes(self): | |
| # we should at least have 1.5 times memory reduction in total | ||
| assert model_size[""] > quantized_model_size[""] * 1.5 | ||
|
|
||
| @unittest.skip(reason="Dependent on #42028, will be removed alongside that PR") | ||
| def test_quantized_moe_forward(self): | ||
|
Comment on lines
+379
to
+380
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This acts as a sanity integration check but it depends on the minimax m2 PR (#42028) so I will remove this skip when merging that PR I think this is the easiest way as these weights force the issue |
||
| """ | ||
| Checks implicitly if the moe implementation is correct, i.e. it does not crash for cases | ||
| where the indices go over `top_k` as shown within the Minimax M2 model | ||
| """ | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| "hf-internal-testing/MiniMax-M2-Tiny-FP8", # single layer version | ||
| device_map=self.device_map, | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M2") | ||
| messages = [ | ||
| {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]}, | ||
| { | ||
| "role": "assistant", | ||
| "content": [ | ||
| { | ||
| "type": "text", | ||
| "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!", | ||
| } | ||
| ], | ||
| }, | ||
| {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}, | ||
| ] | ||
| model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to( | ||
| self.device_map | ||
| ) | ||
|
|
||
| # Only caring about this not crashing | ||
| _ = model.generate(**model_inputs, max_new_tokens=24) | ||
|
|
||
|
|
||
| @require_torch_accelerator | ||
| @unittest.skipIf( | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks ! maybe we can add a small comment to say that it was mostly copied from deepspeed_v3 modeling, so that we should propagate the changes here also in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment, referencing mixtral