diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 7a8b8601d35..76661f21115 100755 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -199,6 +199,7 @@ set(SOURCES "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" + "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/common_extension.cc" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc old mode 100755 new mode 100644 index 7c0df156d94..eee2fdee946 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "bool is_causal, float softcap, bool return_softmax, " "Generator? gen) -> Tensor[]"); m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse); + + /* + * From XGrammar + */ + m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"); + m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu new file mode 100644 index 00000000000..9a99debb6b9 --- /dev/null +++ b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu @@ -0,0 +1,251 @@ +// Adapted from +// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu + +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// clang-format off +#include +#include +#include +#include +#include +// clang-format on + +#ifndef CUDART_INF_FP16 +#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) +#endif + +#ifndef CUDART_INF_BF16 +#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) +#endif + +constexpr int32_t BITS_PER_BLOCK = 32; +constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; + +template +__device__ T NegativeInfinity() { + return -INFINITY; +} + +template <> +__device__ __half NegativeInfinity<__half>() { + return -CUDART_INF_FP16; +} + +template <> +__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() { + return -CUDART_INF_BF16; +} + +template +__device__ PackedT PackedNegativeInfinity() { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + T packed[kAlignment]; +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + packed[i] = NegativeInfinity(); + } + return *reinterpret_cast(packed); +} + +template +__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; + + const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; + + const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; + T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; + const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; + const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); + T logits_reg[kAlignment]; + +#pragma unroll + for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; + offset += THREADS_PER_THREAD_BLOCK * kAlignment) { + if (block_offset + offset >= vocab_size) { + break; + } + + const uint32_t bitmask_val = + (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; + + if (bitmask_val == 0) { + continue; + } + + if (bitmask_val == kPackedMask) { + *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity(); + continue; + } + + *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset); +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + if (((bitmask_val >> i) & 1)) { + logits_reg[i] = NegativeInfinity(); + } + } + *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg); + } +} + +template ::value>> +constexpr auto CeilDiv(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + +template +void ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); + const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); + + const dim3 block(THREADS_PER_THREAD_BLOCK); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (num_bits_per_thread <= 4 && kAlignment <= 4) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 8 && kAlignment <= 8) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 16 && kAlignment <= 16) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } +} + +template +void ApplyTokenBitmaskInplaceDispatchToPackedT( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } else { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } +} + +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); + TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor."); + std::pair logits_shape = + logits.dim() == 2 ? std::make_pair(static_cast(logits.size(0)), static_cast(logits.size(1))) + : std::make_pair(1, static_cast(logits.size(0))); + + TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); + TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); + TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor."); + std::pair bitmask_shape = + bitmask.dim() == 2 ? std::make_pair(static_cast(bitmask.size(0)), static_cast(bitmask.size(1))) + : std::make_pair(1, static_cast(bitmask.size(0))); + + TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32."); + + TORCH_CHECK( + (logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second, + "The provided logits's vocab size should be no less than the bitmask's vocab size " + "(converted from bitmask size). But got vocab size ", + logits_shape.second, + " vs bitmask size ", + bitmask_shape.second); + + int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); + + int32_t num_rows = logits_shape.first; + int32_t* indices_ptr = nullptr; + if (indices) { + TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor."); + TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); + TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); + TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); + num_rows = indices->size(0); + indices_ptr = indices->data_ptr(); + } else { + TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size."); + } + + switch (logits.scalar_type()) { + case torch::kFloat32: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + logits.data_ptr(), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__half*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kBFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + default: + TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); + break; + } +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h old mode 100755 new mode 100644 index 3c906f587e1..f8a3294e618 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -352,3 +352,8 @@ std::vector mha_varlen_fwd_sparse( const bool return_softmax, c10::optional gen_); } // namespace flash + +/* + * From XGrammar + */ +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index e6ba19e0f41..acd21a46ef2 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -41,6 +41,7 @@ sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, ) +from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.moe import ( fp8_blockwise_scaled_grouped_mm, moe_align_block_size, diff --git a/sgl-kernel/python/sgl_kernel/grammar.py b/sgl-kernel/python/sgl_kernel/grammar.py new file mode 100644 index 00000000000..971f94bc2e5 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/grammar.py @@ -0,0 +1,15 @@ +from typing import List, Optional, Union + +import torch + + +def apply_token_bitmask_inplace_cuda( + logits: torch.Tensor, + bitmask: torch.Tensor, + indices: Optional[Union[List[int], torch.Tensor]] = None, +) -> None: + if isinstance(indices, list): + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) + if indices is not None: + indices = indices.to(logits.device) + torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices) diff --git a/sgl-kernel/tests/test_apply_token_bitmask_inplace.py b/sgl-kernel/tests/test_apply_token_bitmask_inplace.py new file mode 100644 index 00000000000..480479134cb --- /dev/null +++ b/sgl-kernel/tests/test_apply_token_bitmask_inplace.py @@ -0,0 +1,23 @@ +import pytest +import torch +from sgl_kernel import apply_token_bitmask_inplace_cuda + + +def test_apply_token_bitmask_inplace_kernel(): + neginf = float("-inf") + bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool) + logits = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32 + ) + expected = torch.where(bool_mask, logits, neginf) + + logits_gpu = logits.to("cuda") + bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda") + apply_token_bitmask_inplace_cuda(logits_gpu, bitmask) + torch.cuda.synchronize() + torch.testing.assert_close(logits_gpu, expected.to("cuda")) + + +if __name__ == "__main__": + test_apply_token_bitmask_inplace_kernel() + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py index 27ef9c0214e..18493f00717 100755 --- a/sgl-kernel/tests/test_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -47,6 +47,16 @@ def group_broadcast(t, shape): ).to(out_dtype) +def is_sm100_supported(device=None) -> bool: + return (torch.cuda.get_device_capability(device)[0] == 10) and ( + torch.version.cuda >= "12.8" + ) + + +@pytest.mark.skipif( + not is_sm100_supported(), + reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100", +) @pytest.mark.parametrize("num_experts", [8, 16]) @pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 82404e57206..09787224be0 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -48,6 +48,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi topk_group=topk_group, compiled=False, n_share_experts_fusion=n_share_experts_fusion, + routed_scaling_factor=2.5, ) # When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension