Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 179 additions & 36 deletions cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ inline void fhcZeroWorkspaces(float* y_acc, uint32_t y_elems, float* r_acc, uint
} // namespace

// ---- mHC fused kernel shape constants (mirrors the Python module) ----
static constexpr uint32_t FHC_SHAPE_N = 24; // HC_MULT * (2 + HC_MULT) = 4 * 6 = 24
static constexpr uint32_t FHC_HIDDEN = 4096; // only this hidden size is currently wired up
// HC_MULT * (2 + HC_MULT) = 4 * 6 = 24.
static constexpr uint32_t FHC_SHAPE_N = 24;
static constexpr uint32_t FHC_HC_MULT = 4;
static constexpr uint32_t FHC_HIDDEN_FLASH = 4096;
static constexpr uint32_t FHC_HIDDEN_PRO = 7168;
static constexpr uint32_t FHC_BLOCK_M = 64;
static constexpr uint32_t FHC_BLOCK_N = 32;
static constexpr uint32_t FHC_BLOCK_K = 64;
Expand All @@ -103,6 +105,38 @@ static constexpr uint32_t FHC_N_INPUT_STG = 2;
static constexpr uint32_t FHC_NUM_MMA_TH = 128;
static constexpr uint32_t FHC_NUM_PMAP_TH = 128;

template <uint32_t Hidden>
static constexpr bool isSupportedFhcHidden()
{
return Hidden == FHC_HIDDEN_FLASH || Hidden == FHC_HIDDEN_PRO;
}

static bool isSupportedFhcHiddenRuntime(int hidden_size)
{
return hidden_size == static_cast<int>(FHC_HIDDEN_FLASH) || hidden_size == static_cast<int>(FHC_HIDDEN_PRO);
}

// Validate the tcgen05 MMA fused-HC compile-time shape contract. Hidden must
// be divisible into BLOCK_K tiles, kNumSplits must evenly divide those tiles,
// and the hidden dimension must align with the per-token post-map vectorized
// write granularity. Keep this in sync with the Python tactic filter.
template <uint32_t Hidden, uint32_t KS>
static constexpr bool isSupportedFhcMmaKS()
{
static_assert(isSupportedFhcHidden<Hidden>(), "Unsupported fused-HC hidden size");
static_assert(KS > 0, "kNumSplits must be positive");

constexpr uint32_t hTilesPerHc = Hidden / FHC_BLOCK_K;
constexpr uint32_t blockSizeBf = FHC_NUM_MMA_TH + FHC_NUM_PMAP_TH;
constexpr uint32_t warpSizeBf = 32;
constexpr uint32_t numWarpsBf = blockSizeBf / warpSizeBf;
constexpr uint32_t toksPerCta = (FHC_BLOCK_M + KS - 1) / KS;
constexpr uint32_t warpsPerTok = (numWarpsBf > toksPerCta) ? (numWarpsBf / toksPerCta) : 1u;
constexpr uint32_t bf16VecLi = 8;

return Hidden % FHC_BLOCK_K == 0 && hTilesPerHc % KS == 0 && Hidden % (warpsPerTok * warpSizeBf * bf16VecLi) == 0;
}

