From 45cee58950224f99d0a64fde5c9675f37591f56b Mon Sep 17 00:00:00 2001 From: esmeetu Date: Wed, 4 Mar 2026 20:07:08 +0800 Subject: [PATCH] fix rank setting Signed-off-by: esmeetu --- flashinfer/comm/trtllm_ar.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 4392468e80..a2a39756e6 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -19,12 +19,13 @@ from ctypes import c_void_p, cast from types import SimpleNamespace from typing import List, Optional, Tuple, Union -from typing_extensions import deprecated -from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from typing_extensions import deprecated + +from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend from ..jit.comm import gen_trtllm_comm_module from ..utils import register_custom_op, round_up @@ -601,7 +602,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( aligned_size, tp_size, tp_rank, - torch.device("cuda", tp_rank).index, + torch.cuda.current_device(), comm_backend, enable_multicast=False, allocate_signal_pads=False,