diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index a5e58df924..f2dbd121c5 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -187,24 +187,63 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); - // Currently only support: input_t = weight_t = state_t = bfloat16, matrixA_t = float - if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) { + if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code, + /*weight */ bfloat16_code, /*matrixA */ float32_code)) { using state_t = nv_bfloat16; using input_t = nv_bfloat16; using weight_t = nv_bfloat16; using matrixA_t = float; - + invokeSelectiveStateUpdate(p, stream); + } else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code, + /*weight */ bfloat16_code, /*matrixA */ float32_code)) { + using state_t = half; + using input_t = nv_bfloat16; + using weight_t = nv_bfloat16; + using matrixA_t = float; + invokeSelectiveStateUpdate(p, stream); + } else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code, + /*weight */ bfloat16_code, /*matrixA */ float32_code)) { + using state_t = float; + using input_t = nv_bfloat16; + using weight_t = nv_bfloat16; + using matrixA_t = float; + invokeSelectiveStateUpdate(p, stream); + } else if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code, + /*weight */ float32_code, /*matrixA */ float32_code)) { + using state_t = nv_bfloat16; + using input_t = nv_bfloat16; + using weight_t = float; + using matrixA_t = float; + invokeSelectiveStateUpdate(p, stream); + } else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code, + /*weight */ float32_code, /*matrixA */ float32_code)) { + using state_t = half; + using input_t = nv_bfloat16; + using weight_t = float; + using matrixA_t = float; + invokeSelectiveStateUpdate(p, stream); + } else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code, + /*weight */ float32_code, /*matrixA */ float32_code)) { + using state_t = float; + using input_t = nv_bfloat16; + using weight_t = float; + using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); } else { // Default case: unsupported dtype combination - TVM_FFI_ICHECK(false) << "Unsupported dtype combination for selective_state_update: " - << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits - << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Currently only support: " - << "state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32"; + TVM_FFI_ICHECK(false) + << "Unsupported dtype combination for selective_state_update: " + << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " + << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " + << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " + << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits + << ". Supported combos include:\n" + << " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" + << " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" + << " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n" + << " (state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)\n" + << " (state=float16, input=bfloat16, weight=float32, matrixA=float32)\n" + << " (state=float32, input=bfloat16, weight=float32, matrixA=float32)"; } } diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 1ea1dd3182..fd6524e0f7 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -62,23 +62,33 @@ __device__ __forceinline__ float thresholded_softplus(float dt_value) { return (dt_value <= threshold) ? softplus(dt_value) : dt_value; } -template -__device__ inline auto make_zero() -> T; - -template <> -__device__ inline auto make_zero() -> float2 { - return make_float2(0.f, 0.f); -} +// Simple packed vector type for loading N elements of type T +template +struct alignas(N * sizeof(T)) PackedAligned { + T val[N]; + static constexpr int count = N; + using dtype = T; +}; -template -__device__ inline auto make_zeros() -> load_t { - load_t rValue; +template +__device__ __forceinline__ auto make_zeros() -> load_t { + load_t ret{}; #pragma unroll - for (int i = 0; i < sizeof(load_t) / sizeof(compute_t); i++) { - auto* dst = reinterpret_cast(&rValue) + i; - convertAndStore(dst, 0.f); - } - return rValue; + for (int i = 0; i < ret.count; i++) + ret.val[i] = typename load_t::dtype{}; // default initialization + return ret; +}; + +// Computes the vector load size that ensures full warp utilization. +// Avoids cases like: dstate=64, load_t = sizeof(float4)/sizeof(f16), warpsize=32 (32 * 8 > 64) +// in which case a part of the warp would be idle. +template +inline constexpr auto getVectorLoadSizeForFullUtilization() -> unsigned { + static_assert(sizeof(float4) >= sizeof(T)); + constexpr unsigned maxHardwareLoadSize = sizeof(float4) / sizeof(T); + constexpr unsigned warpSize = 32; + constexpr unsigned maxLogicalLoadSize = (unsigned)DSTATE / warpSize; + return maxHardwareLoadSize < maxLogicalLoadSize ? maxHardwareLoadSize : maxLogicalLoadSize; } __device__ __forceinline__ float warpReduceSum(float val) { @@ -89,19 +99,17 @@ __device__ __forceinline__ float warpReduceSum(float val) { return val; } -template -struct VectorizedLoadTraits {}; - -template <> -struct VectorizedLoadTraits<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16> { - using input = float2; - using weight = float2; - using state = float2; - static constexpr auto chunk_size = sizeof(input) / sizeof(__nv_bfloat16); +template +struct SharedStorageSimple { + alignas(alignof(PackedAligned)) input_t x[dim]; + float out[dim]; + alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t B[dstate]; + alignas(alignof(PackedAligned)) input_t C[dstate]; }; -template +template __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { auto* __restrict__ output = reinterpret_cast(params.output); // output: (batch, nheads, dim) @@ -126,22 +134,24 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams int const nheads = params.nheads; int const ngroups = params.ngroups; - int const dim = params.dim; constexpr auto warpSize = 32; - auto const dim_offset = blockIdx.x * warpSize * numWarps; - auto const batch = blockIdx.y; - auto const head = blockIdx.z; + constexpr auto rowsPerWarp = (DIM + numWarps - 1) / numWarps; + + auto const batch = blockIdx.x; + auto const head = blockIdx.y; auto const group = head / (nheads / ngroups); auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - state += state_batch * nheads * dim * DSTATE + head * dim * DSTATE; + state += (state_batch * nheads + head) * DIM * DSTATE; + + __shared__ SharedStorageSimple sram; - __shared__ input_t sx[numWarps * warpSize]; - __shared__ float sdt[numWarps * warpSize]; - __shared__ weight_t sz[numWarps * warpSize]; + static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + using load_input_t = PackedAligned; auto const A_value = toFloat(A[head]); @@ -155,73 +165,81 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto d_value = D ? toFloat(D[head]) : 0.f; - auto _d = warp * warpSize + lane; - auto d = dim_offset + _d; - if (d < dim) { - sx[_d] = x[batch * params.x_stride_batch + head * dim + d]; - if (z) { - sz[_d] = z[batch * params.z_stride_batch + head * dim + d]; - } else { - convertAndStore(&sz[_d], 0.f); + if (warp == 0) { + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.x[d]); + *dst = *reinterpret_cast( + &x[batch * params.x_stride_batch + head * DIM + d]); + } + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.B[i]); + *dst = *reinterpret_cast( + &B[batch * params.B_stride_batch + group * DSTATE + i]); + } + } else if (warp == 1) { // Load z, C + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.z[d]); + *dst = z ? *reinterpret_cast( + &z[batch * params.z_stride_batch + head * DIM + d]) + : make_zeros(); + } + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.C[i]); + *dst = *reinterpret_cast( + &C[batch * params.C_stride_batch + group * DSTATE + i]); } - } else { - convertAndStore(&sx[_d], 0.f); - convertAndStore(&sz[_d], 0.f); } + __syncthreads(); - using Load = VectorizedLoadTraits; - - for (auto _d = warp * warpSize; _d < (warp + 1) * warpSize; _d++) { - auto d = dim_offset + _d; - if (d >= dim) break; + for (auto _d = warp * rowsPerWarp; _d < (warp + 1) * rowsPerWarp; _d++) { + auto d = _d; + if (d >= DIM) break; - float x_value = toFloat(sx[_d]); + float x_value = toFloat(sram.x[_d]); float out_value = d_value * x_value * int(lane == 0); // first lane has the value - for (int i = threadIdx.x * Load::chunk_size; i < DSTATE; i += warpSize * Load::chunk_size) { - auto rState = make_zeros(); + for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { + auto rState = make_zeros(); if (state_batch != params.pad_slot_id) - rState = *reinterpret_cast(&state[d * DSTATE + i]); - auto rB = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + i]); - auto rC = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); + rState = *reinterpret_cast(&state[d * DSTATE + i]); - auto* state_vals = reinterpret_cast(&rState); - auto const* B_vals = reinterpret_cast(&rB); - auto const* C_vals = reinterpret_cast(&rC); - - for (int ii = 0; ii < Load::chunk_size; ii++) { - auto state_value = toFloat(state_vals[ii]); - auto B_value = toFloat(B_vals[ii]); - auto C_value = toFloat(C_vals[ii]); + for (int ii = 0; ii < load_state_t::count; ii++) { + auto state_value = toFloat(rState.val[ii]); + auto B_value = toFloat(sram.B[i + ii]); + auto C_value = toFloat(sram.C[i + ii]); auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - convertAndStore(&state_vals[ii], new_state); + + convertAndStore(&rState.val[ii], new_state); out_value += new_state * C_value; } if (state_batch != params.pad_slot_id) - *reinterpret_cast(&state[d * DSTATE + i]) = rState; + *reinterpret_cast(&state[d * DSTATE + i]) = rState; } // warpReduce the out_value out_value = warpReduceSum(out_value); if (lane == 0) { - sdt[_d] = out_value; + sram.out[_d] = out_value; } } - if (d < dim) { - auto out_value = sdt[_d]; - if (z) { - float z_value = toFloat(sz[_d]); - float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; + __syncthreads(); + + for (int l = lane; l < rowsPerWarp; l += warpSize) { + auto d = warp * rowsPerWarp + l; + if (d < DIM) { + auto out_value = sram.out[d]; + if (z) { + float z_value = toFloat(sram.z[d]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); } - convertAndStore(&output[batch * params.out_stride_batch + head * dim + d], out_value); } } @@ -229,11 +247,11 @@ template struct SharedStorage { alignas(128) state_t state[numStages][rowsPerStage * dstate]; - input_t x[dim]; + alignas(alignof(PackedAligned)) input_t x[dim]; float out[dim]; // dt is special cause we're gonna store input in there as well - input_t z[dim]; - input_t B[dstate]; - input_t C[dstate]; + alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t B[dstate]; + alignas(alignof(PackedAligned)) input_t C[dstate]; using barrier_t = cuda::barrier; barrier_t bar_empty[numStages]; @@ -241,6 +259,94 @@ struct SharedStorage { barrier_t bar_consumers; }; +template +__device__ __forceinline__ void consumer_func_vertical( + int lane, int warp, float d_value, float dt_value, float dA, + SharedStorage& + sram) { + namespace cde = cuda::device::experimental; + for (auto dBegin = 0, stage = 0; dBegin < DIM; + dBegin += rowsPerStage, stage = (stage + 1) % numStages) { + // wait for the producer + sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); + +#pragma unroll + for (auto dd = warp; dd < rowsPerStage; dd += consumerWarps) { + auto d = dBegin + dd; + float const x_value = toFloat(sram.x[d]); + float out_value = d_value * x_value * int(lane == 0); // first lane has the value + + constexpr auto bankSize = sizeof(uint32_t); + constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); + + if constexpr (sizeof(state_t) == sizeof(input_t)) { + for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + uint32_t rB = *reinterpret_cast(&sram.B[i]); + auto* rB_ptr = reinterpret_cast(&rB); + + uint32_t rC = *reinterpret_cast(&sram.C[i]); + auto* rC_ptr = reinterpret_cast(&rC); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + auto const B_value = toFloat(rB_ptr[e]); + auto const C_value = toFloat(rC_ptr[e]); + + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } else { + for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + auto const B_value = toFloat(sram.B[i + e]); + auto const C_value = toFloat(sram.C[i + e]); + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } + + out_value = warpReduceSum(out_value); + if (lane == 0) { + sram.out[d] = out_value; + } + } + + // Unblock producer + cde::fence_proxy_async_shared_cta(); + auto _ = sram.bar_empty[stage].arrive(); + } +} + template __global__ void selective_state_update_kernel_producer_consumer_vertical( @@ -261,7 +367,6 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( int const nheads = params.nheads; int const ngroups = params.ngroups; - int const dim = params.dim; constexpr auto warpSize = 32; constexpr auto numWarps = 1 + consumerWarps; @@ -276,9 +381,9 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( using sram_t = SharedStorage; -#pragma nv_diag_suppress 20054 - __shared__ sram_t sram; -#pragma nv_diag_default 20054 + // Use dynamic shared memory to allow opting into extended shared memory on SM90+ + extern __shared__ __align__(128) char smem[]; + sram_t& sram = *reinterpret_cast(smem); namespace cde = cuda::device::experimental; namespace cg = cooperative_groups; @@ -296,10 +401,11 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( } __syncthreads(); - if (warp == consumerWarps) { - auto const state_offset = (state_batch * nheads + head) * dim; + if (warp == consumerWarps) // producer + { + auto const state_offset = (state_batch * nheads + head) * DIM; - for (int d = 0, stage = 0; d < dim + rowsPerStage * numStages; + for (int d = 0, stage = 0; d < DIM + rowsPerStage * numStages; d += rowsPerStage, stage = (stage + 1) % numStages) { if (lane == 0) { cg::invoke_one(cg::coalesced_threads(), [&]() { @@ -316,7 +422,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( cde::cp_async_bulk_wait_group_read<0>(); } - if (d < dim) { + if (d < DIM) { cde::cp_async_bulk_tensor_2d_global_to_shared(&sram.state[stage][0], &tensorState, /*x*/ 0, /*y*/ state_offset + d, sram.bar_full[stage]); @@ -335,8 +441,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( } } else { // consumers - using load_t = float2; - static constexpr auto vectorizedLoadSize = sizeof(load_t) / sizeof(weight_t); + using load_t = PackedAligned; #pragma unroll // Unblock the producer @@ -359,23 +464,23 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto const dA = __expf(A_value * dt_value); if (warp == 0) { // Load x, B - for (auto d = lane * vectorizedLoadSize; d < dim; d += warpSize * vectorizedLoadSize) { + for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * dim + d]); + *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); } - for (auto i = lane * vectorizedLoadSize; i < DSTATE; i += warpSize * vectorizedLoadSize) { + for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.B[i]); *dst = *reinterpret_cast( &B[batch * params.B_stride_batch + group * DSTATE + i]); } } else if (warp == 1) { // Load z, C - for (auto d = lane * vectorizedLoadSize; d < dim; d += warpSize * vectorizedLoadSize) { + for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.z[d]); *dst = - z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * dim + d]) - : make_zero(); + z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) + : make_zeros(); } - for (auto i = lane * vectorizedLoadSize; i < DSTATE; i += warpSize * vectorizedLoadSize) { + for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.C[i]); *dst = *reinterpret_cast( &C[batch * params.C_stride_batch + group * DSTATE + i]); @@ -384,46 +489,19 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( sram.bar_consumers.wait(sram.bar_consumers.arrive()); - for (auto dBegin = 0, stage = 0; dBegin < dim; - dBegin += rowsPerStage, stage = (stage + 1) % numStages) { - // wait for the producer - sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); - -#pragma unroll - for (auto dd = warp; dd < rowsPerStage; dd += consumerWarps) { - auto d = dBegin + dd; - float const x_value = toFloat(sram.x[d]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value - - for (int i = lane; i < DSTATE; i += warpSize) { - auto const state_value = (state_batch != params.pad_slot_id) - ? toFloat(sram.state[stage][dd * DSTATE + i]) - : 0.f; - auto const B_value = toFloat(sram.B[i]); - auto const C_value = toFloat(sram.C[i]); - - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&sram.state[stage][dd * DSTATE + i], new_state); - out_value += new_state * C_value; - } - - out_value = warpReduceSum(out_value); - if (lane == 0) { - sram.out[d] = out_value; - } - } - - // Unblock producer - cde::fence_proxy_async_shared_cta(); - auto _ = sram.bar_empty[stage].arrive(); - } + if (state_batch != params.pad_slot_id) + consumer_func_vertical(lane, warp, d_value, dt_value, dA, + sram); + else + consumer_func_vertical(lane, warp, d_value, dt_value, dA, + sram); // Write output sram.bar_consumers.wait(sram.bar_consumers.arrive()); auto d = warp * warpSize + lane; - if (d < dim) { + if (d < DIM) { auto out_value = sram.out[d]; if (z) { float z_value = toFloat(sram.z[d]); @@ -431,7 +509,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( float silu_z = z_value * sig_z; out_value *= silu_z; } - convertAndStore(&output[batch * params.out_stride_batch + head * dim + d], out_value); + convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); } } #endif @@ -441,55 +519,103 @@ template ; - constexpr size_t vec_size = sizeof(typename Load::input); - constexpr size_t state_vec_size = sizeof(typename Load::state); - if (Load::chunk_size > 1) { - FLASHINFER_CHECK(reinterpret_cast(params.state) % state_vec_size == 0, - "state pointer must be aligned to ", state_vec_size, " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.B) % vec_size == 0, - "B pointer must be aligned to ", vec_size, " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % vec_size == 0, - "C pointer must be aligned to ", vec_size, " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % vec_size == 0, - "B batch stride must be aligned to ", vec_size, " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % vec_size == 0, - "C batch stride must be aligned to ", vec_size, " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % state_vec_size == 0, - "state head stride must be aligned to ", state_vec_size, " bytes"); - } + auto dispatch_dim_dstate = [&]() { + // Alignment checks for vectorized loads in simple kernel + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + using load_input_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, + "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, + "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + } + FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, + "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, + "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + selective_state_update_kernel_simple<<>>(params); + }; - auto dispatch_dstate = [&]() { - constexpr int numWarps = 2; - int const blocks_per_dim = (params.dim + 32 * numWarps - 1) / (32 * numWarps); - dim3 block(32, numWarps); - dim3 grid(blocks_per_dim, params.batch, params.nheads); - selective_state_update_kernel_simple - <<>>(params); + auto dispatch_dstate = [&]() { + switch (params.dstate) { + case 64: + dispatch_dim_dstate.template operator()(); + break; + case 128: + dispatch_dim_dstate.template operator()(); + break; + case 256: + dispatch_dim_dstate.template operator()(); + break; + default: + FLASHINFER_CHECK(false, "Unsupported dstate value. Supported values are: 64, 128, 256"); + } }; - switch (params.dstate) { + switch (params.dim) { case 64: dispatch_dstate.template operator()<64>(); break; case 128: dispatch_dstate.template operator()<128>(); break; - case 256: - dispatch_dstate.template operator()<256>(); - break; default: - FLASHINFER_CHECK(false, "Unsupported dstate value. Supported values are: 64, 128, 256"); + FLASHINFER_CHECK(false, "Unsupported dim value. Supported values are: 64, 128"); } } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 else { auto dispatch_dim_dstate = [&]() { + // Alignment checks for vectorized loads in Hopper kernel + // Note: State uses TMA which requires 128B alignment (checked below) + // x, z, B, and C use PackedAligned + using load_input_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, + "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, + "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + } + FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, + "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, + "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + constexpr auto numConsumers = 4; constexpr auto numWarps = 1 + numConsumers; constexpr auto numStages = 3; @@ -511,7 +637,14 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t auto tensorState = tma::createTensorMap( params.state, params.state_cache_size * nh * dim, DSTATE, rowsPerStage, DSTATE); - scan_func<<>>(params, tensorState); + // Calculate shared memory size and opt-in to extended shared memory + using sram_t = SharedStorage; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, tensorState); }; auto dispatch_dstate = [&]() { diff --git a/tests/mamba/test_selective_state_update.py b/tests/mamba/test_selective_state_update.py index e50994646a..1a85930b9c 100644 --- a/tests/mamba/test_selective_state_update.py +++ b/tests/mamba/test_selective_state_update.py @@ -16,6 +16,7 @@ def create_test_inputs( input_dtype, weight_dtype, matrixA_dtype, + state_dtype, z_none=True, ): # Set seed for reproducibility @@ -29,7 +30,7 @@ def create_test_inputs( ssm_state_cache_size = max(384, int(2 * batch_size)) state_cache = torch.randn( - ssm_state_cache_size, nheads, dim, dstate, dtype=input_dtype, device=device + ssm_state_cache_size, nheads, dim, dstate, dtype=state_dtype, device=device ) x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) @@ -86,23 +87,22 @@ def create_test_inputs( @pytest.mark.parametrize("nheads", [8, 64]) @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dstate", [64, 128, 256]) -@pytest.mark.parametrize("ngroups", [8]) -@pytest.mark.parametrize("delta_softplus", [True]) -@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) -@pytest.mark.parametrize("weight_dtype", [torch.bfloat16]) -@pytest.mark.parametrize("matrixA_dtype", [torch.float32]) +@pytest.mark.parametrize("state_dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.bfloat16]) def test_selective_state_update( batch, nheads, dim, dstate, - ngroups, - delta_softplus, - input_dtype, + state_dtype, weight_dtype, - matrixA_dtype, ): """Test selective_state_update correctness against reference implementation.""" + ngroups = 8 + delta_softplus = True + input_dtype = torch.bfloat16 + matrixA_dtype = torch.float32 + inputs = create_test_inputs( batch, nheads, @@ -112,6 +112,7 @@ def test_selective_state_update( input_dtype, weight_dtype, matrixA_dtype, + state_dtype=state_dtype, z_none=True, ) @@ -238,6 +239,7 @@ def test_selective_state_update_with_z(): input_dtype = torch.bfloat16 weight_dtype = torch.bfloat16 matrixA_dtype = torch.float32 + state_dtype = torch.bfloat16 inputs = create_test_inputs( batch, @@ -248,6 +250,7 @@ def test_selective_state_update_with_z(): input_dtype, weight_dtype, matrixA_dtype, + state_dtype=state_dtype, z_none=False, ) @@ -285,8 +288,8 @@ def test_selective_state_update_with_z(): atol = 1e-3 rtol = 1e-2 - torch.testing.assert_allclose(y_ref, y_test, atol=atol, rtol=rtol) - torch.testing.assert_allclose( + torch.testing.assert_close(y_ref, y_test, atol=atol, rtol=rtol) + torch.testing.assert_close( state_ref[inputs["slot_idx"]], state[inputs["slot_idx"]], atol=atol,