static CUtensorMap makeTma2D(void* base, CUtensorMapDataType dtype, uint64_t gmemInner, uint64_t gmemOuter,
uint32_t smemInner, uint32_t smemOuter, uint64_t gmemOuterStrideBytes, uint32_t swizzleBytes, uint32_t elemBytes)
{
Expand Down Expand Up @@ -144,25 +178,42 @@ static constexpr uint32_t fhcSmemSize()
using FusedRoutFn = void (*)(
uint32_t, CUtensorMap, CUtensorMap, CUtensorMap, CUtensorMap, float*, float const*, float const*, float*);

template <uint32_t KS>
template <uint32_t Hidden, uint32_t KS>
static FusedRoutFn fhcInstance()
{
return &fused_mhc::fused_tf32_pmap_gemm_rout_atomic_impl<FHC_SHAPE_N, FHC_HIDDEN, FHC_HC_MULT, FHC_BLOCK_M,
FHC_BLOCK_N, FHC_BLOCK_K, FHC_SWIZZLE_CD, FHC_N_B_STAGES, FHC_N_INPUT_STG, FHC_NUM_MMA_TH, FHC_NUM_PMAP_TH, KS,
static_assert(isSupportedFhcHidden<Hidden>(), "Unsupported fused-HC hidden size");
static_assert(isSupportedFhcMmaKS<Hidden, KS>(), "Unsupported fused-HC MMA kNumSplits for hidden size");
return &fused_mhc::fused_tf32_pmap_gemm_rout_atomic_impl<FHC_SHAPE_N, Hidden, FHC_HC_MULT, FHC_BLOCK_M, FHC_BLOCK_N,
FHC_BLOCK_K, FHC_SWIZZLE_CD, FHC_N_B_STAGES, FHC_N_INPUT_STG, FHC_NUM_MMA_TH, FHC_NUM_PMAP_TH, KS,
/*kEarlyRelease=*/false>;
}

template <uint32_t Hidden, uint32_t KS>
static FusedRoutFn fhcInstanceIfSupported()
{
if constexpr (isSupportedFhcMmaKS<Hidden, KS>())
{
return fhcInstance<Hidden, KS>();
}
else
{
TLLM_CHECK_WITH_INFO(false, "mhcFusedHcLaunch: unsupported kNumSplits=%u for hidden_size=%u", KS, Hidden);
return nullptr;
}
}

template <uint32_t Hidden>
static FusedRoutFn pickFhc(uint32_t ks)
{
switch (ks)
{
case 1: return fhcInstance<1>();
case 2: return fhcInstance<2>();
case 4: return fhcInstance<4>();
case 8: return fhcInstance<8>();
case 16: return fhcInstance<16>();
case 32: return fhcInstance<32>();
case 64: return fhcInstance<64>();
case 1: return fhcInstanceIfSupported<Hidden, 1>();
case 2: return fhcInstanceIfSupported<Hidden, 2>();
case 4: return fhcInstanceIfSupported<Hidden, 4>();
case 8: return fhcInstanceIfSupported<Hidden, 8>();
case 16: return fhcInstanceIfSupported<Hidden, 16>();
case 32: return fhcInstanceIfSupported<Hidden, 32>();
case 64: return fhcInstanceIfSupported<Hidden, 64>();
default: TLLM_CHECK_WITH_INFO(false, "mhcFusedHcLaunch: unsupported kNumSplits=%u", ks); return nullptr;
}
}
Expand Down Expand Up @@ -197,20 +248,24 @@ static int selectBigFuseBS(int M)
return 512;
}

void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev, float const* post_mix_prev,
float const* comb_mix_prev, float const* w_t, float const* hc_scale, float const* hc_base,
__nv_bfloat16* residual_cur, float* post_mix_cur, float* comb_mix_cur, __nv_bfloat16* layer_input_cur,
float* y_acc_workspace, float* r_acc_workspace, int M, int hidden_size, int hc_mult, int num_k_splits,
int bigfuse_block_size, float rms_eps, float hc_pre_eps, float hc_sinkhorn_eps, float hc_post_mult_value,
int sinkhorn_repeat, cudaStream_t stream)
template <uint32_t Hidden>
static void mhcFusedHcLaunchImpl(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev,
float const* post_mix_prev, float const* comb_mix_prev, float const* w_t, float const* hc_scale,
float const* hc_base, __nv_bfloat16* residual_cur, float* post_mix_cur, float* comb_mix_cur,
__nv_bfloat16* layer_input_cur, float* y_acc_workspace, float* r_acc_workspace, int M, int hidden_size, int hc_mult,
int num_k_splits, int bigfuse_block_size, float rms_eps, float hc_pre_eps, float hc_sinkhorn_eps,
float hc_post_mult_value, int sinkhorn_repeat, cudaStream_t stream)
{
if (M <= 0)
return;

static_assert(isSupportedFhcHidden<Hidden>(), "Unsupported fused-HC hidden size");
TLLM_CHECK_WITH_INFO(hidden_size == static_cast<int>(Hidden),
"mhcFusedHcLaunch: dispatched Hidden=%u but got hidden_size=%d", Hidden, hidden_size);
TLLM_CHECK_WITH_INFO(hc_mult == static_cast<int>(FHC_HC_MULT),
"mhcFusedHcLaunch: hc_mult=%d not supported (only %u)", hc_mult, FHC_HC_MULT);

constexpr uint32_t SHAPE_K = FHC_HC_MULT * FHC_HIDDEN;
constexpr uint32_t SHAPE_K = FHC_HC_MULT * Hidden;

uint32_t const m_u = static_cast<uint32_t>(M);
uint32_t const ks = (num_k_splits > 0) ? static_cast<uint32_t>(num_k_splits) : pickKSplits(M);
Expand All @@ -225,8 +280,8 @@ void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual
SHAPE_K, m_u, FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(SHAPE_K) * sizeof(__nv_bfloat16),
/*swizzleBytes=*/128, sizeof(__nv_bfloat16));

