diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 95bf67322aa..929a60d3b80 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm) # flashinfer FetchContent_Declare( repo-flashinfer - GIT_REPOSITORY https://github.com/sgl-project/flashinfer - GIT_TAG sgl-kernel + GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git + GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d3e0ffae82b..a1c36fff091 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/elementwise */ - m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("rmsnorm", torch::kCUDA, &rmsnorm); - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); - m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); - m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); @@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.def( - "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " - "min_p_val, bool deterministic, int cuda_stream) -> ()"); + "min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); - m.def( - "top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " - "cuda_stream) -> ()"); + m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs); - m.def( - "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " - "cuda_stream) -> ()"); + m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()"); m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); m.def( - "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " - "cuda_stream) -> ()"); + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, " + "float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); m.def( - "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + "top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); /* diff --git a/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu index 41f4d2e7099..02d5b20c9b1 100644 --- a/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu @@ -21,7 +21,8 @@ limitations under the License. using namespace flashinfer; -void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { +void sgl_fused_add_rmsnorm( + torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) { CHECK_INPUT(input); CHECK_INPUT(residual); CHECK_INPUT(weight); @@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T static_cast(weight.data_ptr()), batch_size, hidden_size, + input.stride(0), + residual.stride(0), eps, + enable_pdl, torch_current_stream); TORCH_CHECK( status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh index 158dfa4710c..a773c0e2705 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly( DType threshold_acc) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; - extern __shared__ __align__(alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof(SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = - reinterpret_cast&>(smem_sampling); + reinterpret_cast&>(smem_sampling); DType prob_acc = 0.0; uint32_t cur_prob_offset = bx * num_draft_tokens * d; @@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly( relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0)); } - DeviceSamplingFromProb( + DeviceSamplingFromProb( i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage); if (aggregate_relu_q_minus_p > u) { break; @@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 118a8ba058e..8a6d1c44b2f 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, /* * From csrc/elementwise */ -void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); -void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void gemma_fused_add_rmsnorm( - at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); +void sgl_fused_add_rmsnorm( + torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); @@ -254,48 +254,38 @@ void segment_packbits( */ void min_p_sampling_from_probs( at::Tensor probs, - at::Tensor uniform_samples, - at::Tensor samples, + at::Tensor output, + std::optional maybe_indices, std::optional maybe_min_p_arr, double min_p_val, bool deterministic, - int64_t cuda_stream); + std::optional gen); void top_k_renorm_probs( - at::Tensor probs, - at::Tensor renorm_probs, - std::optional maybe_top_k_arr, - int64_t top_k_val, - int64_t cuda_stream); + at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, int64_t top_k_val); void top_p_renorm_probs( - at::Tensor probs, - at::Tensor renorm_probs, - std::optional maybe_top_p_arr, - double top_p_val, - int64_t cuda_stream); + at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val); void top_k_top_p_sampling_from_probs( at::Tensor probs, - at::Tensor uniform_samples, - at::Tensor samples, - at::Tensor success, + at::Tensor output, + std::optional maybe_indices, std::optional maybe_top_k_arr, double top_k_val, std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); + std::optional gen); void top_p_sampling_from_probs( at::Tensor probs, - at::Tensor uniform_samples, - at::Tensor samples, - at::Tensor success, + at::Tensor output, + std::optional maybe_indices, std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); + std::optional gen); namespace flash { /* diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 307df2a5b3f..dd717d0ca3d 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -11,17 +11,69 @@ def rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, ) -> torch.Tensor: + r"""Root mean square normalization. + + ``out[i] = (input[i] / RMS(input)) * weight[i]`` + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + out: Optional[torch.Tensor] + The output tensor, if specified, the kernel will update this tensor inplace. + enable_pdl: bool + Whether to enable `programmatic dependent launch + `_ + + Returns + ------- + output: torch.Tensor + Normalized tensor, shape (batch_size, hidden_size). + """ if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl) return out def fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + enable_pdl: bool = False, ) -> None: - torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps) + r"""Fused add root mean square normalization. + + Step 1: + ``residual[i] += input[i]`` + + Step 2: + ``input[i] = (residual[i] / RMS(residual)) * weight[i]`` + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + residual: torch.Tensor + Residual tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + enable_pdl: bool + Whether to enable `programmatic dependent launch + `_ + """ + torch.ops.sgl_kernel.fused_add_rmsnorm.default( + input, residual, weight, eps, enable_pdl + ) def gemma_rmsnorm( @@ -29,20 +81,68 @@ def gemma_rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, ) -> torch.Tensor: + r"""Gemma-style root mean square normalization. + + ``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)`` + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + out: Optional[torch.Tensor] + The output tensor, if specified, the kernel will update this tensor inplace. + enable_pdl: bool + Whether to enable `programmatic dependent launch + `_ + + Returns + ------- + output: torch.Tensor + Gemma Normalized tensor, shape (batch_size, hidden_size). + """ if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernel.gemma_rmsnorm.default( - out, input, weight, eps, get_cuda_stream() - ) + torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl) return out def gemma_fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + enable_pdl: bool = False, ) -> None: + r"""Gemma-style fused add root mean square normalization. + + Step 1: + ``residual[i] += input[i]`` + + Step 2: + ``input[i] = (residual[i] / RMS(residual)) * (weight + 1)`` + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + residual: torch.Tensor + Residual tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + enable_pdl: bool + Whether to enable `programmatic dependent launch + `_ + """ torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( - input, residual, weight, eps, get_cuda_stream() + input, residual, weight, eps, enable_pdl ) diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index 7c94e9eda5a..5bc0be6c331 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -13,11 +13,7 @@ def _top_k_renorm_probs_internal( maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) torch.ops.sgl_kernel.top_k_renorm_probs.default( - probs, - renorm_probs, - maybe_top_k_arr, - top_k_val, - get_cuda_stream(), + probs, renorm_probs, maybe_top_k_arr, top_k_val ) return renorm_probs @@ -26,6 +22,30 @@ def top_k_renorm_probs( probs: torch.Tensor, top_k: Union[torch.Tensor, int], ) -> torch.Tensor: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for renormalizing probabilities by top-k thresholding. + + Parameters + ---------- + probs: torch.Tensor + Probabilities, shape ``(batch_size, num_classes)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for + for re-normalizing probabilities, should be in ``(0, num_classes)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. + + Returns + ------- + renorm_probs: torch.Tensor + Renormalized probabilities, shape ``(batch_size, num_classes)``. + + Note + ---- + This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to + ``top_k_sampling_from_probs``. + """ return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) @@ -41,11 +61,7 @@ def _top_p_renorm_probs_internal( maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None renorm_probs = torch.empty_like(probs) torch.ops.sgl_kernel.top_p_renorm_probs.default( - probs, - renorm_probs, - maybe_top_p_arr, - top_p_val, - get_cuda_stream(), + probs, renorm_probs, maybe_top_p_arr, top_p_val ) return renorm_probs @@ -54,6 +70,32 @@ def top_p_renorm_probs( probs: torch.Tensor, top_p: Union[torch.Tensor, float], ) -> torch.Tensor: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for renormalizing probabilities by top-p thresholding. + + Parameters + ---------- + probs: torch.Tensor + Probabilities, shape ``(batch_size, num_classes)``. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for + re-normalizing probabilities, should be in ``(0, 1)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + We mask out the probabilities less than `threshold` where the cumulative sum + of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. + + Returns + ------- + renorm_probs: torch.Tensor + Renormalized probabilities, shape ``(batch_size, num_classes)``. + + Note + ---- + This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to + ``top_p_sampling_from_probs``. + + """ return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) @@ -62,93 +104,187 @@ def top_p_renorm_probs( def _top_p_sampling_from_probs_internal( probs: torch.Tensor, - uniform_samples: torch.Tensor, + indices: Optional[torch.Tensor], maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, + generator: Optional[torch.Generator], ) -> Tuple[torch.Tensor, torch.Tensor]: with probs.device as device: probs = probs.float() - uniform_samples = uniform_samples.float() maybe_top_p_arr = ( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - success = torch.empty(probs.size(0), dtype=torch.bool, device=device) torch.ops.sgl_kernel.top_p_sampling_from_probs.default( probs, - uniform_samples, samples, - success, + indices, maybe_top_p_arr, top_p_val, deterministic, - get_cuda_stream(), + generator, ) - return samples, success + return samples def top_p_sampling_from_probs( probs: torch.Tensor, - uniform_samples: torch.Tensor, top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, deterministic: bool = True, + generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, + this operator implements GPU-based rejection sampling without explicit sorting. + Check the `blog post `_ for more details. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + probs: torch.Tensor + Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` + and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, + shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique + probability distributions. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + indices: Optional[torch.Tensor] + Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. + This allows reusing the same probability distribution for multiple outputs. + If indices is not provided, the i-th output will be sampled from the i-th row of probs. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. + generator: Optional[torch.Generator] + A random number generator for the operation. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. + + Returns + ------- + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + + Note + ---- + This function expects float32 inputs, and the output is int32. + + """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return _top_p_sampling_from_probs_internal( - probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic + probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator ) def _top_k_top_p_sampling_from_probs_internal( probs: torch.Tensor, - uniform_samples: torch.Tensor, + indices: Optional[torch.Tensor], maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, maybe_top_p_arr: Optional[torch.Tensor], top_p_val: float, deterministic: bool, + generator: Optional[torch.Generator], ) -> Tuple[torch.Tensor, torch.Tensor]: with probs.device as device: probs = probs.float() - uniform_samples = uniform_samples.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_p_arr = ( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - success = torch.empty(probs.size(0), dtype=torch.bool, device=device) torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default( probs, - uniform_samples, samples, - success, + indices, maybe_top_k_arr, top_k_val, maybe_top_p_arr, top_p_val, deterministic, - get_cuda_stream(), + generator, ) - return samples, success + return samples def top_k_top_p_sampling_from_probs( probs: torch.Tensor, - uniform_samples: torch.Tensor, top_k: Union[torch.Tensor, int], top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, + generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for top-k and top-p sampling from probabilities, + + this operator implements GPU-based rejection sampling without explicit sorting. + Check the `blog post `_ for more details. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + probs: torch.Tensor + Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` + and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, + shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique + probability distributions. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + indices: Optional[torch.Tensor] + Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. + This allows reusing the same probability distribution for multiple outputs. + If indices is not provided, the i-th output will be sampled from the i-th row of probs. + filter_apply_order: str + The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. + If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. + If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. + generator: Optional[torch.Generator] + A random number generator for the operation. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. + + Returns + ------- + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + + Note + ---- + This function expects float32 inputs, and the output is int32. + + """ if filter_apply_order == "top_k_first": renorm_probs = top_k_renorm_probs(probs, top_k) return top_p_sampling_from_probs( - renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan + renorm_probs, + top_p, + indices, + deterministic, + check_nan=check_nan, + generator=generator, ) elif filter_apply_order == "joint": if check_nan: @@ -156,10 +292,11 @@ def top_k_top_p_sampling_from_probs( raise ValueError("Input probs contains NaN.") return _top_k_top_p_sampling_from_probs_internal( probs, - uniform_samples, + indices, *_to_tensor_scalar_tuple(top_k), *_to_tensor_scalar_tuple(top_p), deterministic, + generator, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") @@ -167,44 +304,82 @@ def top_k_top_p_sampling_from_probs( def _min_p_sampling_from_probs_internal( probs: torch.Tensor, - uniform_samples: torch.Tensor, + indices: Optional[torch.Tensor], maybe_min_p_arr: Optional[torch.Tensor], min_p_val: float, deterministic: bool, + generator: Optional[torch.Generator], ) -> torch.Tensor: with probs.device as device: probs = probs.float() - uniform_samples = uniform_samples.float() maybe_min_p_arr = ( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) torch.ops.sgl_kernel.min_p_sampling_from_probs.default( probs, - uniform_samples, samples, + indices, maybe_min_p_arr, min_p_val, deterministic, - get_cuda_stream(), + generator, ) return samples def min_p_sampling_from_probs( probs: torch.Tensor, - uniform_samples: torch.Tensor, min_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, deterministic: bool = True, + generator: Optional[torch.Generator] = None, check_nan: bool = False, ) -> torch.Tensor: - if uniform_samples.dim() == 2: - # Take the first row (round) of uniform_samples - uniform_samples = uniform_samples[0] + r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py + Fused GPU kernel for `min_p sampling `_ from probabilities, + + this operator implements GPU-based rejection sampling without explicit sorting. + Check the `blog post `_ for more details. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + probs: torch.Tensor + Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` + and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, + shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique + probability distributions. + min_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + indices: Optional[torch.Tensor] + Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. + This allows reusing the same probability distribution for multiple outputs. + If indices is not provided, the i-th output will be sampled from the i-th row of probs. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. + generator: Optional[torch.Generator] + A random number generator for the operation. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. + + Returns + ------- + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + Note + ---- + This function expects float32 inputs, and the output is int32. + """ if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return _min_p_sampling_from_probs_internal( - probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic + probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator ) diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py index 7d3bc5059ee..14f41e5efde 100644 --- a/sgl-kernel/tests/test_sampling.py +++ b/sgl-kernel/tests/test_sampling.py @@ -5,8 +5,8 @@ import torch -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): torch.manual_seed(42) @@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - max_top_k_trails = 32 eps = 1e-4 - pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) # top-p mask sorted_prob, indices = torch.sort(normalized_prob, descending=False) cdf = torch.cumsum(sorted_prob, dim=-1) - mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) # top-k mask sorted_prob, _ = torch.sort(normalized_prob, descending=True) @@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() # overall mask mask = torch.minimum(mask_top_p, mask_top_k) - uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( - 0 - ) - top_p_tensor = torch.full((batch_size,), p).to(0) - top_k_tensor = torch.full((batch_size,), k).to(0) + top_p_tensor = torch.full((batch_size,), p, device="cuda:0") + top_k_tensor = torch.full((batch_size,), k, device="cuda:0") num_trails = 1000 for _ in range(num_trails): - uniform_samples.uniform_() - samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( + samples = sgl_kernel.top_k_top_p_sampling_from_probs( normalized_prob, - uniform_samples, top_k_tensor, top_p_tensor, filter_apply_order="joint", ) - assert torch.all(success) assert torch.all(samples < vocab_size) and torch.all(samples >= 0) assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ torch.arange(batch_size), samples ] -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_renorm_probs(batch_size, vocab_size, p): - pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, indices = torch.sort(normalized_prob, descending=False) cdf = torch.cumsum(sorted_prob, dim=-1) - mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) - renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth = normalized_prob.clone() renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( dim=-1, keepdim=True @@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p): ) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_renorm_probs(batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, _ = torch.sort(normalized_prob, descending=True) pivot = sorted_prob[:, k - 1] mask = (normalized_prob >= pivot.unsqueeze(-1)).int() - renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth = normalized_prob.clone() renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( dim=-1, keepdim=True ) renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) - torch.testing.assert_close( - renorm_prob_ground_truth, - renorm_prob, - rtol=1e-3, - atol=1e-3, - ) + for i in range(batch_size): + torch.testing.assert_close( + renorm_prob_ground_truth[i], + renorm_prob[i], + rtol=1e-3, + atol=1e-3, + ) -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_min_p_sampling(batch_size, vocab_size, p): torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, indices = torch.sort(normalized_prob, descending=False) # scale min-p top_probs = sorted_prob[:, -1].unsqueeze(-1) scaled_p = p * top_probs # min-p mask - mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) - uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) - min_p_tensor = torch.full((batch_size,), p).to(0) + min_p_tensor = torch.full((batch_size,), p, device="cuda:0") num_trails = 1000 for _ in range(num_trails): - uniform_samples.uniform_() samples = sgl_kernel.min_p_sampling_from_probs( normalized_prob, - uniform_samples, min_p_tensor, ) @@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p): torch.nonzero(mask[torch.arange(batch_size), samples] == 0) ] + assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ + torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + ] + if __name__ == "__main__": pytest.main([__file__])