Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions sgl-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ Steps to add a new kernel:

1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.

2. When creating torch extensions, simply add the function definition with `m.def`:
2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)

- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)

```cpp
m.def("register_graph_buffers", register_graph_buffers);
// We need def with schema here for torch.compile
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
```

3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
Expand Down Expand Up @@ -96,6 +104,8 @@ Steps to add a new kernel:

When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.

> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.

To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.

When you need to support new data type conversions, you can easily add conversion functions like this:
Expand All @@ -119,7 +129,7 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
```

### Build & Install
Expand Down
219 changes: 182 additions & 37 deletions sgl-kernel/csrc/torch_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,121 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/allreduce
*/
m.def("init_custom_ar", init_custom_ar);
m.def("dispose", dispose);
m.def("all_reduce", all_reduce);
m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", register_graph_buffers);

m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose);

m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
"barrier_in, int[] barrier_out) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);

/*
* From csrc/attention
*/
m.def("lightning_attention_decode", lightning_attention_decode);
m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);

/*
* From csrc/elementwise
*/
m.def("rmsnorm", rmsnorm);
m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm", gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm);
m.def("silu_and_mul", silu_and_mul);
m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
m.def("gelu_and_mul", gelu_and_mul);
m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);

m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);

m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);

m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);

m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);

/*
* From csrc/gemm
*/
m.def("awq_dequantize", awq_dequantize);
m.def("int8_scaled_mm", int8_scaled_mm);
m.def("fp8_scaled_mm", fp8_scaled_mm);
m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm);
m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8);
m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8);
m.def("cublas_grouped_gemm", cublas_grouped_gemm);
m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm);
m.def("scaled_fp4_quant", scaled_fp4_quant);
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);

m.def(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);

m.def(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor");
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);

m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);

m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);

m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);

m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);

m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);

m.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);

m.def(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);

/*
* From csrc/moe
*/
m.def("moe_align_block_size", moe_align_block_size);
m.def("topk_softmax", topk_softmax);
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);

m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
Expand All @@ -74,10 +146,28 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/speculative
*/
m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only);
m.def("verify_tree_greedy", verify_tree_greedy);
m.def("build_tree_kernel_efficient", build_tree_kernel_efficient);
m.def("segment_packbits", segment_packbits);
m.def(
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);

m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);

m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);

m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);

/*
* From FlashInfer
Expand All @@ -86,16 +176,71 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
m.def("top_k_renorm_probs", top_k_renorm_probs);
m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);

m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);

m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);

m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);

m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);

m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);

/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
m.def(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}

REGISTER_EXTENSION(common_ops)
Loading