Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ struct MoeLoadBalanceStatisticInfo

// rawDataWindowSize means the size of the raw data window.
// e.g. how many steps of raw data are kept in the memory.
int rawDataWindowSize = 1;
// current we keep only the data in current iteration, previous should sum to expertLoadFactor.
static constexpr int rawDataWindowSize = 1;

// decayFactor means the decay factor of the raw data per step.
// e.g. if decayFactor is 0.9, then the raw data of expert i will be decayed by 0.9 for each step.
float decayFactor = 0.9f;
// e.g. if decayFactor is 0.95, then the raw data of expert i will be decayed by 0.95 for each step.
float decayFactor = 0.95f;
};

// The placement information for GPU
Expand Down
84 changes: 76 additions & 8 deletions cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
signal->stepAndOwner += MoeLoadBalanceSingleLayerSignal::kCPU;
}

template <typename TYPE>
__global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount)
{
if (*enabled == 0)
{
return;
}
TYPE oldExpertTokenCount = {0};
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
}

template <typename TYPE>
__global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount)
{
Expand All @@ -151,8 +164,8 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
}

__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
int totalEltCount, int* const enabled, int* const gatheredRawExpertIds)
__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertTokenCount, int totalEltCount,
int* const enabled, int* const gatheredRawExpertIds)
{
extern __shared__ int sharedExpertCount[];
if (*enabled == 0)
Expand All @@ -175,19 +188,19 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceS
__syncthreads();
for (int i = threadIdx.x; i < metaInfo.expertCount; i += blockDim.x)
{
atomicAdd_system(&statisticInfo.expertTokenCount[i], sharedExpertCount[i]);
atomicAdd_system(&expertTokenCount[i], sharedExpertCount[i]);
}
}

__global__ void updateLoadFactorKernel(
MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int* const enabled)
__global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
int* expertTokenCountPtr, int* const enabled)
{
if (*enabled == 0)
{
return;
}
int expertIdx = blockIdx.x * blockDim.x + threadIdx.x;
int expertTokenCount = statisticInfo.expertTokenCount[expertIdx];
int expertTokenCount = expertTokenCountPtr[expertIdx];
float* loadFactor = statisticInfo.expertLoadFactor;
loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
}
Expand Down Expand Up @@ -233,16 +246,71 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
}
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
metaInfo, statisticInfo, totalEltCount, enabled, gatheredRawExpertIds);
metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds);
}

if (isLastStage)
{
// only last stage need update load factor.
int threadCount = 128;
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(metaInfo, statisticInfo, enabled);
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled);
}
}

void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens,
int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds,
cudaStream_t stream)
{
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
if (isFirstStage)
{
// shift window and zero expertTokenCount
// only first stage need shift window.
int threadCount = metaInfo.expertCount;
auto* kernelFunc = zeroExpertTokenCountKernel<int>;
if (threadCount % 4 == 0)
{
threadCount /= 4;
kernelFunc = zeroExpertTokenCountKernel<int4>;
}
else if (threadCount % 2 == 0)
{
threadCount /= 2;
kernelFunc = zeroExpertTokenCountKernel<int2>;
}
dim3 gridDim(1);
dim3 blockDim(threadCount);
void* args[]
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
TLLM_CHECK_WITH_INFO(
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
}

{
// do the statistic into expertTokenCount and maybe also expertLoadFactor;
int threadCount = 1024;
int totalEltCount = numTotalTokens * metaInfo.topK;
int blockCount = (totalEltCount + threadCount - 1) / threadCount;
if (blockCount > smCount)
{
blockCount = smCount;
}
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
}
}

void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
int* globalExpertTokenCount, int* const enabled, cudaStream_t stream)
{
int threadCount = 128;
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
metaInfo, statisticInfo, globalExpertTokenCount, enabled);
}

template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
Expand Down
26 changes: 26 additions & 0 deletions cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal);
void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int numTotalTokens,
int* const enabled, bool isFirstStage, bool isLastStage, int* const gatheredRawExpertIds, cudaStream_t stream);

// @brief do the statistic based on local device's data
//
// This function is used to launch a kernel to do the statistic for local tokens.
//
// @param metaInfo: the meta info
// @param numTotalTokens: the total number of tokens in localRawExpertIds
// @param localExpertTokenCount: the token count that each expert has for local tokens.
// @param enabled: flag on device memory to indicate if the statistic is enabled
// @param isFirstStage: whether the current stage is the first stage (only first stage need shift window)
// @param isLastStage: whether the current stage is the last stage (only last stage need update load factor)
// @param localRawExpertIds: the gathered raw expert ids, should have shape [numTotalTokens, metaInfo.topK]
void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens,
int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds,
cudaStream_t stream);

// @brief update the statistic info based on global info
//
// This function is used to launch a kernel to update the statistic info per iteration.
//
// @param metaInfo: the meta info
// @param statisticInfo: the statistic info
// @param globalExpertTokenCount: the global expert token count, should have shape [metaInfo.expertCount]
// @param enabled: flag on device memory to indicate if the statistic is enabled
void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo,
int* globalExpertTokenCount, int* const enabled, cudaStream_t stream);