CUtensorMap desc_x = makeTma2D(const_cast<__nv_bfloat16*>(x_prev), CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, FHC_HIDDEN,
m_u, FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(FHC_HIDDEN) * sizeof(__nv_bfloat16),
CUtensorMap desc_x = makeTma2D(const_cast<__nv_bfloat16*>(x_prev), CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, Hidden, m_u,
FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(Hidden) * sizeof(__nv_bfloat16),
/*swizzleBytes=*/128, sizeof(__nv_bfloat16));

CUtensorMap desc_b = makeTma2D(const_cast<float*>(w_t), CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, SHAPE_K, FHC_SHAPE_N,
Expand All @@ -239,7 +294,7 @@ void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual

// ---- Step 1: fused post-mapping + TF32 GEMM + sqrsum + residual_out ----
constexpr uint32_t fused_smem = fhcSmemSize();
FusedRoutFn fa = pickFhc(ks);
FusedRoutFn fa = pickFhc<Hidden>(ks);
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
reinterpret_cast<void const*>(fa), cudaFuncAttributeMaxDynamicSharedMemorySize, fused_smem));

Expand All @@ -257,6 +312,37 @@ void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual
hc_post_mult_value, sinkhorn_repeat, /*num_splits=*/1, /*block_size=*/bs, stream);
}

void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev, float const* post_mix_prev,
float const* comb_mix_prev, float const* w_t, float const* hc_scale, float const* hc_base,
__nv_bfloat16* residual_cur, float* post_mix_cur, float* comb_mix_cur, __nv_bfloat16* layer_input_cur,
float* y_acc_workspace, float* r_acc_workspace, int M, int hidden_size, int hc_mult, int num_k_splits,
int bigfuse_block_size, float rms_eps, float hc_pre_eps, float hc_sinkhorn_eps, float hc_post_mult_value,
int sinkhorn_repeat, cudaStream_t stream)
{
if (M <= 0)
return;

TLLM_CHECK_WITH_INFO(isSupportedFhcHiddenRuntime(hidden_size),
"mhcFusedHcLaunch: unsupported hidden_size=%d; supported hidden sizes are 4096 and 7168", hidden_size);

switch (hidden_size)
{
case static_cast<int>(FHC_HIDDEN_FLASH):
mhcFusedHcLaunchImpl<FHC_HIDDEN_FLASH>(x_prev, residual_prev, post_mix_prev, comb_mix_prev, w_t, hc_scale,
hc_base, residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur, y_acc_workspace, r_acc_workspace, M,
hidden_size, hc_mult, num_k_splits, bigfuse_block_size, rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, stream);
return;
case static_cast<int>(FHC_HIDDEN_PRO):
mhcFusedHcLaunchImpl<FHC_HIDDEN_PRO>(x_prev, residual_prev, post_mix_prev, comb_mix_prev, w_t, hc_scale,
hc_base, residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur, y_acc_workspace, r_acc_workspace, M,
hidden_size, hc_mult, num_k_splits, bigfuse_block_size, rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, stream);
return;
default: return;
}
}

