diff --git a/csrc/flashinfer_sampling_binding.cu b/csrc/flashinfer_sampling_binding.cu index a3a6a08275..06488a2c8e 100644 --- a/csrc/flashinfer_sampling_binding.cu +++ b/csrc/flashinfer_sampling_binding.cu @@ -20,8 +20,9 @@ using tvm::ffi::Optional; void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, Optional maybe_temperature_arr, double temperature_val, bool enable_pdl); -void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, - bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, +void sampling_from_probs(TensorView probs, TensorView output, TensorView valid, + Optional maybe_indices, bool deterministic, + Optional maybe_seed_arr, uint64_t seed_val, Optional maybe_offset_arr, uint64_t offset_val); void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, @@ -29,28 +30,28 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_offset_arr, uint64_t offset_val); -void top_p_sampling_from_probs(TensorView probs, TensorView output, +void top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, Optional maybe_offset_arr, uint64_t offset_val); -void top_k_sampling_from_probs(TensorView probs, TensorView output, +void top_k_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, Optional maybe_offset_arr, uint64_t offset_val); -void min_p_sampling_from_probs(TensorView probs, TensorView output, +void min_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, Optional maybe_offset_arr, uint64_t offset_val); -void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, +void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_top_k_arr, double top_k_val, Optional maybe_top_p_arr, double top_p_val, diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 45f6dd9634..8b9500c444 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -101,11 +101,15 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, - bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, +void sampling_from_probs(TensorView probs, TensorView output, TensorView valid, + Optional maybe_indices, bool deterministic, + Optional maybe_seed_arr, uint64_t seed_val, Optional maybe_offset_arr, uint64_t offset_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + CHECK_INPUT(valid); + CHECK_DIM(1, valid); + CHECK_DEVICE(valid, probs); CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); @@ -119,6 +123,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional( static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + static_cast(valid.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, vocab_size, deterministic, @@ -134,7 +139,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, bool deterministic, Optional maybe_seed_arr, @@ -142,6 +147,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, uint64_t offset_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + CHECK_INPUT(valid); + CHECK_DIM(1, valid); + CHECK_DEVICE(valid, probs); CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); @@ -157,6 +165,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] { cudaError_t status = sampling::TopPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + static_cast(valid.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, @@ -173,7 +182,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, }); } -void top_k_sampling_from_probs(TensorView probs, TensorView output, +void top_k_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, bool deterministic, Optional maybe_seed_arr, @@ -184,6 +193,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) + CHECK_INPUT(valid); + CHECK_DIM(1, valid); + CHECK_DEVICE(valid, probs); CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); @@ -199,6 +211,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] { cudaError_t status = sampling::TopKSamplingFromProb( static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + static_cast(valid.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, @@ -215,7 +228,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, }); } -void min_p_sampling_from_probs(TensorView probs, TensorView output, +void min_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, bool deterministic, Optional maybe_seed_arr, @@ -226,6 +239,9 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) + CHECK_INPUT(valid); + CHECK_DIM(1, valid); + CHECK_DEVICE(valid, probs); CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); @@ -242,7 +258,7 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, cudaError_t status = sampling::MinPSamplingFromProb( static_cast(probs.data_ptr()), has_min_p_arr ? static_cast(maybe_min_p_arr.value().data_ptr()) : nullptr, - static_cast(output.data_ptr()), + static_cast(output.data_ptr()), static_cast(valid.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, min_p_val, vocab_size, deterministic, @@ -258,7 +274,7 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, }); } -void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, +void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional maybe_indices, Optional maybe_top_k_arr, double top_k_val, Optional maybe_top_p_arr, double top_p_val, @@ -270,6 +286,9 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) + CHECK_INPUT(valid); + CHECK_DIM(1, valid); + CHECK_DEVICE(valid, probs); CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); @@ -289,7 +308,7 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, static_cast(probs.data_ptr()), has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, - static_cast(output.data_ptr()), + static_cast(output.data_ptr()), static_cast(valid.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, top_k_val, top_p_val, vocab_size, deterministic, diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 034109b510..3e30f34816 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -145,12 +145,14 @@ def sampling_from_probs( generator: Optional[torch.Generator], seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) + valid = torch.empty(batch_size, dtype=torch.bool, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size, generator, device) @@ -161,6 +163,7 @@ def sampling_from_probs( module.sampling_from_probs( probs, samples, + valid, indices, deterministic, maybe_seed_arr, @@ -168,6 +171,8 @@ def sampling_from_probs( maybe_offset_arr, offset_val, ) + if return_valid: + return samples, valid return samples # torch library for sampling_from_probs @@ -178,9 +183,15 @@ def _fake_sampling_from_probs( indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 + if return_valid: + return ( + torch.empty(batch_size, dtype=out_dtype, device=probs.device), + torch.empty(batch_size, dtype=torch.bool, device=probs.device), + ) return torch.empty(batch_size, dtype=out_dtype, device=probs.device) # torch library for top_p_sampling_from_probs @@ -195,7 +206,8 @@ def top_p_sampling_from_probs( generator: Optional[torch.Generator], seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = probs.device probs = probs.float() maybe_top_p_arr = ( @@ -204,6 +216,7 @@ def top_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) + valid = torch.empty(batch_size, dtype=torch.bool, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) @@ -214,6 +227,7 @@ def top_p_sampling_from_probs( module.top_p_sampling_from_probs( probs, samples, + valid, indices, maybe_top_p_arr, top_p_val, @@ -223,6 +237,8 @@ def top_p_sampling_from_probs( maybe_offset_arr, offset_val, ) + if return_valid: + return samples, valid return samples @register_fake_op("flashinfer::top_p_sampling_from_probs") @@ -233,11 +249,16 @@ def _fake_top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 - sample = torch.empty(batch_size, dtype=out_dtype, device=probs.device) - return sample + if return_valid: + return ( + torch.empty(batch_size, dtype=out_dtype, device=probs.device), + torch.empty(batch_size, dtype=torch.bool, device=probs.device), + ) + return torch.empty(batch_size, dtype=out_dtype, device=probs.device) # torch library for top_k_sampling_from_probs @@ -251,13 +272,15 @@ def top_k_sampling_from_probs( generator: Optional[torch.Generator], seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) + valid = torch.empty(batch_size, dtype=torch.bool, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) @@ -268,6 +291,7 @@ def top_k_sampling_from_probs( module.top_k_sampling_from_probs( probs, samples, + valid, indices, maybe_top_k_arr, top_k_val, @@ -277,6 +301,8 @@ def top_k_sampling_from_probs( maybe_offset_arr, offset_val, ) + if return_valid: + return samples, valid return samples @register_fake_op("flashinfer::top_k_sampling_from_probs") @@ -287,11 +313,16 @@ def _fake_top_k_sampling_from_probs( top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 - sample = torch.empty(batch_size, dtype=out_dtype, device=probs.device) - return sample + if return_valid: + return ( + torch.empty(batch_size, dtype=out_dtype, device=probs.device), + torch.empty(batch_size, dtype=torch.bool, device=probs.device), + ) + return torch.empty(batch_size, dtype=out_dtype, device=probs.device) # torch library for min_p_sampling_from_probs @@ -305,7 +336,8 @@ def min_p_sampling_from_probs( generator: Optional[torch.Generator], seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = probs.device probs = probs.float() maybe_min_p_arr = ( @@ -314,6 +346,7 @@ def min_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) + valid = torch.empty(batch_size, dtype=torch.bool, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size, generator, device) @@ -324,6 +357,7 @@ def min_p_sampling_from_probs( module.min_p_sampling_from_probs( probs, samples, + valid, indices, maybe_min_p_arr, min_p_val, @@ -333,10 +367,30 @@ def min_p_sampling_from_probs( maybe_offset_arr, offset_val, ) + if return_valid: + return samples, valid return samples - # torch library for top_k_top_p_sampling_from_probs + @register_fake_op("flashinfer::min_p_sampling_from_probs") + def _fake_min_p_sampling_from_probs( + probs: torch.Tensor, + indices: Optional[torch.Tensor], + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, + generator: Optional[torch.Generator], + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + batch_size = indices.size(0) if indices is not None else probs.size(0) + out_dtype = indices.dtype if indices is not None else torch.int32 + if return_valid: + return ( + torch.empty(batch_size, dtype=out_dtype, device=probs.device), + torch.empty(batch_size, dtype=torch.bool, device=probs.device), + ) + return torch.empty(batch_size, dtype=out_dtype, device=probs.device) + # torch library for top_k_top_p_sampling_from_probs @register_custom_op("flashinfer::top_k_top_p_sampling_from_probs", mutates_args=()) def top_k_top_p_sampling_from_probs( probs: torch.Tensor, @@ -349,7 +403,8 @@ def top_k_top_p_sampling_from_probs( generator: Optional[torch.Generator], seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: device = probs.device probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None @@ -359,6 +414,7 @@ def top_k_top_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) + valid = torch.empty(batch_size, dtype=torch.bool, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) @@ -369,6 +425,7 @@ def top_k_top_p_sampling_from_probs( module.top_k_top_p_sampling_from_probs( probs, samples, + valid, indices, maybe_top_k_arr, top_k_val, @@ -380,6 +437,8 @@ def top_k_top_p_sampling_from_probs( maybe_offset_arr, offset_val, ) + if return_valid: + return samples, valid return samples @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") @@ -392,11 +451,16 @@ def _fake_top_k_top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: + return_valid: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 - sample = torch.empty(batch_size, dtype=out_dtype, device=probs.device) - return sample + if return_valid: + return ( + torch.empty(batch_size, dtype=out_dtype, device=probs.device), + torch.empty(batch_size, dtype=torch.bool, device=probs.device), + ) + return torch.empty(batch_size, dtype=out_dtype, device=probs.device) # torch library for top_p_renorm_probs @@ -796,6 +860,7 @@ def sampling_from_probs( check_nan: bool = False, seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, + return_valid: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from probabilities. @@ -869,7 +934,13 @@ def sampling_from_probs( if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().sampling_from_probs( - probs, indices, deterministic, generator, seed, offset + probs, + indices, + deterministic, + generator, + seed, + offset, + return_valid, ) @@ -883,6 +954,7 @@ def top_p_sampling_from_probs( check_nan: bool = False, seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, + return_valid: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -980,6 +1052,7 @@ def top_p_sampling_from_probs( generator, seed, offset, + return_valid, ) @@ -993,6 +1066,7 @@ def top_k_sampling_from_probs( check_nan: bool = False, seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, + return_valid: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -1090,6 +1164,7 @@ def top_k_sampling_from_probs( generator, seed, offset, + return_valid, ) @@ -1103,6 +1178,7 @@ def min_p_sampling_from_probs( check_nan: bool = False, seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, + return_valid: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, @@ -1196,6 +1272,7 @@ def min_p_sampling_from_probs( generator, seed, offset, + return_valid, ) @@ -1357,6 +1434,7 @@ def top_k_top_p_sampling_from_probs( check_nan: bool = False, seed: Optional[Union[int, torch.Tensor]] = None, offset: Optional[Union[int, torch.Tensor]] = None, + return_valid: bool = False, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from probabilities, @@ -1465,6 +1543,7 @@ def top_k_top_p_sampling_from_probs( generator=generator, seed=seed, offset=offset, + return_valid=return_valid, ) elif filter_apply_order == "joint": if check_nan: @@ -1479,6 +1558,7 @@ def top_k_top_p_sampling_from_probs( generator, seed, offset, + return_valid, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 7008208171..1d0711aa6f 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -621,7 +621,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( int max_valid_index = BlockReduce(temp_storage->block_prim.reduce_int) .Reduce(valid_index, MaxReduceOp{}); - if (tx == 0 && max_valid_index != -1) { + if (tx == 0 && max_valid_index != -1 && max_valid_index < (int)d) { temp_storage->last_valid_id = max_valid_index; } __syncthreads(); @@ -753,9 +753,9 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* template -__global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, - uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, - uint64_t offset_val) { +__global__ void SamplingFromProbKernel(DType* probs, IdType* output, bool* valid, IdType* indices, + uint32_t d, uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { curandStatePhilox4_32_10_t state; const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -773,6 +773,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind reinterpret_cast&>( smem_sampling); temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); vec_t probs_vec; @@ -798,17 +799,27 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; + valid[bx] = false; + } + return; + } sampled_id = temp_storage.last_valid_id; } - output[bx] = sampled_id; + if (tx == 0) { + output[bx] = sampled_id; + valid[bx] = true; + } } template -__global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, - IdType* top_k_arr, uint32_t top_k_val, uint32_t d, - uint64_t* seed_arr, uint64_t seed_val, +__global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, bool* valid, + IdType* indices, IdType* top_k_arr, uint32_t top_k_val, + uint32_t d, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -838,6 +849,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* do { round += 1; temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; @@ -861,6 +873,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; + valid[bx] = false; + } + return; + } sampled_id = temp_storage.last_valid_id; } double pivot_0 = probs[row_idx * d + sampled_id]; @@ -923,15 +942,16 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); if (tx == 0) { output[bx] = sampled_id; + valid[bx] = true; } } template -__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, - float* top_p_arr, float top_p_val, uint32_t d, - uint64_t* seed_arr, uint64_t seed_val, +__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, bool* valid, + IdType* indices, float* top_p_arr, float top_p_val, + uint32_t d, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -959,6 +979,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* int sampled_id; do { temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; @@ -982,6 +1003,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; + valid[bx] = false; + } + return; + } sampled_id = temp_storage.last_valid_id; } double pivot_0 = probs[row_idx * d + sampled_id]; @@ -1040,6 +1068,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); if (tx == 0) { output[bx] = sampled_id; + valid[bx] = true; } } @@ -1047,8 +1076,8 @@ template __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdType* output, - IdType* indices, float min_p_val, uint32_t d, - uint64_t* seed_arr, uint64_t seed_val, + bool* valid, IdType* indices, float min_p_val, + uint32_t d, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -1103,6 +1132,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp int sampled_id; temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); float u = curand_uniform(&state) * q; #pragma unroll 2 @@ -1124,19 +1154,27 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; + valid[bx] = false; + } + return; + } sampled_id = temp_storage.last_valid_id; } output[bx] = sampled_id; + valid[bx] = true; } template __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, float* top_p_arr, - IdType* output, IdType* indices, IdType top_k_val, - float top_p_val, uint32_t d, uint64_t* seed_arr, - uint64_t seed_val, uint64_t* offset_arr, - uint64_t offset_val) { + IdType* output, bool* valid, IdType* indices, + IdType top_k_val, float top_p_val, uint32_t d, + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -1164,6 +1202,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, int sampled_id; do { temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); float u = curand_uniform(&state) * q; aggregate = 0; @@ -1188,6 +1227,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id sampled_id = temp_storage.last_valid_id; + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; + valid[bx] = false; + } + return; + } } double pivot_0 = probs[row_idx * d + sampled_id]; double pivot_1 = (pivot_0 + high) / 2; @@ -1250,6 +1296,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, __syncthreads(); if (tx == 0) { output[bx] = sampled_id; + valid[bx] = true; } } @@ -1423,16 +1470,18 @@ cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint3 } template -cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t* seed_arr, uint64_t seed_val, - uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { +cudaError_t SamplingFromProb(T* probs, IdType* output, bool* valid, IdType* indices, + uint32_t batch_size, uint32_t d, bool deterministic, + uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, + uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; + void* args[] = {&probs, &output, &valid, &indices, &d, + &seed_arr, &seed_val, &offset_arr, &offset_val}; const uint32_t smem_size = sizeof(SamplingTempStorage); DISPATCH_ALIGNED_VEC_SIZE( @@ -1447,8 +1496,8 @@ cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t } template -cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, uint32_t d, +cudaError_t TopKSamplingFromProb(T* probs, IdType* output, bool* valid, IdType* indices, + T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, bool deterministic, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { @@ -1459,7 +1508,7 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val, + void* args[] = {&probs, &output, &valid, &indices, &top_k_arr, &top_k_val, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( @@ -1476,10 +1525,11 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t } template -cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, - uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, - uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, - uint64_t offset_val, cudaStream_t stream = 0) { +cudaError_t TopPSamplingFromProb(T* probs, IdType* output, bool* valid, IdType* indices, + T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, + bool deterministic, uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val, + cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); @@ -1487,7 +1537,7 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_p_arr, &top_p_val, + void* args[] = {&probs, &output, &valid, &indices, &top_p_arr, &top_p_val, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( @@ -1504,8 +1554,8 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t } template -cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* indices, - uint32_t batch_size, float min_p_val, uint32_t d, +cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, bool* valid, + IdType* indices, uint32_t batch_size, float min_p_val, uint32_t d, bool deterministic, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { @@ -1516,7 +1566,7 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr, &output, &indices, &min_p_val, + void* args[] = {&probs, &min_p_arr, &output, &valid, &indices, &min_p_val, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( @@ -1534,8 +1584,8 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* template cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, - IdType* indices, uint32_t batch_size, IdType top_k_val, - T top_p_val, uint32_t d, bool deterministic, + bool* valid, IdType* indices, uint32_t batch_size, + IdType top_k_val, T top_p_val, uint32_t d, bool deterministic, uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1545,8 +1595,9 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, &top_k_val, - &top_p_val, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &valid, + &indices, &top_k_val, &top_p_val, &d, &seed_arr, + &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index fb1c015c7e..d033739b1b 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -1110,6 +1110,63 @@ def test_sampling_with_default_device_cuda(batch_size, vocab_size): torch.set_default_device(original_device) +@pytest.mark.parametrize("batch_size", [1, 4, 19]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_sampling_nan_input(batch_size, vocab_size): + torch.manual_seed(42) + probs = torch.rand(batch_size, vocab_size, device="cuda:0", dtype=torch.float32) + probs = probs / probs.sum(dim=-1, keepdim=True) + + # Set NaN at different positions: first, middle, last + nan_indices = [0] + if batch_size > 1: + nan_indices.append(batch_size // 2) + if batch_size > 2: + nan_indices.append(batch_size - 1) + + for idx in nan_indices: + probs[idx, :] = float("nan") + + valid_indices = [i for i in range(batch_size) if i not in nan_indices] + + def check_result(result, valid): + # NaN rows should return 0 and valid=False + for idx in nan_indices: + assert result[idx].item() == 0 and not valid[idx].item() + # Non-NaN rows should have valid=True and valid token index + for idx in valid_indices: + assert valid[idx].item() + assert 0 <= result[idx].item() < vocab_size + + # sampling_from_probs + result, valid = flashinfer.sampling.sampling_from_probs(probs, return_valid=True) + check_result(result, valid) + + # top_k_sampling_from_probs + result, valid = flashinfer.sampling.top_k_sampling_from_probs( + probs, top_k=50, return_valid=True + ) + check_result(result, valid) + + # top_p_sampling_from_probs + result, valid = flashinfer.sampling.top_p_sampling_from_probs( + probs, top_p=0.9, return_valid=True + ) + check_result(result, valid) + + # min_p_sampling_from_probs + result, valid = flashinfer.sampling.min_p_sampling_from_probs( + probs, min_p=0.1, return_valid=True + ) + check_result(result, valid) + + # top_k_top_p_sampling_from_probs (joint mode) + result, valid = flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, top_k=50, top_p=0.9, filter_apply_order="joint", return_valid=True + ) + check_result(result, valid) + + if __name__ == "__main__": # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_sampling_from_logits_freq(128256, gumbel_distribution(0.1))