Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -613,22 +613,8 @@ void run(Data& data, void* stream)
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
data.mNumExpertGroups);
if (data.mNumExpertGroups > 1)
{
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups,
"Routing kernel expects #experts groups %d to be <= #warps %d", data.mNumExpertGroups, MaxNumGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
"Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts,
data.mNumExpertGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize,
"Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d",
data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups);
}
else
{
TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= #warps %d",
data.mTopK, topk::MaxNumTopK);
}
// Note: Routing-specific constraints (experts per group, topK limits) are checked later
// only when routing is actually needed (data.mPtrTopKIds == nullptr)
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
int const numBlocks = data.mNumTokens;
Expand Down Expand Up @@ -663,6 +649,25 @@ void run(Data& data, void* stream)
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
if (data.mPtrTopKIds == nullptr)
{
// Routing needs to be executed - validate routing kernel constraints
if (data.mNumExpertGroups > 1)
{
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups,
"Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
"Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts,
data.mNumExpertGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize,
"Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts "
"per group",
WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups);
}
else
{
TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= max topk %d",
data.mTopK, topk::MaxNumTopK);
}

int const numThreadsMain = data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts;
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain,
Expand Down
14 changes: 13 additions & 1 deletion tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,13 +1044,25 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor,
dtype = origin_tensor.dtype
device = origin_tensor.device
shapes = []
for d in dims:
for i, d in enumerate(dims):
if isinstance(d, StaticDim):
assert d.val == origin_tensor.shape[i]
shapes.append(d.val)
else:
# TODO: how to make sure the created Tensor has the min/max info
assert isinstance(d, DynamicDim)
shapes.append(d.opt)

if len(dims) == 2 and isinstance(dims[0], DynamicDim) and isinstance(
dims[1], StaticDim) and (dtype == torch.int32
or dtype == torch.int64):
# We should be carefully about int values, since they might be index like topk_index.
# We want to keep them legal, so just repeating input tensor.
repeat_times = (shapes[0] + origin_tensor.shape[0] -
1) // origin_tensor.shape[0]
dup_tensor = origin_tensor.repeat(repeat_times, 1)[:shapes[0]]
return dup_tensor

# TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
# One solution is to manituplate the tensor content to make it more like the real data
# during the tuning process. This can by controlled in the preparation phase by the runner.
Expand Down
Loading