// ===================================================================
// FMA-path fused hyper-connection boundary launcher.
//
Expand Down Expand Up @@ -317,6 +403,8 @@ void mhcFusedHcFmaLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* resid

TLLM_CHECK_WITH_INFO(hc_mult == static_cast<int>(FHC_HC_MULT),
"mhcFusedHcFmaLaunch: hc_mult=%d not supported (only %u)", hc_mult, FHC_HC_MULT);
TLLM_CHECK_WITH_INFO(hidden_size % static_cast<int>(FHC_BLOCK_K) == 0,
"mhc fused-HC FMA path requires hidden_size to be divisible by %u, got %d", FHC_BLOCK_K, hidden_size);
TLLM_CHECK_WITH_INFO(
FHC_SHAPE_N % tile_n == 0, "mhcFusedHcFmaLaunch: SHAPE_N=%u not divisible by tile_n=%d", FHC_SHAPE_N, tile_n);

Expand Down Expand Up @@ -357,29 +445,48 @@ using FusedAllInOneFn = void (*)(uint32_t, CUtensorMap, CUtensorMap, CUtensorMap
__nv_bfloat16*, float*, float*, int*, float const*, float const*, float const*, float const*, float*, float*, float,
float, float, float, uint32_t);

template <uint32_t KS>
template <uint32_t Hidden, uint32_t KS>
Comment thread
pcastonguay marked this conversation as resolved.
static FusedAllInOneFn fhcAllInOneInstance()
{
return &fused_mhc::fused_allinone_tf32_pmap_gemm_atomic_impl<FHC_SHAPE_N, FHC_HIDDEN, FHC_HC_MULT, FHC_BLOCK_M,
static_assert(isSupportedFhcHidden<Hidden>(), "Unsupported fused-HC hidden size");
static_assert(isSupportedFhcMmaKS<Hidden, KS>(), "Unsupported fused-HC MMA kNumSplits for hidden size");
return &fused_mhc::fused_allinone_tf32_pmap_gemm_atomic_impl<FHC_SHAPE_N, Hidden, FHC_HC_MULT, FHC_BLOCK_M,
FHC_BLOCK_N, FHC_BLOCK_K, FHC_SWIZZLE_CD, FHC_N_B_STAGES, FHC_N_INPUT_STG, FHC_NUM_MMA_TH, FHC_NUM_PMAP_TH, KS>;
}

template <uint32_t Hidden, uint32_t KS>
static FusedAllInOneFn fhcAllInOneInstanceIfSupported()
{
if constexpr (isSupportedFhcMmaKS<Hidden, KS>())
{
return fhcAllInOneInstance<Hidden, KS>();
}
else
{
TLLM_CHECK_WITH_INFO(
false, "mhcFusedHcAllInOneLaunch: unsupported kNumSplits=%u for hidden_size=%u", KS, Hidden);
return nullptr;
}
}

template <uint32_t Hidden>
static FusedAllInOneFn pickFhcAllInOne(uint32_t ks)
{
switch (ks)
{
case 1: return fhcAllInOneInstance<1>();
case 2: return fhcAllInOneInstance<2>();
case 4: return fhcAllInOneInstance<4>();
case 8: return fhcAllInOneInstance<8>();
case 16: return fhcAllInOneInstance<16>();
case 32: return fhcAllInOneInstance<32>();
case 64: return fhcAllInOneInstance<64>();
case 1: return fhcAllInOneInstanceIfSupported<Hidden, 1>();
case 2: return fhcAllInOneInstanceIfSupported<Hidden, 2>();
case 4: return fhcAllInOneInstanceIfSupported<Hidden, 4>();
case 8: return fhcAllInOneInstanceIfSupported<Hidden, 8>();
case 16: return fhcAllInOneInstanceIfSupported<Hidden, 16>();
case 32: return fhcAllInOneInstanceIfSupported<Hidden, 32>();
case 64: return fhcAllInOneInstanceIfSupported<Hidden, 64>();
default: TLLM_CHECK_WITH_INFO(false, "mhcFusedHcAllInOneLaunch: unsupported kNumSplits=%u", ks); return nullptr;
}
}

