-
Notifications
You must be signed in to change notification settings - Fork 584
Description
Flashinfer seems to be missing the latest MOE comm kernels for multinode-nvlink/gb200.
TRTLLM's path is mnnvl_moe_alltoallv_combine -> torch.ops.trtllm.moe_comm -> moeCommOp -> tensorrt_llm::kernels::moeAllToAll, see https://github.com/NVIDIA/TensorRT-LLM/blob/222bc911cd35405f3539c366da6c03c00e9a7fb7/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu#L1406
Flashinfer's path is flashinfer.comm.trtllm_alltoall.mnnvl_moe_alltoallv_combine -> moe_comm -> moeCommOp -> flashinfer::trtllm_alltoall::moeAllToAll, see https://github.com/flashinfer-ai/flashinfer/blob/9721ff7ff11cd537ea5c3aba61aef0e037dddf74/include/flashinfer/comm/trtllm_alltoall.cuh#L522
The lowering paths are pretty much the same. However the kernel implementations are different.
This divergence appears to have happened at https://github.com/NVIDIA/TensorRT-LLM/pull/6973
so currently flashinfer seems to have an older version of WideEP from TRTLLM for GB200 and needs to get the more recent, optimized implementations.