diff --git a/sgl-kernel/csrc/cpu/norm.cpp b/sgl-kernel/csrc/cpu/norm.cpp index 391a0d4e525..88faafd5b6a 100644 --- a/sgl-kernel/csrc/cpu/norm.cpp +++ b/sgl-kernel/csrc/cpu/norm.cpp @@ -4,6 +4,67 @@ namespace { // NB: avoid using `at::vec::map<>` on bfloat16 or half +// Llama4TextL2Norm +template +void l2norm_kernel_impl( + scalar_t* __restrict__ output, + const scalar_t* __restrict__ input, + int64_t batch_size, + int64_t hidden_size, + float eps = 1e-5) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int kVecSize = bVec::size(); + at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // local ptrs + scalar_t* __restrict__ out_ptr = output + i * hidden_size; + const scalar_t* __restrict__ input_ptr = input + i * hidden_size; + + fVec sum_fvec = fVec(float(0)); + float sum_val = float(0); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec += x_fvec0 * x_fvec0; + sum_fvec += x_fvec1 * x_fvec1; + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = static_cast(input_ptr[d]); + sum_val += x_val * x_val; + } + + sum_val += vec_reduce_sum(sum_fvec); + float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps); + const fVec scale_fvec = fVec(rsqrt_var); + +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + x_fvec0 = x_fvec0 * scale_fvec; + x_fvec1 = x_fvec1 * scale_fvec; + + bVec out_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + out_bvec.store(out_ptr + d); + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = static_cast(input_ptr[d]); + out_ptr[d] = static_cast(x_val * rsqrt_var); + } + } + }); +} template void rmsnorm_kernel_impl( scalar_t* __restrict__ output, @@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl( } // anonymous namespace +// input : {batch_size, hidden_size} +at::Tensor l2norm_cpu(at::Tensor& input, double eps) { + RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector({input})); + + CHECK_INPUT(input); + CHECK_DIM(2, input); + int64_t batch_size = input.size(0); + int64_t hidden_size = input.size(1); + at::Tensor output = at::empty_like(input); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] { + l2norm_kernel_impl(output.data_ptr(), input.data_ptr(), batch_size, hidden_size, eps); + }); + return output; +} + // input : {batch_size, hidden_size} // weight: {hidden_size} at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { diff --git a/sgl-kernel/csrc/cpu/rope.cpp b/sgl-kernel/csrc/cpu/rope.cpp index 64bc297fe25..1c624946653 100644 --- a/sgl-kernel/csrc/cpu/rope.cpp +++ b/sgl-kernel/csrc/cpu/rope.cpp @@ -4,126 +4,343 @@ namespace { template -void rope_kernel_impl( - scalar_t* __restrict__ q_pe_out, - scalar_t* __restrict__ k_pe_out, - int64_t* __restrict__ t_pos, - scalar_t* __restrict__ q_pe, - scalar_t* __restrict__ k_pe, - scalar_t* __restrict__ t_emb_pos, - int64_t seq_len, - int64_t num_head, +void rotary_embedding_3D_kernel_impl( + scalar_t* __restrict__ query_out, + scalar_t* __restrict__ key_out, + int64_t* __restrict__ positions, + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ cos_sin_cache, + int64_t num_tokens, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, int64_t rotary_dim, - int64_t HR, - int64_t q_pe_stride_s, - int64_t out_stride_qs, - int64_t out_stride_ks, - int64_t HK, - int64_t k_pe_stride_s, - int64_t q_pe_stride_n, - int64_t out_stride_qn) { + int64_t query_stride_s, + int64_t query_out_stride_s, + int64_t key_out_stride_s, + int64_t key_stride_s, + int64_t query_stride_h, + int64_t query_out_stride_h) { + int64_t HR = rotary_dim; + int64_t HK = rotary_dim; int64_t COFF = HR / 2; - at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { + at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { int64_t seq{0}, head_id{0}; - data_index_init(begin, seq, seq_len, head_id, num_head); + data_index_init(begin, seq, num_tokens, head_id, num_heads); for (int64_t i = begin; i < end; ++i) { - int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n; - int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn; - int64_t out_offset_k = seq * out_stride_ks; + int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h; + int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h; + int64_t out_offset_k = seq * key_out_stride_s; int64_t p = 0; scalar_t* sin_start = nullptr; scalar_t* cos_start = nullptr; // step 0) get the rotary position embedding for the current position - p = t_pos[seq]; - sin_start = t_emb_pos + p * HR + COFF; - cos_start = t_emb_pos + p * HR; + p = positions[seq]; + sin_start = cos_sin_cache + p * HR + COFF; + cos_start = cos_sin_cache + p * HR; // step 1) apply_rotary_pos_emb for the rotary_dim elements in every // head of query/key for (int64_t h = 0; h < rotary_dim; h += 2) { scalar_t cos = cos_start[h >> 1]; scalar_t sin = sin_start[h >> 1]; - scalar_t in1 = q_pe[in_offset_q + h]; - scalar_t in2 = q_pe[in_offset_q + h + 1]; + scalar_t in1 = query[in_offset_q + h]; + scalar_t in2 = query[in_offset_q + h + 1]; scalar_t out1 = in1 * cos - in2 * sin; scalar_t out2 = in2 * cos + in1 * sin; - q_pe_out[out_offset_q + h] = out1; - q_pe_out[out_offset_q + h + 1] = out2; + query_out[out_offset_q + h] = out1; + query_out[out_offset_q + h + 1] = out2; } for (int64_t h = 0; h < HK; h += 2) { scalar_t cos = cos_start[h >> 1]; scalar_t sin = sin_start[h >> 1]; - int64_t k_pe_offset = seq * k_pe_stride_s; - scalar_t in1_k = k_pe[k_pe_offset + h]; - scalar_t in2_k = k_pe[k_pe_offset + h + 1]; + int64_t k_pe_offset = seq * key_stride_s; + scalar_t in1_k = key[k_pe_offset + h]; + scalar_t in2_k = key[k_pe_offset + h + 1]; scalar_t out1_k = in1_k * cos - in2_k * sin; scalar_t out2_k = in2_k * cos + in1_k * sin; - k_pe_out[out_offset_k + h] = out1_k; - k_pe_out[out_offset_k + h + 1] = out2_k; + key_out[out_offset_k + h] = out1_k; + key_out[out_offset_k + h + 1] = out2_k; } // move to the next index - data_index_step(seq, seq_len, head_id, num_head); + data_index_step(seq, num_tokens, head_id, num_heads); } }); } + +template +void rotary_embedding_neox_2D_kernel_impl( + int64_t* __restrict__ positions, + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ cos_sin_cache, + int64_t rotary_dim, + int64_t query_stride_s, + int64_t key_stride_s, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + int64_t num_tokens) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int64_t bVecSize = bVec::size(); + + int64_t embed_dim = rotary_dim / 2; + bool flag = (embed_dim % bVecSize == 0); + int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize; + + auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) { + int64_t j = 0; + for (; j < loop_upper; j += bVecSize) { + int64_t rot_offset = j; + int64_t x_index = rot_offset; + int64_t y_index = embed_dim + rot_offset; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + bVec _cos = bVec::loadu(cache_ptr + x_index); + bVec _sin = bVec::loadu(cache_ptr + y_index); + + bVec _q_x = bVec::loadu(qk + out_x); + bVec _q_y = bVec::loadu(qk + out_y); + fVec _cos_0, _cos_1; + std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos); + fVec _sin_0, _sin_1; + std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin); + fVec _q_x_0, _q_x_1; + std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x); + fVec _q_y_0, _q_y_1; + std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y); + + auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0; + auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1; + auto out1 = convert_from_float_ext(out1_0, out1_1); + out1.store(qk + out_x); + + auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0; + auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1; + auto out2 = convert_from_float_ext(out2_0, out2_1); + out2.store(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + int64_t x_index = j; + int64_t y_index = embed_dim + j; + + int64_t out_x = token_head + x_index; + int64_t out_y = token_head + y_index; + + float _cos = cache_ptr[x_index]; + float _sin = cache_ptr[y_index]; + + float _q_x = qk[out_x]; + float _q_y = qk[out_y]; + + qk[out_x] = _q_x * _cos - _q_y * _sin; + qk[out_y] = _q_y * _cos + _q_x * _sin; + } + } + }; + +#pragma omp parallel for + for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim; + + for (int64_t i = 0; i < num_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * query_stride_s + head_idx * head_size; + compute_loop(token_head, cache_ptr, query); + } + + for (int64_t i = 0; i < num_kv_heads; ++i) { + int64_t head_idx = i; + int64_t token_head = token_idx * key_stride_s + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } + } +} + +template +void rotary_embedding_2D_kernel_impl( + int64_t* __restrict__ positions, + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ cos_sin_cache, + int64_t rotary_dim, + int64_t query_stride_s, + int64_t key_stride_s, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + int64_t num_tokens) { + int64_t embed_dim = rotary_dim / 2; + + at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { + int64_t token_idx = {0}, i = {0}; + data_index_init(begin, token_idx, num_tokens, i, num_heads); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + int64_t pos = positions[token_idx]; + scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim; + scalar_t* cos_cache_ptr = cache_ptr; + scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + int64_t head_idx = i; + int64_t token_head = token_idx * query_stride_s + head_idx * head_size; + scalar_t* head_query = token_head + query; + for (int64_t j = 0; j < embed_dim; j += 1) { + int64_t rot_offset = j; + int64_t x_index = 2 * rot_offset; + int64_t y_index = 2 * rot_offset + 1; + + float cos = cos_cache_ptr[rot_offset]; + float sin = sin_cache_ptr[rot_offset]; + + float x = head_query[x_index]; + float y = head_query[y_index]; + + head_query[x_index] = x * cos - y * sin; + head_query[y_index] = y * cos + x * sin; + } + data_index_step(token_idx, num_tokens, i, num_heads); + } + }); + + at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { + int64_t token_idx{0}, i = {0}; + data_index_init(begin, token_idx, num_tokens, i, num_kv_heads); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + int64_t pos = positions[token_idx]; + scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim; + scalar_t* cos_cache_ptr = cache_ptr; + scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + int64_t head_idx = i; + int64_t token_head = token_idx * key_stride_s + head_idx * head_size; + scalar_t* head_key = key + token_head; + for (int64_t j = 0; j < embed_dim; j += 1) { + int64_t rot_offset = j; + int64_t x_index = 2 * rot_offset; + int64_t y_index = 2 * rot_offset + 1; + + float cos = cos_cache_ptr[rot_offset]; + float sin = sin_cache_ptr[rot_offset]; + + float x = head_key[x_index]; + float y = head_key[y_index]; + + head_key[x_index] = x * cos - y * sin; + head_key[y_index] = y * cos + x * sin; + } + data_index_step(token_idx, num_tokens, i, num_kv_heads); + } + }); +} + } // namespace -std::tuple -rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) { - RECORD_FUNCTION( - "sgl-kernel::rotary_position_embedding_cpu", std::vector({t_pos, q_pe, k_pe, t_emb_pos})); - CHECK_INPUT(t_pos); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe); - CHECK_INPUT(t_emb_pos); - CHECK_DIM(1, t_pos); - CHECK_DIM(3, q_pe); - CHECK_DIM(3, k_pe); - CHECK_DIM(2, t_emb_pos); - - int64_t seq_len = q_pe.size(0); - int64_t num_head = q_pe.size(1); - int64_t rotary_dim = q_pe.size(2); - int64_t HK = k_pe.size(2); - int64_t HR = t_emb_pos.size(1); - CHECK_EQ(HR, rotary_dim); - CHECK_EQ(k_pe.size(0), seq_len); - CHECK_EQ(k_pe.size(1), 1); - CHECK_EQ(t_pos.size(0), seq_len); - CHECK_EQ(HK, rotary_dim); - - at::Tensor q_pe_out = at::empty_like(q_pe); - at::Tensor k_pe_out = at::empty_like(k_pe); - int64_t q_pe_stride_s = q_pe.stride(0); - int64_t q_pe_stride_n = q_pe.stride(1); - int64_t k_pe_stride_s = k_pe.stride(0); - int64_t out_stride_qs = q_pe_out.stride(0); - int64_t out_stride_qn = q_pe_out.stride(1); - int64_t out_stride_ks = k_pe_out.stride(0); - - const auto input_dtype = q_pe.scalar_type(); - TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type()); - TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type"); - TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type"); - - AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] { - rope_kernel_impl( - q_pe_out.data_ptr(), - k_pe_out.data_ptr(), - t_pos.data_ptr(), - q_pe.data_ptr(), - k_pe.data_ptr(), - t_emb_pos.data_ptr(), - seq_len, - num_head, - rotary_dim, - HR, - q_pe_stride_s, - out_stride_qs, - out_stride_ks, - HK, - k_pe_stride_s, - q_pe_stride_n, - out_stride_qn); +std::tuple rotary_embedding_cpu( + at::Tensor& positions, + at::Tensor& query, + at::Tensor& key, + int64_t head_size, + at::Tensor& cos_sin_cache, + bool is_neox) { + RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector({query, key})); + CHECK_DIM(1, positions); + const auto input_dim = query.dim(); + const auto input_dtype = query.scalar_type(); + TORCH_CHECK( + input_dim == 2 || input_dim == 3, + " Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor"); + CHECK_DIM(2, cos_sin_cache); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(query); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(key); + + int64_t rotary_dim = cos_sin_cache.size(1); + if (input_dim == 3) { + // TODO: add support for head_dim != rotary_dim case when input_dim=3 + CHECK_EQ(query.size(-1), rotary_dim); + // TODO: add support for kv_head != 1 + CHECK_EQ(key.size(1), 1); + } + + int64_t num_tokens = positions.numel(); + CHECK_EQ(key.size(0), num_tokens); + CHECK_EQ(query.size(0), num_tokens); + + TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type()); + TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type"); + TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type"); + + int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1); + int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1); + int64_t key_stride_s = key.stride(0); + int64_t query_stride_s = query.stride(0); + + // input stride of num head dim is meaningful only when input dim = 3 + int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1; + at::Tensor query_out = at::empty_like(query); + at::Tensor key_out = at::empty_like(key); + int64_t query_out_stride_s = query_out.stride(0); + int64_t key_out_stride_s = key_out.stride(0); + // output stride of num head dim is meaningful only when input dim = 3 + int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1; + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] { + if (input_dim == 2) { + if (is_neox) { + rotary_embedding_neox_2D_kernel_impl( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rotary_dim, + query_stride_s, + key_stride_s, + num_heads, + num_kv_heads, + head_size, + num_tokens); + } else { + rotary_embedding_2D_kernel_impl( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rotary_dim, + query_stride_s, + key_stride_s, + num_heads, + num_kv_heads, + head_size, + num_tokens); + } + query_out = query; + key_out = key; + + } else { + TORCH_CHECK( + is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet"); + // TODO: add neox style support for rope impl with 3D inputs + rotary_embedding_3D_kernel_impl( + query_out.data_ptr(), + key_out.data_ptr(), + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + num_tokens, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + query_stride_s, + query_out_stride_s, + key_out_stride_s, + key_stride_s, + query_stride_h, + query_out_stride_h); + } }); - return std::make_tuple(q_pe_out, k_pe_out); + return std::make_tuple(query_out, key_out); } diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index 6a6b64d12ab..da8639a35c3 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) } } +template +void topk_sigmoid_kernel_impl( + float* __restrict__ topk_weights, + int32_t* __restrict__ topk_ids, + const scalar_t* __restrict__ gating_output, + int64_t num_tokens, + int64_t topk, + bool renormalize) { + using Vec = at::vec::Vectorized; + const int64_t num_experts_per_group = NUM_EXPERTS; + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + alignas(64) float scores[NUM_EXPERTS]; + using elem_t = std::pair; + std::vector queue(num_experts_per_group); + + for (int64_t i = begin; i < end; ++i) { + at::vec::convert(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS); + + float gmax = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group); + + // find position of first max, + // note that we may have multiple max values. + int first_max_idx = -1; + for (int64_t e = 0; e < num_experts_per_group; ++e) { + if (scores[e] == gmax) { + first_max_idx = e; + break; + } + } + + // scalar sigmoid + topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax)); + topk_ids[i] = first_max_idx; + + if (renormalize) { + float sum = 0.f; + for (int64_t j = 0; j < topk; ++j) { + sum += topk_weights[i * topk + j]; + } + float scale = 1.f / sum; + for (int64_t j = 0; j < topk; ++j) { + topk_weights[i * topk + j] *= scale; + } + } + } + }); +} + +template +void topk_softmax_kernel_impl( + float* __restrict__ topk_weights, + int32_t* __restrict__ topk_ids, + const scalar_t* __restrict__ gating_output, + int64_t num_tokens, + int64_t topk, + bool renormalize) { + const int64_t num_experts_per_group = NUM_EXPERTS; + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + alignas(64) float scores[NUM_EXPERTS]; + using elem_t = std::pair; + std::vector queue(num_experts_per_group); + + for (int64_t i = begin; i < end; ++i) { + softmax(scores, gating_output + i * NUM_EXPERTS); + + for (int64_t e = 0; e < num_experts_per_group; ++e) { + queue[e] = {scores[e], e}; + } + + std::partial_sort( + queue.begin(), + queue.begin() + num_experts_per_group, + queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; }); + + for (int64_t j = 0; j < topk; ++j) { + topk_weights[i * topk + j] = queue[j].first; + topk_ids[i * topk + j] = queue[j].second; + } + + if (renormalize) { + float sum = 0.f; + for (int64_t j = 0; j < topk; ++j) { + sum += topk_weights[i * topk + j]; + } + float scale = 1.f / sum; + for (int64_t j = 0; j < topk; ++j) { + topk_weights[i * topk + j] *= scale; + } + } + } + }); +} + template inline void apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) { @@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl( topk_group, \ renormalize); +#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \ + topk_sigmoid_kernel_impl( \ + topk_weights.data_ptr(), \ + topk_ids.data_ptr(), \ + gating_output.data_ptr(), \ + num_tokens, \ + topk, \ + renormalize); + +#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \ + topk_softmax_kernel_impl( \ + topk_weights.data_ptr(), \ + topk_ids.data_ptr(), \ + gating_output.data_ptr(), \ + num_tokens, \ + topk, \ + renormalize); + #define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \ biased_grouped_topk_kernel_impl( \ topk_weights.data_ptr(), \ @@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl( } // anonymous namespace +std::tuple +topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) { + RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector({hidden_states, gating_output})); + CHECK_INPUT(gating_output); + + const auto st = hidden_states.scalar_type(); + CHECK_EQ(gating_output.scalar_type(), st); + + int64_t num_tokens = hidden_states.size(0); + int64_t num_experts = gating_output.size(1); + TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch"); + TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case"); + at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); + at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] { + switch (num_experts) { + case 1: + LAUNCH_TOPK_SIGMOID_KERNEL(1); + break; + case 2: + LAUNCH_TOPK_SIGMOID_KERNEL(2); + break; + case 4: + LAUNCH_TOPK_SIGMOID_KERNEL(4); + break; + case 8: + LAUNCH_TOPK_SIGMOID_KERNEL(8); + break; + case 16: + LAUNCH_TOPK_SIGMOID_KERNEL(16); + break; + case 32: + LAUNCH_TOPK_SIGMOID_KERNEL(32); + break; + case 64: + LAUNCH_TOPK_SIGMOID_KERNEL(64); + break; + case 128: + LAUNCH_TOPK_SIGMOID_KERNEL(128); + break; + case 160: + LAUNCH_TOPK_SIGMOID_KERNEL(160); + break; + case 256: + LAUNCH_TOPK_SIGMOID_KERNEL(256); + break; + default: + TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); + } + }); + return std::make_tuple(topk_weights, topk_ids); +} + +std::tuple +topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) { + RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector({hidden_states, gating_output})); + CHECK_INPUT(gating_output); + + const auto st = hidden_states.scalar_type(); + CHECK_EQ(gating_output.scalar_type(), st); + + int64_t num_tokens = hidden_states.size(0); + int64_t num_experts = gating_output.size(1); + TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch"); + + at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); + at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] { + switch (num_experts) { + case 1: + LAUNCH_TOPK_SOFTMAX_KERNEL(1); + break; + case 2: + LAUNCH_TOPK_SOFTMAX_KERNEL(2); + break; + case 4: + LAUNCH_TOPK_SOFTMAX_KERNEL(4); + break; + case 8: + LAUNCH_TOPK_SOFTMAX_KERNEL(8); + break; + case 16: + LAUNCH_TOPK_SOFTMAX_KERNEL(16); + break; + case 32: + LAUNCH_TOPK_SOFTMAX_KERNEL(32); + break; + case 64: + LAUNCH_TOPK_SOFTMAX_KERNEL(64); + break; + case 128: + LAUNCH_TOPK_SOFTMAX_KERNEL(128); + break; + case 160: + LAUNCH_TOPK_SOFTMAX_KERNEL(160); + break; + case 256: + LAUNCH_TOPK_SOFTMAX_KERNEL(256); + break; + default: + TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); + } + }); + return std::make_tuple(topk_weights, topk_ids); +} + // grouped topk for DeepSeek V2 std::tuple grouped_topk_cpu( at::Tensor& hidden_states, diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 018f8efb846..b718fc1b1e9 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -23,6 +23,9 @@ limitations under the License. // silu_and_mul at::Tensor silu_and_mul_cpu(at::Tensor& input); +// l2norm +at::Tensor l2norm_cpu(at::Tensor& input, double eps); + // rmsnorm at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps); @@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps); void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps); // topk +std::tuple +topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize); +std::tuple +topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize); + std::tuple grouped_topk_cpu( at::Tensor& hidden_states, at::Tensor& gating_output, @@ -185,8 +193,13 @@ void shm_allreduce( at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int64_t dim); // rope -std::tuple -rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos); +std::tuple rotary_embedding_cpu( + at::Tensor& positions, + at::Tensor& query, + at::Tensor& key, + int64_t head_size, + at::Tensor& cos_sin_cache, + bool is_neox); TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation @@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // norm m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu); + m.def("l2norm_cpu(Tensor input, float eps) -> Tensor"); + m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu); m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"); m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu); // topk + m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)"); + m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu); + m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)"); + m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu); m.def( "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, " "int topk_group) -> (Tensor, Tensor)"); @@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("shm_allgather", torch::kCPU, &shm_allgather); // rope - m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"); - m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu); + m.def( + "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, " + "bool is_neox) -> (Tensor, Tensor)"); + m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu); } REGISTER_EXTENSION(common_ops) diff --git a/test/srt/cpu/test_norm.py b/test/srt/cpu/test_norm.py index 8af46c6a195..b7d139a5be5 100644 --- a/test/srt/cpu/test_norm.py +++ b/test/srt/cpu/test_norm.py @@ -63,10 +63,24 @@ def _norm_test(self, m, n, dtype): self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol)) + def _l2norm_test(self, m, n, dtype): + + x = torch.randn([m, n], dtype=dtype) + hidden_size = x.size(-1) + fake_ones_weight = torch.ones(hidden_size, dtype=dtype) + variance_epsilon = 1e-6 + + out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon) + ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) + + atol = rtol = precision[ref_out.dtype] + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + def test_norm(self): for params in itertools.product(self.M, self.N, self.dtype): with self.subTest(m=params[0], n=params[1], dtype=params[2]): self._norm_test(*params) + self._l2norm_test(*params) if __name__ == "__main__": diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py index 33b6fc6238e..b9c5da42b11 100644 --- a/test/srt/cpu/test_rope.py +++ b/test/srt/cpu/test_rope.py @@ -4,7 +4,10 @@ import torch from utils import precision -from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding +from sglang.srt.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, + RotaryEmbedding, +) from sglang.test.test_utils import CustomTestCase @@ -62,10 +65,13 @@ def test_deepseek_v2_rope(self): ) # fused rope kernel - q_pe_clone, k_pe_clone = ( - torch.ops.sgl_kernel.rotary_position_embedding_cpu( - positions, q_pe_clone, k_pe_clone, cos_sin_cache - ) + q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu( + positions, + q_pe_clone, + k_pe_clone, + rope.head_size, + cos_sin_cache, + False, ) atol = rtol = precision[q_pe.dtype] @@ -73,6 +79,98 @@ def test_deepseek_v2_rope(self): self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol)) torch.testing.assert_close(k_pe, k_pe_clone) + def test_origin_rope(self): + def single_test( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, + ): + torch.manual_seed(100) + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + ).to(device) + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, + num_q_heads * head_size, + dtype=dtype, + device=device, + ) + key = torch.randn( + batch_size * seq_len, + num_kv_heads * head_size, + dtype=dtype, + device=device, + ) + + query_ref, key_ref = query.clone(), key.clone() + query_cpu, key_cpu = query.clone(), key.clone() + + query_ref_out, key_ref_out = rope_ref.forward_native( + pos_ids, query_ref, key_ref + ) + query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu( + pos_ids, + query_cpu, + key_cpu, + rope_ref.head_size, + rope_ref.cos_sin_cache.to(query.dtype), + rope_ref.is_neox_style, + ) + torch.testing.assert_close( + query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2) + + test_config = [ + (64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8), + (512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2), + ] + + for ( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + ) in test_config: + single_test( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/cpu/test_topk.py b/test/srt/cpu/test_topk.py index 22c9e27848c..3d0138d9afe 100644 --- a/test/srt/cpu/test_topk.py +++ b/test/srt/cpu/test_topk.py @@ -8,7 +8,9 @@ from sglang.srt.layers.moe.topk import ( biased_grouped_topk_impl as native_biased_grouped_topk, ) +from sglang.srt.layers.moe.topk import fused_topk_native as native_fused_topk from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk +from sglang.srt.models.llama4 import Llama4MoE from sglang.test.test_utils import CustomTestCase @@ -94,5 +96,86 @@ def test_biased_grouped_topk(self): self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) +class TestTopK(CustomTestCase): + def _run_single_test(self, M, E, topk, renormalize, dtype): + torch.manual_seed(1998) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_fused_topk( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + ) + + # fused version + topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( + hidden_states, gating_output, topk, renormalize + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_topk(self): + for renormalize in [True, False]: + self._run_single_test(123, 8, 2, renormalize, torch.bfloat16) + self._run_single_test(123, 16, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) + self._run_single_test(123, 64, 6, renormalize, torch.bfloat16) + self._run_single_test(123, 256, 4, renormalize, torch.bfloat16) + self._run_single_test(123, 160, 6, renormalize, torch.bfloat16) + + +class TestCustomTopK(CustomTestCase): + def _run_single_test( + self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f + ): + torch.manual_seed(16) + + # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating + hidden_states = torch.randn(M, 100, dtype=dtype) + gating_output = torch.randn(M, E, dtype=dtype) * 2 * M + + ref_topk_weights, ref_topk_ids = native_custom_f( + hidden_states.float(), + gating_output.float(), + topk, + renormalize, + ) + + # fused version + topk_weights, topk_ids = fused_custom_f( + hidden_states, gating_output, topk, renormalize + ) + + res = torch.zeros(M, E, dtype=torch.float) + ref = torch.zeros(M, E, dtype=torch.float) + res.scatter_(1, topk_ids.long(), topk_weights) + ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) + torch.testing.assert_close(res, ref) + + def test_custom_topk(self): + test_custom_functions = [ + (Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu) + ] + for native_custom_f, fused_custom_f in test_custom_functions: + self._run_single_test( + 123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + self._run_single_test( + 123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + self._run_single_test( + 123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f + ) + + if __name__ == "__main__": unittest.main()