Skip to content

Commit dd4b4c7

Browse files
committed
Initial commit.
Signed-off-by: Bo Li <[email protected]>
1 parent 4f84a45 commit dd4b4c7

File tree

8 files changed

+3646
-0
lines changed

8 files changed

+3646
-0
lines changed

cpp/tensorrt_llm/common/vec_dtypes.cuh

Lines changed: 1877 additions & 0 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 524 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <NvInferRuntime.h>
19+
#include <cuda_bf16.h>
20+
#include <cuda_fp16.h>
21+
22+
namespace tensorrt_llm::kernels::moe_a2a
23+
{
24+
25+
// Configuration constants
26+
static constexpr int kMaxExperts = 256; // Maximum number of experts per rank
27+
static constexpr int kMaxTopK = 8; // Maximum top-k experts per token
28+
static constexpr int kMaxPayloads = 8; // Maximum number of different payload types
29+
static constexpr int kMaxRanks = 64; // Maximum supported EP size
30+
31+
// Describes a single payload type to be communicated
32+
struct PayloadDescriptor
33+
{
34+
void const* src_data; // Source data pointer [local_num_tokens, elements_per_token]
35+
int element_size; // Size of each element in bytes
36+
int elements_per_token; // Number of elements per token (e.g., hidden_size, top_k)
37+
};
38+
39+
// Kernel pointers packed into a struct for device access
40+
// Dispatch kernel pointers - const source data
41+
struct DispatchKernelPointers
42+
{
43+
void const* src_data_ptrs[kMaxPayloads]; // Array of source data pointers
44+
void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers
45+
int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload
46+
int* completion_flags[kMaxRanks]; // Per-rank completion flags pointers
47+
};
48+
49+
// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
50+
struct CombineKernelPointers
51+
{
52+
void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output
53+
void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const)
54+
int* completion_flags[kMaxRanks]; // Per-rank completion flags pointers
55+
};
56+
57+
// Dispatch phase parameters
58+
struct MoeA2ADispatchParams
59+
{
60+
// EP configuration
61+
int ep_size; // Number of EP ranks
62+
int ep_rank; // Current EP rank
63+
64+
// Token configuration
65+
int local_num_tokens; // Number of tokens on this rank
66+
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
67+
int top_k; // Number of experts per token
68+
69+
// Expert routing information
70+
int32_t const* token_selected_experts; // [local_num_tokens, top_k]
71+
72+
// Generic payloads
73+
int num_payloads; // Number of different payload types
74+
PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors
75+
76+
// Receive buffers and synchronization
77+
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
78+
int* completion_flags[kMaxRanks]; // Per-rank completion flags pointers
79+
80+
// Communication tracking
81+
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
82+
int* send_indices; // [local_num_tokens, ep_size] send index tensor
83+
int* local_token_counter; // Atomic counter for completed tokens on this rank
84+
85+
cudaStream_t stream;
86+
};
87+
88+
// Dispatch kernels
89+
void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params);
90+
91+
// Combine phase parameters
92+
struct MoeA2ACombineParams
93+
{
94+
// EP configuration
95+
int ep_size; // Number of EP ranks
96+
int ep_rank; // Current EP rank
97+
98+
// Token configuration
99+
int local_num_tokens; // Number of tokens on this rank
100+
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation
101+
int top_k; // Number of experts per token
102+
103+
// Expert routing information
104+
int const* send_indices; // [local_num_tokens, ep_size] from dispatch
105+
106+
// Single payload information
107+
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
108+
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
109+
int elements_per_token; // Number of elements per token
110+
nvinfer1::DataType dtype; // Data type for proper summation
111+
112+
// Synchronization
113+
int* local_token_counter; // Atomic counter for completed tokens
114+
int* completion_flags[kMaxRanks]; // Per-rank completion flags pointers
115+
116+
cudaStream_t stream;
117+
};
118+
119+
// Combine kernels
120+
void moe_a2a_combine_launch(MoeA2ACombineParams const& params);
121+
122+
} // namespace tensorrt_llm::kernels::moe_a2a

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ add_library(
7070
moeOp.cpp
7171
moeUtilOp.cpp
7272
moeCommOp.cpp
73+
moeAlltoAllOp.cpp
7374
moeLoadBalanceOp.cpp
7475
mxFp4BlockScaleMoe.cpp
7576
mxFp8Quantize.cpp

0 commit comments

Comments
 (0)