// @brief compute the route
//
// This function is used to launch a kernel to compute the route based on the token selected experts and the placement
Expand Down
79 changes: 79 additions & 0 deletions cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,60 @@ void moeLoadBalanceStatistic(torch::Tensor gatheredRawExpertIds, torch::Tensor e
static_cast<bool>(isFirstStage), static_cast<bool>(isLastStage), gatheredRawExpertIds.data_ptr<int>(), stream);
}

void moeHierarchicalStatisticLocalDevice(torch::Tensor localRawExpertIds, torch::Tensor localExpertTokenCount,
torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr, int64_t isFirstStage, int64_t isLastStage)
{
CHECK_INPUT(localRawExpertIds, torch::kInt32);
CHECK_INPUT(localExpertTokenCount, torch::kInt32);
CHECK_INPUT(enabled, torch::kInt32);
TORCH_CHECK(localRawExpertIds.dim() == 2, "localRawExpertIds must be a 2D tensor");
TORCH_CHECK(localExpertTokenCount.dim() == 1, "localExpertTokenCount must be a 1D tensor");
int topK = localRawExpertIds.size(1);
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
TORCH_CHECK(isFirstStage == 0 || isFirstStage == 1, "isFirstStage must be 0 or 1");
TORCH_CHECK(isLastStage == 0 || isLastStage == 1, "isLastStage must be 0 or 1");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");

auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();

tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();

TORCH_CHECK(localExpertTokenCount.size(0) == metaInfo.expertCount, "localExpertTokenCount should have shape (%d,)",
metaInfo.expertCount);
TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK");

int numTotalTokens = localRawExpertIds.size(0);

tensorrt_llm::kernels::moeHierarchicalStatisticLocalDevice(metaInfo, numTotalTokens,
localExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), static_cast<bool>(isFirstStage),
static_cast<bool>(isLastStage), localRawExpertIds.data_ptr<int>(), stream);
}

void moeHierarchicalStatisticUpdate(
torch::Tensor globalExpertTokenCount, torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr)
{
CHECK_INPUT(globalExpertTokenCount, torch::kInt32);
CHECK_INPUT(enabled, torch::kInt32);
TORCH_CHECK(globalExpertTokenCount.dim() == 1, "globalExpertTokenCount must be a 1D tensor");
TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor");
TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element");
TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null");
auto* loadBalancer
= reinterpret_cast<tensorrt_llm::runtime::SingleLayerMoeLoadBalancer*>(singleLayerLoadBalancerPtr);
auto stream = at::cuda::getCurrentCUDAStream();

tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo();
auto statisticInfo = loadBalancer->getStatisticInfo();

TORCH_CHECK(globalExpertTokenCount.size(0) == metaInfo.expertCount,
"globalExpertTokenCount should have shape (%d,)", metaInfo.expertCount);
tensorrt_llm::kernels::moeHierarchicalStatisticUpdate(
metaInfo, *statisticInfo, globalExpertTokenCount.data_ptr<int>(), enabled.data_ptr<int>(), stream);
}

torch::Tensor moeLoadBalanceRouting(
torch::Tensor tokenSelectedExperts, bool offsetByEpRank, int64_t singleLayerLoadBalancerPtr)
{
Expand Down Expand Up @@ -182,6 +236,31 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("moe_load_balance_statistic", &torch_ext::moeLoadBalanceStatistic);
}

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_hierarchical_statistic_local_device(Tensor local_raw_expert_ids, Tensor local_expert_token_count, Tensor "
"enabled, int "
"single_layer_load_balancer_ptr, int is_first_stage, int is_last_stage) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_hierarchical_statistic_local_device", &torch_ext::moeHierarchicalStatisticLocalDevice);
}

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_hierarchical_statistic_update(Tensor global_expert_token_count, Tensor enabled, int "
"single_layer_load_balancer_ptr) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_hierarchical_statistic_update", &torch_ext::moeHierarchicalStatisticUpdate);
}

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
Expand Down
27 changes: 10 additions & 17 deletions cpp/tests/kernels/moeLoadBalanceKernelTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ struct MoeLoadBalanceTestParam
bool isFirstStage;
bool isLastStage;
float decayFactor;
int rawDataWindowSize;
};

class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoadBalanceTestParam>
Expand All @@ -126,14 +125,13 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa
mMetaInfo.epSize = param.epSize;
mMetaInfo.slotCountPerRank = param.slotCountPerRank;

mStatisticInfo.rawDataWindowSize = param.rawDataWindowSize;
mStatisticInfo.decayFactor = param.decayFactor;

ASSERT_EQ(cudaStreamCreate(&mStream), cudaSuccess);

// allocate device memory
size_t expertLoadFactorSize = param.expertCount * sizeof(float);
size_t expertTokenCountSize = param.expertCount * param.rawDataWindowSize * sizeof(int);
size_t expertTokenCountSize = param.expertCount * mStatisticInfo.rawDataWindowSize * sizeof(int);
size_t gatheredIdsSize = param.maxTokenCountPerRank * param.epSize * param.topK * sizeof(int);