void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev,
template <uint32_t Hidden>
static void mhcFusedHcAllInOneLaunchImpl(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev,
float const* post_mix_prev, float const* comb_mix_prev, float const* w_t, float const* hc_scale,
float const* hc_base, __nv_bfloat16* residual_cur, float* post_mix_cur, float* comb_mix_cur,
__nv_bfloat16* layer_input_cur, float* y_acc_workspace, float* r_acc_workspace, int* done_counter_workspace, int M,
Expand All @@ -389,10 +496,13 @@ void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const*
if (M <= 0)
return;

static_assert(isSupportedFhcHidden<Hidden>(), "Unsupported fused-HC hidden size");
TLLM_CHECK_WITH_INFO(hidden_size == static_cast<int>(Hidden),
"mhcFusedHcAllInOneLaunch: dispatched Hidden=%u but got hidden_size=%d", Hidden, hidden_size);
TLLM_CHECK_WITH_INFO(hc_mult == static_cast<int>(FHC_HC_MULT),
"mhcFusedHcAllInOneLaunch: hc_mult=%d not supported (only %u)", hc_mult, FHC_HC_MULT);

constexpr uint32_t SHAPE_K = FHC_HC_MULT * FHC_HIDDEN;
constexpr uint32_t SHAPE_K = FHC_HC_MULT * Hidden;

uint32_t const m_u = static_cast<uint32_t>(M);
uint32_t const ks = (num_k_splits > 0) ? static_cast<uint32_t>(num_k_splits) : 1u;
Expand All @@ -407,8 +517,8 @@ void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const*
SHAPE_K, m_u, FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(SHAPE_K) * sizeof(__nv_bfloat16),
/*swizzleBytes=*/128, sizeof(__nv_bfloat16));

CUtensorMap desc_x = makeTma2D(const_cast<__nv_bfloat16*>(x_prev), CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, FHC_HIDDEN,
m_u, FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(FHC_HIDDEN) * sizeof(__nv_bfloat16),
CUtensorMap desc_x = makeTma2D(const_cast<__nv_bfloat16*>(x_prev), CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, Hidden, m_u,
FHC_BLOCK_K, FHC_BLOCK_M, static_cast<uint64_t>(Hidden) * sizeof(__nv_bfloat16),
/*swizzleBytes=*/128, sizeof(__nv_bfloat16));

CUtensorMap desc_b = makeTma2D(const_cast<float*>(w_t), CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, SHAPE_K, FHC_SHAPE_N,
Expand All @@ -421,7 +531,7 @@ void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const*

// ---- Launch the single all-in-one kernel ----
constexpr uint32_t fused_smem = fhcAllInOneSmemSize();
FusedAllInOneFn fa = pickFhcAllInOne(ks);
FusedAllInOneFn fa = pickFhcAllInOne<Hidden>(ks);
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
reinterpret_cast<void const*>(fa), cudaFuncAttributeMaxDynamicSharedMemorySize, fused_smem));

Expand All @@ -433,6 +543,37 @@ void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const*
static_cast<uint32_t>(sinkhorn_repeat));
}

