Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from ctypes import c_void_p, cast
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
from typing_extensions import deprecated

from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from typing_extensions import deprecated

from flashinfer.comm.mnnvl import CommBackend, SymmDeviceMemory, TorchDistBackend

from ..jit.comm import gen_trtllm_comm_module
from ..utils import register_custom_op, round_up
Expand Down Expand Up @@ -601,7 +602,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
aligned_size,
tp_size,
tp_rank,
torch.device("cuda", tp_rank).index,
torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
Expand Down
Loading