diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 3b0c34da826..9f3a9e1c3c1 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -55,9 +55,17 @@ Steps to add a new kernel: 1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `` instead of ``. Using `` will cause compilation errors when using SABI. -2. When creating torch extensions, simply add the function definition with `m.def`: +2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`: +- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9) + +- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) + ```cpp - m.def("register_graph_buffers", register_graph_buffers); + // We need def with schema here for torch.compile + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); ``` 3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels. @@ -96,6 +104,8 @@ Steps to add a new kernel: When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`. +> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++. + To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically. When you need to support new data type conversions, you can easily add conversion functions like this: @@ -119,7 +129,7 @@ To use this with your library functions, simply wrap them with make_pytorch_shim /* * From flash-attention */ - m.def("fwd", make_pytorch_shim(mha_fwd)); + m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ``` ### Build & Install diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index c26b9024c53..3b91e63cd09 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -22,49 +22,121 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ - m.def("init_custom_ar", init_custom_ar); - m.def("dispose", dispose); - m.def("all_reduce", all_reduce); - m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta); - m.def("register_graph_buffers", register_graph_buffers); + + m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + m.def("register_graph_buffers", ®ister_graph_buffers); + m.def("dispose", &dispose); + + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); /* * From csrc/attention */ - m.def("lightning_attention_decode", lightning_attention_decode); + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); /* * From csrc/elementwise */ - m.def("rmsnorm", rmsnorm); - m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm); - m.def("gemma_rmsnorm", gemma_rmsnorm); - m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); - m.def("silu_and_mul", silu_and_mul); - m.def("gelu_tanh_and_mul", gelu_tanh_and_mul); - m.def("gelu_and_mul", gelu_and_mul); - m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); /* * From csrc/gemm */ - m.def("awq_dequantize", awq_dequantize); - m.def("int8_scaled_mm", int8_scaled_mm); - m.def("fp8_scaled_mm", fp8_scaled_mm); - m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm); - m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8); - m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8); - m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8); - m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8); - m.def("cublas_grouped_gemm", cublas_grouped_gemm); - m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm); - m.def("scaled_fp4_quant", scaled_fp4_quant); + m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"); + m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); + + m.def( + "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> " + "Tensor"); + m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); + + m.def( + "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float fp8_min, float fp8_max) -> ()"); + m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + + m.def( + "sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float int8_min, float int8_max) -> ()"); + m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); + + m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); + m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); + + m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); + m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); + + m.def( + "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," + " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); + m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); + + m.def( + "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," + " Tensor block_scale_a, Tensor block_scale_b," + " Tensor alpha) -> ()"); + m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + + m.def( + "scaled_fp4_quant(Tensor! output, Tensor! input," + " Tensor! output_scale, Tensor! input_scale) -> ()"); + m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); /* * From csrc/moe */ - m.def("moe_align_block_size", moe_align_block_size); - m.def("topk_softmax", topk_softmax); + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " @@ -74,10 +146,28 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/speculative */ - m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only); - m.def("verify_tree_greedy", verify_tree_greedy); - m.def("build_tree_kernel_efficient", build_tree_kernel_efficient); - m.def("segment_packbits", segment_packbits); + m.def( + "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " + "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " + "float threshold_single, float threshold_acc, " + "bool deterministic, int cuda_stream) -> ()"); + m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); + + m.def( + "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " + "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "Tensor target_predict, int cuda_stream) -> ()"); + m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); + + m.def( + "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " + "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " + "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); + m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); + + m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"); + m.impl("segment_packbits", torch::kCUDA, &segment_packbits); /* * From FlashInfer @@ -86,16 +176,71 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "cublas_handle, int cuda_stream) -> ()"); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); - m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); - m.def("top_k_renorm_probs", top_k_renorm_probs); - m.def("top_p_renorm_probs", top_p_renorm_probs); - m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); - m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); + + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + m.def( + "top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs); + + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); /* * From flash-attention */ - m.def("fwd", make_pytorch_shim(mha_fwd)); + m.def( + "fwd(Tensor! q," + " Tensor k," + " Tensor v," + " Tensor? k_new," + " Tensor? v_new," + " Tensor? q_v," + " Tensor!? out," + " Tensor? cu_seqlens_q," + " Tensor? cu_seqlens_k," + " Tensor? cu_seqlens_k_new," + " Tensor? seqused_q," + " Tensor? seqused_k," + " int? max_seqlen_q," + " int? max_seqlen_k," + " Tensor? page_table," + " Tensor? kv_batch_idx," + " Tensor? leftpad_k," + " Tensor? rotary_cos," + " Tensor? rotary_sin," + " Tensor? seqlens_rotary," + " Tensor? q_descale," + " Tensor? k_descale," + " Tensor? v_descale," + " float softmax_scale," + " bool is_causal," + " int window_size_left," + " int window_size_right," + " float softcap," + " bool is_rotary_interleaved," + " Tensor? scheduler_metadata," + " int num_splits," + " bool? pack_gqa," + " int sm_margin) -> Tensor[]"); + m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } REGISTER_EXTENSION(common_ops)