diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 4392468e80..a2a39756e6 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -19,12 +19,13 @@ from ctypes import c_void_p, cast from types import SimpleNamespace from typing import List, Optional, Tuple, Union -from typing_extensions import deprecated -from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from typing_extensions import deprecated + +from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend from ..jit.comm import gen_trtllm_comm_module from ..utils import register_custom_op, round_up @@ -601,7 +602,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( aligned_size, tp_size, tp_rank, - torch.device("cuda", tp_rank).index, + torch.cuda.current_device(), comm_backend, enable_multicast=False, allocate_signal_pads=False,