diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 3b24b29fbb..4fa5362d09 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -166,16 +166,28 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens routing_weights, selected_experts = calculate_routing_tensors(router_logits, self.top_k, hidden_states.dtype) - final_hidden_states = self.call_dynamic_moe_op( - hidden_states=hidden_states, - expert_routing_table=selected_experts, - router_weights=routing_weights, - ) - if is_deepspeed_available() and (not self.training): - from deepspeed import comm as dist + # TODO + # This is a hack solution to avoid segmentation fault during SFT training. + # Remove this section after the issue is fixed. + if self.training: + final_hidden_states = self.call_sparse_moe_op( + shape=original_shape, + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=routing_weights, + ) + else: + final_hidden_states = self.call_dynamic_moe_op( + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=routing_weights, + ) + if is_deepspeed_available(): + from deepspeed import comm as dist + + if dist.is_initialized(): + dist.all_reduce(final_hidden_states) - if dist.is_initialized(): - dist.all_reduce(final_hidden_states) return final_hidden_states.view(original_shape), router_logits def call_dynamic_moe_op( @@ -202,6 +214,37 @@ def call_dynamic_moe_op( experts_max=len(self.experts) - 1, ) + def call_sparse_moe_op( + self, + shape, + hidden_states, + expert_routing_table, + router_weights, + ): + dtype = hidden_states.dtype + device = hidden_states.device + + padded_weights = torch.zeros((hidden_states.shape[0], self.num_experts), dtype=dtype, device=device) + padded_weights.scatter_(-1, expert_routing_table, router_weights) + padded_weights = padded_weights.view(shape[0], shape[1], self.num_experts).permute(2, 0, 1).unsqueeze(-1) + + current_state_static = hidden_states + + final_hidden_states = torch.zeros(shape, dtype=dtype, device=device) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + padded_weight = padded_weights[expert_idx] + current_hidden_states_static = expert_layer(current_state_static).view(shape) * padded_weight + final_hidden_states += current_hidden_states_static + + # Support long sequences exceeding 8192 + if not self.training and shape[1] > 8192: + htcore.mark_step() + + return final_hidden_states + class GaudiMixtralAttentionLongSequence: @staticmethod