ASSERT_EQ(cudaMalloc(&mDeviceEnabled, sizeof(int)), cudaSuccess);
Expand All @@ -147,9 +145,9 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa

// allocate host memory for verification
mExpectedLoadFactor.resize(param.expertCount, 0.0f);
mExpectedExpertTokenCount.resize(param.expertCount * param.rawDataWindowSize);
mExpectedExpertTokenCount.resize(param.expertCount * mStatisticInfo.rawDataWindowSize);
mHostExpertLoadFactor.resize(param.expertCount);
mHostExpertTokenCount.resize(param.expertCount * param.rawDataWindowSize);
mHostExpertTokenCount.resize(param.expertCount * mStatisticInfo.rawDataWindowSize);
mHostGatheredIds.resize(param.maxTokenCountPerRank * param.epSize * param.topK);

// initialize the random number generator
Expand Down Expand Up @@ -188,7 +186,7 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam<MoeLoa
mExpectedExpertTokenCount = mHostExpertTokenCount;
if (param.isFirstStage)
{
for (int windowIdx = param.rawDataWindowSize - 1; windowIdx >= 0; --windowIdx)
for (int windowIdx = mStatisticInfo.rawDataWindowSize - 1; windowIdx >= 0; --windowIdx)
{
if (windowIdx > 0)
{
Expand Down Expand Up @@ -305,7 +303,7 @@ TEST_P(MoeLoadBalanceStatisticKernelTest, TestStatistics)
EXPECT_NEAR(mHostExpertLoadFactor[i], mExpectedLoadFactor[i], 1e-6)
<< "Expert " << i << " load factor mismatch";
}
for (int i = 0; i < param.expertCount * param.rawDataWindowSize; ++i)
for (int i = 0; i < param.expertCount * mStatisticInfo.rawDataWindowSize; ++i)
{
EXPECT_EQ(mHostExpertTokenCount[i], mExpectedExpertTokenCount[i]) << "Expert " << i << " token count mismatch";
}
Expand All @@ -323,8 +321,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati
/* maxTokenCountPerRank */ 128,
/* isFirstStage */ true,
/* isLastStage */ true,
/* decayFactor */ 0.9f,
/* rawDataWindowSize */ 3},
/* decayFactor */ 0.9f},
// large scale test scenarios
MoeLoadBalanceTestParam{/* expertCount */ 64,
/* topK */ 4,
Expand All @@ -334,8 +331,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati
/* maxTokenCountPerRank */ 512,
/* isFirstStage */ false,
/* isLastStage */ true,
/* decayFactor */ 0.95f,
/* rawDataWindowSize */ 5} // can add more test scenarios
/* decayFactor */ 0.95f} // can add more test scenarios
));

class MoeLoadBalanceRouteKernelTest : public ::testing::TestWithParam<MoeLoadBalanceTestParam>
Expand Down Expand Up @@ -601,8 +597,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
/* maxTokenCountPerRank */ 128,
/* isFirstStage */ true,
/* isLastStage */ true,
/* decayFactor */ 0.9f,
/* rawDataWindowSize */ 3},
/* decayFactor */ 0.9f},
// large scale test scenarios
MoeLoadBalanceTestParam{/* expertCount */ 256,
/* topK */ 8,
Expand All @@ -612,8 +607,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
/* maxTokenCountPerRank */ 5000,
/* isFirstStage */ false,
/* isLastStage */ true,
/* decayFactor */ 0.95f,
/* rawDataWindowSize */ 5},
/* decayFactor */ 0.95f},
// edge case: single rank
MoeLoadBalanceTestParam{/* expertCount */ 16,
/* topK */ 2,
Expand All @@ -623,5 +617,4 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern
/* maxTokenCountPerRank */ 64,
/* isFirstStage */ true,
/* isLastStage */ true,
/* decayFactor */ 0.9f,
/* rawDataWindowSize */ 1}));
/* decayFactor */ 0.9f}));
19 changes: 16 additions & 3 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,22 @@ def _(single_layer_load_balancer_ptr: int):
pass

@torch.library.register_fake("trtllm::moe_load_balance_statistic")
def _(single_layer_load_balancer_ptr: int,
gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor,
is_first_stage: bool, is_last_stage: bool):
def _(gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor,
single_layer_load_balancer_ptr: int, is_first_stage: bool,
is_last_stage: bool):
pass

@torch.library.register_fake(
"trtllm::moe_hierarchical_statistic_local_device")
def _(local_raw_expert_ids: torch.Tensor,
local_expert_token_count: torch.Tensor, enabled: torch.Tensor,
single_layer_load_balancer_ptr: int, is_first_stage: bool,
is_last_stage: bool):
pass

@torch.library.register_fake("trtllm::moe_hierarchical_statistic_update")
def _(global_expert_token_count: torch.Tensor, enabled: torch.Tensor,
single_layer_load_balancer_ptr: int):
pass

@torch.library.register_fake("trtllm::moe_load_balance_routing")
Expand Down
Loading