diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 4fdafb49c6d9..27717bee9262 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -527,6 +527,9 @@ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn): # Keep a handle here; actual usage happens in forward of your MoE block self.act_fn = ACT2FN[config.hidden_act] + # We follow the mixtral "eager" moe implementation at + # https://github.com/huggingface/transformers/blob/457048fbfdba9a7dee8bd03328c62f49e57b95f9/src/transformers/models/mixtral/modular_mixtral.py#L148 + # The core changes in this FP8 version should only relate to how we call the linear projections def forward( self, hidden_states: torch.Tensor, @@ -534,18 +537,17 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states.index_select(0, token_idx) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] gate, up = self.linear( current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scale_inv[expert_idx] ).chunk(2, dim=-1) @@ -554,7 +556,7 @@ def forward( current_hidden_states, self.down_proj[expert_idx], self.down_proj_scale_inv[expert_idx] ) - routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) + routing_weights = top_k_weights[token_idx, top_k_pos, None] current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index 48bd079092f3..30a726df62cd 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -126,6 +126,14 @@ def setUpClass(cls): cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config ) + def setup(self): + """ + Clear also on each setup (e.g. if a different model is used than the base cls one) + """ + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + def tearDown(self): gc.collect() backend_empty_cache(torch_device) @@ -368,6 +376,38 @@ def test_compute_module_sizes(self): # we should at least have 1.5 times memory reduction in total assert model_size[""] > quantized_model_size[""] * 1.5 + @unittest.skip(reason="Dependent on #42028, will be removed alongside that PR") + def test_quantized_moe_forward(self): + """ + Checks implicitly if the moe implementation is correct, i.e. it does not crash for cases + where the indices go over `top_k` as shown within the Minimax M2 model + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/MiniMax-M2-Tiny-FP8", # single layer version + device_map=self.device_map, + ) + + tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M2") + messages = [ + {"role": "user", "content": [{"type": "text", "text": "What is your favourite condiment?"}]}, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!", + } + ], + }, + {"role": "user", "content": [{"type": "text", "text": "Do you have mayonnaise recipes?"}]}, + ] + model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to( + self.device_map + ) + + # Only caring about this not crashing + _ = model.generate(**model_inputs, max_new_tokens=24) + @require_torch_accelerator @unittest.skipIf(