From cad84fe9f2ba582c065ef58dc587039c35ae4020 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Mon, 25 Mar 2024 19:03:31 +0000 Subject: [PATCH 1/2] fix mixtral onnx export --- src/transformers/models/mixtral/modeling_mixtral.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4c4c44bd2297..2d87667e64b9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -866,15 +866,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. From d3faf1ab87b009581f63ade38ae2048d0cb6244f Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Wed, 3 Apr 2024 12:27:34 +0000 Subject: [PATCH 2/2] fix qwen model --- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index e921af9232dd..cab2ef5ff7e5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -843,15 +843,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here.