void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev,
float const* post_mix_prev, float const* comb_mix_prev, float const* w_t, float const* hc_scale,
float const* hc_base, __nv_bfloat16* residual_cur, float* post_mix_cur, float* comb_mix_cur,
__nv_bfloat16* layer_input_cur, float* y_acc_workspace, float* r_acc_workspace, int* done_counter_workspace, int M,
int hidden_size, int hc_mult, int num_k_splits, float rms_eps, float hc_pre_eps, float hc_sinkhorn_eps,
float hc_post_mult_value, int sinkhorn_repeat, cudaStream_t stream)
{
if (M <= 0)
return;

TLLM_CHECK_WITH_INFO(isSupportedFhcHiddenRuntime(hidden_size),
"mhcFusedHcAllInOneLaunch: unsupported hidden_size=%d; supported hidden sizes are 4096 and 7168", hidden_size);

switch (hidden_size)
{
case static_cast<int>(FHC_HIDDEN_FLASH):
mhcFusedHcAllInOneLaunchImpl<FHC_HIDDEN_FLASH>(x_prev, residual_prev, post_mix_prev, comb_mix_prev, w_t,
hc_scale, hc_base, residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur, y_acc_workspace,
r_acc_workspace, done_counter_workspace, M, hidden_size, hc_mult, num_k_splits, rms_eps, hc_pre_eps,
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, stream);
return;
case static_cast<int>(FHC_HIDDEN_PRO):
mhcFusedHcAllInOneLaunchImpl<FHC_HIDDEN_PRO>(x_prev, residual_prev, post_mix_prev, comb_mix_prev, w_t, hc_scale,
hc_base, residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur, y_acc_workspace, r_acc_workspace,
done_counter_workspace, M, hidden_size, hc_mult, num_k_splits, rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, stream);
return;
default: return;
}
}

// ===================================================================
// All-in-one single-kernel fused hyper-connection launcher (FMA).
//
Expand Down Expand Up @@ -505,6 +646,8 @@ void mhcFusedHcFmaAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 cons

TLLM_CHECK_WITH_INFO(hc_mult == static_cast<int>(FHC_HC_MULT),
"mhcFusedHcFmaAllInOneLaunch: hc_mult=%d not supported (only %u)", hc_mult, FHC_HC_MULT);
TLLM_CHECK_WITH_INFO(hidden_size % static_cast<int>(FHC_BLOCK_K) == 0,
"mhc fused-HC FMA path requires hidden_size to be divisible by %u, got %d", FHC_BLOCK_K, hidden_size);
TLLM_CHECK_WITH_INFO(FHC_SHAPE_N % tile_n == 0,
"mhcFusedHcFmaAllInOneLaunch: SHAPE_N=%u not divisible by tile_n=%d", FHC_SHAPE_N, tile_n);

Expand Down
10 changes: 8 additions & 2 deletions cpp/tensorrt_llm/kernels/mhcKernels/mhcKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ void mhcPostMappingLaunch(__nv_bfloat16 const* residual, __nv_bfloat16 const* x,
// y_acc_workspace: fp32 [M, 24]
// r_acc_workspace: fp32 [M]
//
// Shape constraints (B200 / sm_100a): hidden_size == 4096, hc_mult == 4.
// Shape constraints (B200/B300 / sm_100a):
// hc_mult == 4
// hidden_size in {4096, 7168} for SM100/tcgen05 MMA fused-HC paths.
// FMA fused-HC paths use runtime hidden_size but still require hidden_size % 64 == 0.
// Passing num_k_splits == 0 or bigfuse_block_size == 0 falls back to the
// internal heuristics.
void mhcFusedHcLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev, float const* post_mix_prev,
Expand Down Expand Up @@ -101,7 +104,10 @@ void mhcFusedHcFmaLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* resid
// r_acc_workspace: fp32 [M]
// done_counter_workspace: int32 [ceil(M / 64)]
//
// Shape constraints (B200 / sm_100a): hidden_size == 4096, hc_mult == 4.
// Shape constraints (B200/B300 / sm_100a):
// hc_mult == 4
// hidden_size in {4096, 7168} for SM100/tcgen05 MMA fused-HC paths.
// FMA fused-HC paths use runtime hidden_size but still require hidden_size % 64 == 0.
// Passing num_k_splits == 0 falls back to internal heuristics.
void mhcFusedHcAllInOneLaunch(__nv_bfloat16 const* x_prev, __nv_bfloat16 const* residual_prev,
float const* post_mix_prev, float const* comb_mix_prev, float const* w_t, float const* hc_scale,
Expand Down
Loading
Loading