Skip to content

Commit 5d10e19

Browse files
authored
Need to handle 1D and 2D position_ids at the same time
Signed-off-by: Jin Li <[email protected]>
1 parent 567b357 commit 5d10e19

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ def forward_impl(self, position_ids: Optional[torch.Tensor],
951951

952952
hidden_states = hidden_states[:num_tokens, ...]
953953
if position_ids is not None:
954-
position_ids = position_ids[:, :num_tokens]
954+
position_ids = position_ids[..., :num_tokens]
955955

956956
if self.is_lite:
957957
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(

0 commit comments

Comments
 (0)