Skip to content

[feature request] add optimal gb200 moe comm kernels #2094

@nvrohanv

Description

@nvrohanv

Flashinfer seems to be missing the latest MOE comm kernels for multinode-nvlink/gb200.

TRTLLM's path is mnnvl_moe_alltoallv_combine -> torch.ops.trtllm.moe_comm -> moeCommOp -> tensorrt_llm::kernels::moeAllToAll, see https://github.com/NVIDIA/TensorRT-LLM/blob/222bc911cd35405f3539c366da6c03c00e9a7fb7/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu#L1406

Flashinfer's path is flashinfer.comm.trtllm_alltoall.mnnvl_moe_alltoallv_combine -> moe_comm -> moeCommOp -> flashinfer::trtllm_alltoall::moeAllToAll, see https://github.com/flashinfer-ai/flashinfer/blob/9721ff7ff11cd537ea5c3aba61aef0e037dddf74/include/flashinfer/comm/trtllm_alltoall.cuh#L522

The lowering paths are pretty much the same. However the kernel implementations are different.
This divergence appears to have happened at https://github.com/NVIDIA/TensorRT-LLM/pull/6973
so currently flashinfer seems to have an older version of WideEP from TRTLLM for GB200 and needs to get the more recent, optimized implementations.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions