diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 0fb132607680..d6e5766c83a3 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -22,6 +22,29 @@ namespace { } \ }() +#define AT_DISPATCH_BOOL2(BOOL_V1, BOOL_NAME1, BOOL_V2, BOOL_NAME2, ...) \ + [&] { \ + if (BOOL_V1) { \ + constexpr bool BOOL_NAME1 = true; \ + if (BOOL_V2) { \ + constexpr bool BOOL_NAME2 = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME2 = false; \ + return __VA_ARGS__(); \ + } \ + } else { \ + constexpr bool BOOL_NAME1 = false; \ + if (BOOL_V2) { \ + constexpr bool BOOL_NAME2 = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME2 = false; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + // dispatch: bfloat16, float16, int8_t, fp8_e4m3 #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ [&] { \ @@ -105,6 +128,8 @@ namespace { #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + // [NB] Parallel Routines // // * at::parallel_for - applies for most of generic use cases, this will be compiled @@ -321,4 +346,10 @@ struct Unroll<1> { } }; +// conditional data ptr for optional tensor +template +inline T* conditional_data_ptr(const std::optional& opt) { + return opt.has_value() ? opt.value().data_ptr() : nullptr; +} + } // anonymous namespace diff --git a/sgl-kernel/csrc/cpu/mamba/conv.cpp b/sgl-kernel/csrc/cpu/mamba/conv.cpp new file mode 100644 index 000000000000..aceb51b6142c --- /dev/null +++ b/sgl-kernel/csrc/cpu/mamba/conv.cpp @@ -0,0 +1,708 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ y, const scalar_t* __restrict__ x, int64_t size) { + using Vec = at::vec::Vectorized; + const bool is_padding = (x == nullptr); + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data_vec = is_padding ? Vec(0.f) : Vec::loadu(x + d); + data_vec.store(y + d); + } +} + +// no remainder +template +void inline update_conv_state( + scalar_t* __restrict__ conv_states, + const scalar_t* __restrict__ input, + int64_t width, + int64_t dim, + int64_t seqlen, + bool has_initial_states) { + // width for `conv_states` + int64_t width1 = width - 1; + int64_t w = 0; + for (; w < width1 - seqlen; ++w) { + scalar_t* y = conv_states + w * dim; + const scalar_t* x = has_initial_states ? conv_states + (w + seqlen) * dim : nullptr; + copy_stub(y, x, dim); + } + for (; w < width1; ++w) { + scalar_t* y = conv_states + w * dim; + const scalar_t* x = input + (w + seqlen - width1) * dim; + copy_stub(y, x, dim); + } +} + +// A : [M, BLOCK_N] +// B : [BLOCK_N, K], prepacked as [K/2, BLOCK_N, 2] +// C : [M, BLOCK_N] +// bias : [BLOCK_N] +// +// lda : leading dimension of `input` and `out` +// +template +struct tinygemm_kernel { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + const scalar_t* __restrict__ bias, + const scalar_t* __restrict__ conv_states, + bool has_initial_state, + int64_t M, + int64_t lda, + bool is_first_token) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::BFloat16* __restrict__ B, + at::BFloat16* __restrict__ C, + const at::BFloat16* __restrict__ bias, + const at::BFloat16* __restrict__ conv_states, + bool has_initial_state, + int64_t M, + int64_t lda, + bool is_first_token) { + assert(K == 4); + constexpr int ROWS = K; + constexpr int COLS = BLOCK_N / block_size_n(); + + // leading dimension size for b for next block [K/2, 32, 2] + constexpr int ldb = block_size_n() * K; + + __m512bh va[ROWS * COLS]; + __m512bh vb[ROWS * COLS]; + __m512 vc[COLS * 2]; + + // k: {-3, -2, -1} -> {0, 1, 2} + auto set_conv_states = [&](int k, int col) -> __m512i { + return has_initial_state ? _mm512_loadu_si512(conv_states + (k + K - 1) * lda + col * 32) + : _mm512_setzero_si512(); + }; + +#define MM512_LOAD_A(idx) \ + ((idx) < 0 && is_first_token) ? (__m512bh)(set_conv_states((idx), col)) \ + : (__m512bh)(_mm512_loadu_si512(A + (idx) * lda + col * 32)) + +#define MM512_PACK_A(ap, bp, a, b) \ + do { \ + __m512i r0 = (__m512i)(a); \ + __m512i r1 = (__m512i)(b); \ + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); \ + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); \ + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); \ + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); \ + (ap) = (__m512bh)_mm512_shuffle_i32x4(r0, r1, 0x88); \ + (bp) = (__m512bh)_mm512_shuffle_i32x4(r0, r1, 0xdd); \ + } while (0) + + // step 0 : preload a at time step [-3][-2][-1] + auto preloada = [&](auto i) { + constexpr int col = i; + int64_t m = 0; + va[1 * COLS + col] = MM512_LOAD_A(m - 3); + va[2 * COLS + col] = MM512_LOAD_A(m - 2); + va[3 * COLS + col] = MM512_LOAD_A(m - 1); + }; + Unroll{}(preloada); + + auto loada = [&](auto i, int64_t m) { + constexpr int col = i; + // update previous time step + va[0 * COLS + col] = va[1 * COLS + col]; + va[1 * COLS + col] = va[2 * COLS + col]; + va[2 * COLS + col] = va[3 * COLS + col]; + // load current time step + va[3 * COLS + col] = MM512_LOAD_A(m); + }; + + // step 1 : load weight for just once + auto loadb = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vb[row * COLS + col] = (__m512bh)(_mm512_loadu_si512(B + col * ldb + row * 32)); + }; + Unroll{}(loadb); + + // [NB] accumulates 4x32 bfloat16 blocks + // + // +------------+------------+ + // | col0 | col1 | + // +------------+------------+ + // | va0 va1 | va0 va1 | + // | va2 va3 | va2 va3 | + // +------------+------------+ + // | vc0 vc1 | vc0 vc1 | + // +------------+------------+ + // + // * va and vb shares the same memory layout + // * block_n 32 with 4 rows equals to 4 registers + // * 37 uops with avx512bf16 v.s. 57 uops with avx512f + // + auto compute = [&](auto i) { + constexpr int col = i; + + // init accumulators + if constexpr (has_bias) { + __m512i b16 = _mm512_loadu_si512(reinterpret_cast(bias + col * 32)); + vc[col * 2 + 0] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); + vc[col * 2 + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); + } else { + vc[col * 2 + 0] = _mm512_set1_ps(0.f); + vc[col * 2 + 1] = _mm512_set1_ps(0.f); + } + + // convert to vnni2 format + __m512bh va0, va1, va2, va3; + MM512_PACK_A(va0, va1, va[0 * COLS + col], va[1 * COLS + col]); + MM512_PACK_A(va2, va3, va[2 * COLS + col], va[3 * COLS + col]); + + // accumulate + vc[col * 2 + 0] = _mm512_dpbf16_ps(vc[col * 2 + 0], va0, vb[0 * COLS + col]); + vc[col * 2 + 0] = _mm512_dpbf16_ps(vc[col * 2 + 0], va2, vb[2 * COLS + col]); + vc[col * 2 + 1] = _mm512_dpbf16_ps(vc[col * 2 + 1], va1, vb[1 * COLS + col]); + vc[col * 2 + 1] = _mm512_dpbf16_ps(vc[col * 2 + 1], va3, vb[3 * COLS + col]); + }; + + using fVec = at::vec::Vectorized; + using bVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + auto storec = [&](auto i, int64_t m) { + constexpr int col = i; + fVec x0 = fVec(vc[col * 2 + 0]); + fVec x1 = fVec(vc[col * 2 + 1]); + if constexpr (has_silu) { + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + } + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(C + m * lda + col * 32); + }; + + for (int64_t m = 0; m < M; ++m) { + // step 3.a : load a at current time step + Unroll{}(loada, m); + // step 3.b : accumulate for window size (4) + Unroll{}(compute); + // step 3.c : store c at current time step + Unroll{}(storec, m); + } + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL(K, NB_SIZE) \ + tinygemm_kernel::apply( \ + input + bs * seqlen * dim + mb_start * dim + nb_start, \ + weight + nb_start * width, \ + out + bs * seqlen * dim + mb_start * dim + nb_start, \ + has_bias ? bias + nb_start : nullptr, \ + has_conv_states ? conv_states + conv_state_index * (K - 1) * dim + nb_start : nullptr, \ + has_initial_states_value, \ + mb_size, \ + dim, \ + mb_start == 0); + +template +void causal_conv1d_fwd_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, + const scalar_t* __restrict__ bias, + scalar_t* __restrict__ conv_states, + const int32_t* __restrict__ conv_indices, + const bool* __restrict__ has_initial_state, + bool silu_activation, + int64_t batch, + int64_t dim, + int64_t seqlen, + int64_t width, + int64_t num_seq_blocks) { + // handle 32 x 64 per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n() * 2; + const int64_t NB = div_up(dim, BLOCK_N); + + const int64_t num_blocks_per_seq = div_up(seqlen, BLOCK_M); + const bool has_conv_states = conv_states != nullptr; + const bool has_conv_indices = conv_indices != nullptr; + + // parallel on [batch, seq, NB] + AT_DISPATCH_BOOL2(bias != nullptr, has_bias, silu_activation, has_silu, [&] { + at::parallel_for(0, num_seq_blocks * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, num_seq_blocks, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + int64_t bs = mb / num_blocks_per_seq; + + int64_t mb_start = (mb % num_blocks_per_seq) * BLOCK_M; + int64_t mb_size = std::min(seqlen - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(dim - nb_start, BLOCK_N); + + const bool has_initial_states_value = has_conv_states ? has_initial_state[bs] : false; + int32_t conv_state_index = has_conv_indices ? conv_indices[bs] : bs; + + switch (width << 4 | nb_size >> 4) { + case 0x42: + LAUNCH_TINYGEMM_KERNEL(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", width, " x ", nb_size); + } + + // move to the next index + data_index_step(mb, num_seq_blocks, nb, NB); + } + }); + }); + + // update conv_states if necessary + if (has_conv_states) { + at::parallel_for(0, batch, 0, [&](int64_t begin, int64_t end) { + for (int64_t bs = begin; bs < end; ++bs) { + update_conv_state( + conv_states + bs * (width - 1) * dim, input + bs * seqlen * dim, width, dim, seqlen, has_initial_state[bs]); + } + }); + } +} + +#define LAUNCH_TINYGEMM_VARLEN_KERNEL(K, NB_SIZE) \ + tinygemm_kernel::apply( \ + input + batch_offset * dim + mb_start * dim + nb_start, \ + weight + nb_start * width, \ + out + batch_offset * dim + mb_start * dim + nb_start, \ + has_bias ? bias + nb_start : nullptr, \ + nullptr, \ + false, \ + mb_size, \ + dim, \ + mb_start == 0); + +// TODO: add `has_initial_state` support for varlen kernel +template +void causal_conv1d_fwd_varlen_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, + const scalar_t* __restrict__ bias, + scalar_t* __restrict__ conv_states, + const int32_t* __restrict__ query_start_loc, + const int32_t* __restrict__ conv_indices, + const bool* __restrict__ has_initial_state, + const int32_t* __restrict__ block_indices, + bool silu_activation, + int64_t batch, + int64_t dim, + int64_t width, + int64_t num_seq_blocks) { + // handle 32 x 64 per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n() * 2; + const int64_t NB = div_up(dim, BLOCK_N); + + const bool has_conv_states = conv_states != nullptr; + const bool has_conv_indices = conv_indices != nullptr; + + // parallel on [batch, seq, NB] + AT_DISPATCH_BOOL2(bias != nullptr, has_bias, silu_activation, has_silu, [&] { + at::parallel_for(0, num_seq_blocks * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, num_seq_blocks, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + int32_t bs = block_indices[mb * 2 + 0]; + int32_t batch_offset = query_start_loc[bs]; + int32_t seqlen = query_start_loc[bs + 1] - query_start_loc[bs]; + + int64_t mb_start = block_indices[mb * 2 + 1] * BLOCK_M; + int64_t mb_size = std::min(seqlen - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(dim - nb_start, BLOCK_N); + + switch (width << 4 | nb_size >> 4) { + case 0x42: + LAUNCH_TINYGEMM_VARLEN_KERNEL(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_VARLEN_KERNEL(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", width, " x ", nb_size); + } + + // move to the next index + data_index_step(mb, num_seq_blocks, nb, NB); + } + }); + }); + + // update conv_states if necessary + if (has_conv_states) { + at::parallel_for(0, batch, 0, [&](int64_t begin, int64_t end) { + for (int64_t bs = begin; bs < end; ++bs) { + int32_t conv_state_index = has_conv_indices ? conv_indices[bs] : bs; + int32_t seqlen = query_start_loc[bs + 1] - query_start_loc[bs]; + int32_t batch_offset = query_start_loc[bs]; + update_conv_state( + conv_states + conv_state_index * (width - 1) * dim, + input + batch_offset * dim, + width, + dim, + seqlen, + /* has_initial_state */ false); + } + }); + } +} + +template +void causal_conv1d_update_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + scalar_t* __restrict__ conv_states, + const scalar_t* __restrict__ weight, + const scalar_t* __restrict__ bias, + const int32_t* __restrict__ conv_indices, + bool silu_activation, + int64_t batch, + int64_t dim, + int64_t seqlen, + int64_t width) { + // handle 32 x 64 per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n() * 2; + const int64_t NB = div_up(dim, BLOCK_N); + + const bool has_conv_states = conv_states != nullptr; + const bool has_conv_indices = conv_indices != nullptr; + + // parallel on [batch, NB] + AT_DISPATCH_BOOL2(bias != nullptr, has_bias, silu_activation, has_silu, [&] { + at::parallel_for(0, batch * NB, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, nb{0}; + data_index_init(begin, bs, batch, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + int64_t mb_start = 0; + int64_t mb_size = 1; + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(dim - nb_start, BLOCK_N); + + const bool has_initial_states_value = true; + int32_t conv_state_index = has_conv_indices ? conv_indices[bs] : bs; + + switch (width << 4 | nb_size >> 4) { + case 0x42: + LAUNCH_TINYGEMM_KERNEL(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", width, " x ", nb_size); + } + + // move to the next index + data_index_step(bs, batch, nb, NB); + } + }); + }); + +#define CONV_STATE_INDEXR(w) conv_states + conv_state_index*(width - 1) * dim + (w) * dim + + // update conv_states + at::parallel_for(0, batch, 0, [&](int64_t begin, int64_t end) { + for (int64_t bs = begin; bs < end; ++bs) { + // update old states, range [1, width - 1) + int32_t conv_state_index = has_conv_indices ? conv_indices[bs] : bs; + for (int64_t w = 1; w < width - 1; ++w) { + std::memcpy(CONV_STATE_INDEXR(w - 1), CONV_STATE_INDEXR(w), dim * sizeof(scalar_t)); + } + // copy new states + std::memcpy(CONV_STATE_INDEXR(width - 2), input + bs * dim, dim * sizeof(scalar_t)); + } + }); +} + +} // anonymous namespace + +// from [dim, width] or [N, K] +// to [N/BLOCK_N, K/2, BLOCK_N, 2] +at::Tensor causal_conv1d_weight_pack(const at::Tensor& weight) { + CHECK_INPUT(weight); + + int64_t dim = weight.size(0); + int64_t width = weight.size(1); + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(width == 4, "causal_conv1d_weight_pack: support only width of 4"); + TORCH_CHECK(dim % BLOCK_N == 0, "causal_conv1d_weight_pack: invalid dim size ", dim); + + const int64_t N = dim, K2 = width >> 1; + const int64_t NB = div_up(N, BLOCK_N); + + auto packed_weight = at::empty_like(weight); + AT_DISPATCH_REDUCED_FLOATING_TYPES(weight.scalar_type(), "causal_conv1d_fwd_kernel_impl", [&] { + // cast to float32 as vnni size is 2 + const float* w_data = reinterpret_cast(weight.data_ptr()); + float* packed_data = reinterpret_cast(packed_weight.data_ptr()); + + at::parallel_for(0, NB * K2 * BLOCK_N, 0, [&](int64_t begin, int64_t end) { + int64_t nb{0}, k2{0}, n{0}; + data_index_init(begin, nb, NB, k2, K2, n, BLOCK_N); + + // TODO: optimize this if we need to online prepacking. + for (int64_t i = begin; i < end; ++i) { + packed_data[i] = w_data[nb * BLOCK_N * K2 + n * K2 + k2]; + + // move to the next index + data_index_step(nb, NB, k2, K2, n, BLOCK_N); + } + }); + }); + return packed_weight; +} + +#define CHECK_OPTIONAL_SHAPE_DTYPE(OPT, SIZE, DTYPE) \ + if (OPT.has_value()) { \ + const auto tensor = OPT.value(); \ + CHECK_CONTIGUOUS(tensor); \ + CHECK_EQ(tensor.size(0), SIZE); \ + CHECK_EQ(tensor.scalar_type(), DTYPE); \ + } + +template +int64_t get_block_count(const std::optional& offsets, int64_t batch, int64_t seqlen) { + if (offsets.has_value()) { + const int32_t* offsets_data = offsets.value().data_ptr(); + int32_t num_seq_blocks = 0; + for (int64_t row = 0; row < batch; ++row) { + num_seq_blocks += div_up(offsets_data[row + 1] - offsets_data[row], BLOCK_M); + } + return num_seq_blocks; + } + return batch * div_up(seqlen, int64_t(BLOCK_M)); +} + +template +at::Tensor get_block_indices(const std::optional& offsets, int64_t num_seq_blocks) { + if (!offsets.has_value()) { + return at::Tensor(); + } + + const at::Tensor& offsets_ = offsets.value(); + at::Tensor indices = at::empty({num_seq_blocks, 2}, offsets_.options()); + + int64_t batch = offsets_.size(0) - 1; + + const int32_t* offsets_data = offsets_.data_ptr(); + int32_t* indices_data = indices.data_ptr(); + + int64_t idx = 0; + for (int32_t row = 0; row < batch; ++row) { + int32_t blocks = div_up(offsets_data[row + 1] - offsets_data[row], BLOCK_M); + + for (int32_t col = 0; col < blocks; ++col) { + indices_data[idx * 2 + 0] = row; + indices_data[idx * 2 + 1] = col; + idx++; + } + } + return indices; +} + +// API aligned with GPUs +// +// x: (batch, dim, seqlen) or (dim, cu_seq_len) for varlen +// weight: (dim, width) +// bias: (dim,) +// query_start_loc: (batch + 1) int32 +// cache_indices: (batch) int32 +// has_initial_state: (batch) bool +// conv_states: (..., dim, width - 1) itype +// activation: either None or "silu" or "swish" +// pad_slot_id: int +// +at::Tensor causal_conv1d_fwd_cpu( + const at::Tensor& x, + const at::Tensor& weight, + const std::optional& bias, + const std::optional& conv_states, + const std::optional& query_start_loc, + const std::optional& conv_state_indices, + const std::optional& has_initial_state, + bool silu_activation, + int64_t pad_slot_id, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::causal_conv1d_fwd_cpu", std::vector({x, weight, bias})); + + CHECK_CONTIGUOUS(weight); + auto packed_w = is_vnni ? weight : causal_conv1d_weight_pack(weight); + + const bool is_var_seqlen = query_start_loc.has_value(); + const int64_t input_ndim = is_var_seqlen ? 2 : 3; + TORCH_CHECK(x.dim() == input_ndim, "causal_conv1d_fwd_cpu: expect x to be ", input_ndim, "D tensor."); + TORCH_CHECK(x.stride(-2) == 1 && x.stride(-1) == x.size(-2), "causal_conv1d_fwd_cpu: expect x to be transposed."); + + const int64_t batch = is_var_seqlen ? query_start_loc.value().size(0) - 1 : x.size(0); + const int64_t dim = x.size(-2); + const int64_t seqlen = x.size(-1); + const int64_t width = weight.size(-1); + + const auto scalar_type = x.scalar_type(); + CHECK_EQ(weight.scalar_type(), scalar_type); + CHECK_OPTIONAL_SHAPE_DTYPE(bias, dim, scalar_type); + CHECK_OPTIONAL_SHAPE_DTYPE(query_start_loc, batch + 1, at::kInt); + CHECK_OPTIONAL_SHAPE_DTYPE(conv_state_indices, batch, at::kInt); + CHECK_OPTIONAL_SHAPE_DTYPE(has_initial_state, batch, at::kBool); + + if (conv_states.has_value()) { + auto& conv_states_val = conv_states.value(); + int64_t padded_batch = conv_states_val.size(0); + CHECK_EQ(conv_states_val.scalar_type(), scalar_type); + CHECK_GE(padded_batch, batch); + CHECK_EQ(conv_states_val.size(1), dim); + CHECK_EQ(conv_states_val.size(2), width - 1); + + // adjust `conv_states` to be contiguous on `dim` + // should happen only once + if (conv_states_val.stride(-2) != 1) { + auto conv_states_copy = conv_states_val.clone(); + conv_states_val.as_strided_({padded_batch, dim, width - 1}, {(width - 1) * dim, 1, dim}); + conv_states_val.copy_(conv_states_copy); + } + } + + // block size for sequence blocks, 32 + constexpr int64_t BLOCK_M = block_size_m(); + + // total number of sequence blocks + int64_t num_seq_blocks = get_block_count(query_start_loc, batch, seqlen); + + at::Tensor out = at::empty_like(x); + AT_DISPATCH_REDUCED_FLOATING_TYPES(scalar_type, "causal_conv1d_fwd_kernel_impl", [&] { + if (is_var_seqlen) { + // record seq blocks in Coordinate format, aka [num_seq_blocks, 2] + at::Tensor block_indices = get_block_indices(query_start_loc, num_seq_blocks); + + causal_conv1d_fwd_varlen_kernel_impl( + out.data_ptr(), + x.data_ptr(), + packed_w.data_ptr(), + conditional_data_ptr(bias), + conditional_data_ptr(conv_states), + conditional_data_ptr(query_start_loc), + conditional_data_ptr(conv_state_indices), + conditional_data_ptr(has_initial_state), + block_indices.data_ptr(), + silu_activation, + batch, + dim, + width, + num_seq_blocks); + } else { + causal_conv1d_fwd_kernel_impl( + out.data_ptr(), + x.data_ptr(), + packed_w.data_ptr(), + conditional_data_ptr(bias), + conditional_data_ptr(conv_states), + conditional_data_ptr(conv_state_indices), + conditional_data_ptr(has_initial_state), + silu_activation, + batch, + dim, + seqlen, + width, + num_seq_blocks); + } + }); + return out; +} + +// API aligned with GPUs +// +// x: (batch, dim) or (batch, dim, seqlen) +// conv_state: (..., dim, state_len), where state_len >= width - 1 +// weight: (dim, width) +// bias: (dim,) +// cache_seqlens: (batch,), dtype int32. +// conv_state_indices: (batch,), dtype int32 +// pad_slot_id: int +// out: (batch, dim) or (batch, dim, seqlen) +// +at::Tensor causal_conv1d_update_cpu( + const at::Tensor& x, + const at::Tensor& conv_states, + const at::Tensor& weight, + const std::optional& bias, + bool silu_activation, + const std::optional& cache_seqlens, + const std::optional& conv_state_indices, + int64_t pad_slot_id, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::causal_conv1d_update_cpu", std::vector({x, weight, bias})); + + CHECK_CONTIGUOUS(x); + CHECK_CONTIGUOUS(weight); + auto packed_w = is_vnni ? weight : causal_conv1d_weight_pack(weight); + + // TODO: add multi-token prediction support + TORCH_CHECK(x.dim() == 2, "causal_conv1d_update_cpu: expect x to be 2D tensor."); + TORCH_CHECK(!cache_seqlens.has_value(), "causal_conv1d_update_cpu: don't support cache_seqlens."); + + int64_t batch = x.size(0); + int64_t dim = x.size(1); + int64_t seqlen = 1; + int64_t width = weight.size(-1); + + const auto scalar_type = x.scalar_type(); + CHECK_EQ(weight.scalar_type(), scalar_type); + CHECK_OPTIONAL_SHAPE_DTYPE(bias, dim, scalar_type); + CHECK_OPTIONAL_SHAPE_DTYPE(conv_state_indices, batch, at::kInt); + + CHECK_EQ(conv_states.scalar_type(), scalar_type); + CHECK_EQ(conv_states.size(1), dim); + CHECK_EQ(conv_states.size(2), width - 1); + + // adjust `conv_states` to be contiguous on `dim` + if (conv_states.stride(-2) != 1) { + int64_t num_cache_lines = conv_states.size(0); + auto conv_states_copy = conv_states.clone(); + conv_states.as_strided_({num_cache_lines, dim, width - 1}, {(width - 1) * dim, 1, dim}); + conv_states.copy_(conv_states_copy); + } + + at::Tensor out = at::empty_like(x); + AT_DISPATCH_REDUCED_FLOATING_TYPES(scalar_type, "causal_conv1d_update_kernel_impl", [&] { + causal_conv1d_update_kernel_impl( + out.data_ptr(), + x.data_ptr(), + conv_states.data_ptr(), + packed_w.data_ptr(), + conditional_data_ptr(bias), + conditional_data_ptr(conv_state_indices), + silu_activation, + batch, + dim, + seqlen, + width); + }); + return out; +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index a620b930caa5..44c32659e214 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -229,6 +229,32 @@ std::tuple qkv_proj_with_rope_fused_weight( int64_t kv_lora_rank, int64_t qk_rope_head_dim); +// mamba causal conv1d +at::Tensor causal_conv1d_weight_pack(const at::Tensor& weight); + +at::Tensor causal_conv1d_fwd_cpu( + const at::Tensor& x, + const at::Tensor& weight, + const std::optional& bias, + const std::optional& conv_states, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + bool silu_activation, + int64_t pad_slot_id, + bool is_vnni); + +at::Tensor causal_conv1d_update_cpu( + const at::Tensor& x, + const at::Tensor& conv_states, + const at::Tensor& weight, + const std::optional& bias, + bool silu_activation, + const std::optional& cache_seqlens, + const std::optional& conv_state_indices, + int64_t pad_slot_id, + bool is_vnni); + // shared memory init void initialize(int64_t size, int64_t rank); @@ -383,6 +409,21 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu); + // causal conv1d + m.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor"); + m.impl("causal_conv1d_weight_pack", torch::kCPU, &causal_conv1d_weight_pack); + + m.def( + "causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? conv_states, Tensor? query_start_loc," + "Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, int pad_slot_id, bool is_vnni) -> " + "Tensor"); + m.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu); + + m.def( + "causal_conv1d_update_cpu(Tensor x, Tensor conv_states, Tensor weight, Tensor? bias, bool silu_activation," + "Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, bool is_vnni) -> Tensor"); + m.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu); + // all reduce m.def("initialize(int size, int rank) -> ()"); m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()"); diff --git a/test/srt/cpu/test_causal_conv1d.py b/test/srt/cpu/test_causal_conv1d.py new file mode 100644 index 000000000000..7f705cab76a3 --- /dev/null +++ b/test/srt/cpu/test_causal_conv1d.py @@ -0,0 +1,327 @@ +import unittest +from typing import Optional + +import sgl_kernel # noqa: F401 +import torch +import torch.nn.functional as F +from utils import parametrize, precision + +from sglang.test.test_utils import CustomTestCase + +causal_conv1d_weight_pack = torch.ops.sgl_kernel.causal_conv1d_weight_pack +causal_conv1d_fwd = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu +causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu + + +torch.manual_seed(1234) + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + + x_new = torch.cat([conv_state, x], dim=-1) + conv_state.copy_(x_new[:, :, -state_len:]) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] + + out = out.squeeze(-1) + return out if activation is None else F.silu(out) + + +class TestCausalConv1d(CustomTestCase): + activation = "silu" + + @parametrize( + batch=[1, 1024], + dim=[96, 512], + seqlen=[2, 36], + width=[4], + has_bias=[True, False], + has_initial_state=[True, False], + ) + def test_causal_conv1d( + self, + batch, + dim, + seqlen, + width, + has_bias, + has_initial_state, + dtype=torch.bfloat16, + prepack=True, + ): + x = torch.randn(batch, seqlen, dim).to(dtype).transpose_(-1, -2) + weight = torch.randn(dim, width).to(dtype) + bias = torch.randn(dim).to(dtype) if has_bias else None + + if has_initial_state: + initial_states = torch.randn(batch, dim, width - 1, dtype=dtype) + has_initial_state_tensor = torch.ones(batch, dtype=torch.bool) + else: + initial_states = None + has_initial_state_tensor = None + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + out_ref, final_states_ref = causal_conv1d_ref( + x, + weight, + bias, + initial_states, + return_final_states=has_initial_state, + activation=self.activation, + ) + + out = causal_conv1d_fwd( + x, + packed_weight, + bias, + initial_states, + None, + None, + has_initial_state_tensor, + self.activation in ["silu"], + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close( + final_states_ref, initial_states, atol=atol, rtol=rtol + ) + + @parametrize( + batch=[11], + dim=[96], + max_seqlen=[66], + width=[4], + ) + def test_causal_conv1d_varlen( + self, + batch, + dim, + max_seqlen, + width, + has_bias=False, + dtype=torch.bfloat16, + prepack=False, + ): + total_entries = batch + 3 + + seqlens = torch.randint(1, max_seqlen, (batch + 1,)) + seqlens[0] = 0 + # 1 or 2 must test + seqlens[-2] = 2 + + query_start_loc = torch.cumsum(seqlens, dim=0).to(torch.int32) + + seqlen = query_start_loc[-1].item() + x = torch.randn(seqlen, dim, dtype=dtype).transpose_(-1, -2) + weight = torch.randn(dim, width, dtype=dtype) + bias = torch.randn(dim, dtype=dtype) if has_bias else None + + final_states = torch.randn(total_entries, dim, width - 1, dtype=dtype) + final_states_ref = final_states.clone() + + has_initial_states = torch.randint(0, 2, (batch,), dtype=torch.bool).fill_( + False + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32)[:batch] + + out_ref = [] + out_ref_b = [] + + return_final_states = final_states is not None + splits = torch.split(x, seqlens[1:].tolist(), dim=1) + for i, x_s in enumerate(splits): + out_ref_b.append( + causal_conv1d_ref( + x_s.unsqueeze(0), + weight, + bias, + activation=self.activation, + return_final_states=return_final_states, + final_states_out=( + final_states_ref[state_indices[i]].unsqueeze(0) + if return_final_states + else None + ), + initial_states=( + final_states_ref[state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None + ), + ) + ) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) + out_ref_tensor = torch.cat(out_ref, dim=0).squeeze(0) + + out = causal_conv1d_fwd( + x, + weight, + bias, + final_states, + query_start_loc, + state_indices, + has_initial_states, + self.activation in ["silu"], + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref_tensor, out, atol=atol, rtol=rtol) + torch.testing.assert_close(final_states_ref, final_states, atol=atol, rtol=rtol) + + @parametrize( + batch=[11], + dim=[32, 64, 96], + width=[4], + ) + def test_causal_conv1d_update( + self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True + ): + x = torch.randn(batch, dim).to(dtype) + conv_state = torch.randn(batch, dim, width - 1, dtype=dtype) + weight = torch.randn(dim, width).to(dtype) + bias = torch.randn(dim).to(dtype) if has_bias else None + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + conv_state_ref = conv_state.clone() + out_ref = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias, activation=self.activation + ) + + cache_seqlens = None + conv_state_indices = None + out = causal_conv1d_update( + x, + conv_state, + packed_weight, + bias, + self.activation in ["silu"], + cache_seqlens, + conv_state_indices, + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close(conv_state_ref, conv_state, atol=atol, rtol=rtol) + + @parametrize( + batch=[7], + dim=[96], + width=[4], + ) + def test_causal_conv1d_update_with_batch_gather( + self, batch, dim, width, has_bias=False, dtype=torch.bfloat16, prepack=True + ): + total_entries = batch + 3 + + x = torch.randn(batch, dim).to(dtype=dtype) + + conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32) + conv_state = torch.randn(total_entries, dim, width - 1, dtype=dtype) + + weight = torch.randn(dim, width).to(dtype=dtype) + bias = torch.randn(dim).to(dtype=dtype) if has_bias else None + conv_state_ref = conv_state[conv_state_indices, :] + + packed_weight = causal_conv1d_weight_pack(weight) if prepack else weight + + out_ref = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias, activation=self.activation + ) + + cache_seqlens = None + out = causal_conv1d_update( + x, + conv_state, + packed_weight, + bias, + self.activation in ["silu"], + cache_seqlens, + conv_state_indices, + PAD_SLOT_ID, + prepack, + ) + + atol = rtol = precision[dtype] + torch.testing.assert_close(out_ref, out, atol=atol, rtol=rtol) + torch.testing.assert_close( + conv_state_ref, conv_state[conv_state_indices, :], atol=atol, rtol=rtol + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index 6435dad746ce..8f03c1bc9ce1 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -1,3 +1,4 @@ +import itertools import math import torch @@ -15,6 +16,19 @@ fp8_max, fp8_min = 400, -400 +def parametrize(**params): + def decorator(func): + def wrapper(self): + for combo in itertools.product(*params.values()): + kwargs = dict(zip(params.keys(), combo)) + with self.subTest(**kwargs): + func(self, **kwargs) + + return wrapper + + return decorator + + def SiluAndMul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 2b2edf540bce..0e45ff663833 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -323,14 +323,18 @@ } # Add Intel Xeon tests -# NOTE: please sort the test cases alphabetically by the test file name suite_xeon = { "per-commit-cpu": [ TestFile("cpu/test_activation.py"), TestFile("cpu/test_binding.py"), + TestFile("cpu/test_causal_conv1d.py"), + TestFile("cpu/test_cpu_graph.py"), TestFile("cpu/test_decode.py"), TestFile("cpu/test_extend.py"), TestFile("cpu/test_gemm.py"), + TestFile("cpu/test_intel_amx_attention_backend_a.py"), + TestFile("cpu/test_intel_amx_attention_backend_b.py"), + TestFile("cpu/test_intel_amx_attention_backend_c.py"), TestFile("cpu/test_mamba.py"), TestFile("cpu/test_mla.py"), TestFile("cpu/test_moe.py"), @@ -340,10 +344,6 @@ TestFile("cpu/test_rope.py"), TestFile("cpu/test_shared_expert.py"), TestFile("cpu/test_topk.py"), - TestFile("cpu/test_cpu_graph.py"), - TestFile("cpu/test_intel_amx_attention_backend_a.py"), - TestFile("cpu/test_intel_amx_attention_backend_b.py"), - TestFile("cpu/test_intel_amx_attention_backend_c.py"), ], }