diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 6f59507991..63c5692faa 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -70,7 +70,9 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + Optional maybe_seed_arr, + Optional maybe_offset_arr) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); @@ -86,7 +88,12 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional(probs.data_ptr()), static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, vocab_size, deterministic, philox_seed, philox_offset, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -96,7 +103,9 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + Optional maybe_seed_arr, + Optional maybe_offset_arr) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); @@ -115,7 +124,12 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, - batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -125,7 +139,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, void top_k_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + Optional maybe_seed_arr, + Optional maybe_offset_arr) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -147,7 +163,12 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, - batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -157,7 +178,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, void min_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + Optional maybe_seed_arr, + Optional maybe_offset_arr) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -180,7 +203,12 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status); return true; @@ -192,7 +220,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_top_k_arr, double top_k_val, Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset) { + uint64_t philox_offset, Optional maybe_seed_arr, + Optional maybe_offset_arr) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -219,6 +248,10 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 8571016a15..e49ab611be 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -49,6 +49,47 @@ def get_seed_and_offset( return int(seed), int(offset) +def _validate_per_request_generator( + seed_arr: torch.Tensor, + offset_arr: torch.Tensor, + batch_size: int, +) -> None: + """Validate per-request generator tensors. + + Parameters + ---------- + seed_arr : torch.Tensor + Seed array tensor + offset_arr : torch.Tensor + Offset array tensor + batch_size : int + Expected batch size + + Raises + ------ + TypeError + If tensors are not int64 + ValueError + If tensors are not on CUDA device or have incorrect shape + """ + if seed_arr.dtype != torch.int64: + raise TypeError(f"seed_arr must be int64 tensor, got {seed_arr.dtype}") + if offset_arr.dtype != torch.int64: + raise TypeError(f"offset_arr must be int64 tensor, got {offset_arr.dtype}") + if not seed_arr.is_cuda: + raise ValueError(f"seed_arr must be on CUDA device, got {seed_arr.device}") + if not offset_arr.is_cuda: + raise ValueError(f"offset_arr must be on CUDA device, got {offset_arr.device}") + if seed_arr.shape != (batch_size,): + raise ValueError( + f"seed_arr must have shape ({batch_size},), got {seed_arr.shape}" + ) + if offset_arr.shape != (batch_size,): + raise ValueError( + f"offset_arr must have shape ({batch_size},), got {offset_arr.shape}" + ) + + @functools.cache def get_sampling_module(): module = gen_sampling_module().build_and_load() @@ -144,16 +185,35 @@ def sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) - if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size, generator, device) - module.sampling_from_probs( - probs, - samples, - indices, - deterministic, - seed, - offset, - ) + + # Check if generator is a tuple of tensors (per-request generators) + if isinstance(generator, tuple): + seed_arr, offset_arr = generator + _validate_per_request_generator(seed_arr, offset_arr, batch_size) + module.sampling_from_probs( + probs, + samples, + indices, + deterministic, + 0, + 0, # scalar seed/offset (ignored when arrays provided) + seed_arr, + offset_arr, + ) + else: + # Traditional single generator path + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size, generator, device) + module.sampling_from_probs( + probs, + samples, + indices, + deterministic, + seed, + offset, + None, + None, # no per-request generators + ) return samples # torch library for sampling_from_probs @@ -190,18 +250,39 @@ def top_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) - if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator, device) - module.top_p_sampling_from_probs( - probs, - samples, - indices, - maybe_top_p_arr, - top_p_val, - deterministic, - seed, - offset, - ) + + # Check if generator is a tuple of tensors (per-request generators) + if isinstance(generator, tuple): + seed_arr, offset_arr = generator + _validate_per_request_generator(seed_arr, offset_arr, batch_size) + module.top_p_sampling_from_probs( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + deterministic, + 0, + 0, # scalar seed/offset (ignored when arrays provided) + seed_arr, + offset_arr, + ) + else: + # Traditional single generator path + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + module.top_p_sampling_from_probs( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + deterministic, + seed, + offset, + None, + None, # no per-request generators + ) return samples @register_fake_op("flashinfer::top_p_sampling_from_probs") @@ -237,18 +318,39 @@ def top_k_sampling_from_probs( maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) - if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator, device) - module.top_k_sampling_from_probs( - probs, - samples, - indices, - maybe_top_k_arr, - top_k_val, - deterministic, - seed, - offset, - ) + + # Check if generator is a tuple of tensors (per-request generators) + if isinstance(generator, tuple): + seed_arr, offset_arr = generator + _validate_per_request_generator(seed_arr, offset_arr, batch_size) + module.top_k_sampling_from_probs( + probs, + samples, + indices, + maybe_top_k_arr, + top_k_val, + deterministic, + 0, + 0, # scalar seed/offset (ignored when arrays provided) + seed_arr, + offset_arr, + ) + else: + # Traditional single generator path + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + module.top_k_sampling_from_probs( + probs, + samples, + indices, + maybe_top_k_arr, + top_k_val, + deterministic, + seed, + offset, + None, + None, # no per-request generators + ) return samples @register_fake_op("flashinfer::top_k_sampling_from_probs") @@ -286,18 +388,39 @@ def min_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) - if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size, generator, device) - module.min_p_sampling_from_probs( - probs, - samples, - indices, - maybe_min_p_arr, - min_p_val, - deterministic, - seed, - offset, - ) + + # Check if generator is a tuple of tensors (per-request generators) + if isinstance(generator, tuple): + seed_arr, offset_arr = generator + _validate_per_request_generator(seed_arr, offset_arr, batch_size) + module.min_p_sampling_from_probs( + probs, + samples, + indices, + maybe_min_p_arr, + min_p_val, + deterministic, + 0, + 0, # scalar seed/offset (ignored when arrays provided) + seed_arr, + offset_arr, + ) + else: + # Traditional single generator path + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size, generator, device) + module.min_p_sampling_from_probs( + probs, + samples, + indices, + maybe_min_p_arr, + min_p_val, + deterministic, + seed, + offset, + None, + None, # no per-request generators + ) return samples # torch library for top_k_top_p_sampling_from_probs @@ -324,20 +447,43 @@ def top_k_top_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) - if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator, device) - module.top_k_top_p_sampling_from_probs( - probs, - samples, - indices, - maybe_top_k_arr, - top_k_val, - maybe_top_p_arr, - top_p_val, - deterministic, - seed, - offset, - ) + + # Check if generator is a tuple of tensors (per-request generators) + if isinstance(generator, tuple): + seed_arr, offset_arr = generator + _validate_per_request_generator(seed_arr, offset_arr, batch_size) + module.top_k_top_p_sampling_from_probs( + probs, + samples, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + 0, + 0, # scalar seed/offset (ignored when arrays provided) + seed_arr, + offset_arr, + ) + else: + # Traditional single generator path + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + module.top_k_top_p_sampling_from_probs( + probs, + samples, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + seed, + offset, + None, + None, # no per-request generators + ) return samples @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") @@ -668,7 +814,9 @@ def sampling_from_probs( probs: torch.Tensor, indices: Optional[torch.Tensor] = None, deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -691,8 +839,38 @@ def sampling_from_probs( and output dtype defaults to ``torch.int32``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + + **Per-request generator behavior:** + + - Each request uses its own seed from ``seed_arr[i]`` + - Offsets track RNG state and are automatically updated in-place after sampling + - Each RNG call consumes 4 values, so offsets increment by 4 per call + - For iterative samplers (top_p, top_k, top_k_top_p), offsets increment by ``4 * num_rounds`` + where ``num_rounds`` varies based on the sampling algorithm + - Sequential calls with the same generator tuple will use updated offsets (cumulative) + + **Example with per-request generators:** + + .. code-block:: python + + # Create per-request generators + seed_arr = torch.randint(0, 2**32, (batch_size,), dtype=torch.int64, device="cuda") + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda") + + # First sampling call + samples1 = sampling_from_probs(probs1, generator=(seed_arr, offset_arr)) + # offset_arr is now [4, 4, 4, ...] (automatically updated) + + # Second sampling call reuses same generators with updated offsets + samples2 = sampling_from_probs(probs2, generator=(seed_arr, offset_arr)) + # offset_arr is now [8, 8, 8, ...] (cumulative) + check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] @@ -742,7 +920,9 @@ def top_p_sampling_from_probs( top_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -774,8 +954,12 @@ def top_p_sampling_from_probs( and output dtype defaults to ``torch.int32``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + Offsets are automatically updated in-place after sampling. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] @@ -839,7 +1023,9 @@ def top_k_sampling_from_probs( top_k: Union[torch.Tensor, int], indices: Optional[torch.Tensor] = None, deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -871,8 +1057,12 @@ def top_k_sampling_from_probs( and output dtype defaults to ``torch.int32``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + Offsets are automatically updated in-place after sampling. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] @@ -936,7 +1126,9 @@ def min_p_sampling_from_probs( min_p: Union[torch.Tensor, float], indices: Optional[torch.Tensor] = None, deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -969,8 +1161,12 @@ def min_p_sampling_from_probs( and output dtype defaults to ``torch.int32``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + Offsets are automatically updated in-place after sampling. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] @@ -1031,7 +1227,9 @@ def top_k_top_p_sampling_from_logits( indices: Optional[torch.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -1072,8 +1270,12 @@ def top_k_top_p_sampling_from_logits( If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + Offsets are automatically updated in-place after sampling. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] @@ -1164,7 +1366,9 @@ def top_k_top_p_sampling_from_probs( indices: Optional[torch.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, - generator: Optional[torch.Generator] = None, + generator: Optional[ + Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]] + ] = None, check_nan: bool = False, seed: Optional[int] = None, offset: Optional[int] = None, @@ -1205,8 +1409,12 @@ def top_k_top_p_sampling_from_probs( If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. - generator: Optional[torch.Generator] - A random number generator for the operation. + generator: Optional[Union[torch.Generator, Tuple[torch.Tensor, torch.Tensor]]] + Random number generator. Can be either: + - A ``torch.Generator`` for traditional single-generator sampling (default) + - A tuple of ``(seed_arr, offset_arr)`` tensors for per-request generators, + where both are int64 tensors of shape ``(batch_size,)`` on CUDA. + Offsets are automatically updated in-place after sampling. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. seed: Optional[int] diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6bd7881b73..d9773948b5 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -742,10 +742,20 @@ template __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, + uint64_t* offset_arr = nullptr) { curandStatePhilox4_32_10_t state; const uint32_t bx = blockIdx.x, tx = threadIdx.x; - curand_init(philox_seed, bx, philox_offset, &state); + + // Use per-request seed/offset if arrays provided, otherwise use scalar values + uint64_t seed = (seed_arr != nullptr) ? seed_arr[bx] : philox_seed; + uint64_t offset = (offset_arr != nullptr) ? offset_arr[bx] : philox_offset; + // When using per-request seeds, subsequence should be 0 since bx is already incorporated via + // seed_arr[bx] When using scalar seed, use bx as subsequence to differentiate between blocks + uint64_t subsequence = (seed_arr != nullptr) ? 0 : bx; + + curand_init(seed, subsequence, offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; extern __shared__ __align__( @@ -783,6 +793,12 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind sampled_id = temp_storage.last_valid_id; } output[bx] = sampled_id; + + // Atomically update offset if using per-request generators + // Each curand_uniform call consumes 4 values from the RNG state + if (tx == 0 && offset_arr != nullptr) { + atomicAdd(&offset_arr[bx], 4ULL); + } } template __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, IdType* top_k_arr, uint32_t top_k_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, + uint64_t* offset_arr = nullptr) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Use per-request seed/offset if arrays provided, otherwise use scalar values + uint64_t seed = (seed_arr != nullptr) ? seed_arr[bx] : philox_seed; + uint64_t offset = (offset_arr != nullptr) ? offset_arr[bx] : philox_offset; + // When using per-request seeds, subsequence should be 0 since bx is already incorporated via + // seed_arr[bx] When using scalar seed, use bx as subsequence to differentiate between blocks + uint64_t subsequence = (seed_arr != nullptr) ? 0 : bx; + curandStatePhilox4_32_10_t state; - curand_init(philox_seed, bx, philox_offset, &state); + curand_init(seed, subsequence, offset, &state); const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; @@ -899,6 +925,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); if (tx == 0) { output[bx] = sampled_id; + + // Atomically update offset if using per-request generators + // TopK sampling calls curand_uniform once per round (variable count) + // We increment by round * 4 since each call consumes 4 values + // Note: All threads converge to same round count due to __syncthreads() in loop + if (offset_arr != nullptr) { + atomicAdd(&offset_arr[bx], static_cast(round * 4)); + } } } @@ -907,11 +941,21 @@ template __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, float* top_p_arr, float top_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, + uint64_t* offset_arr = nullptr) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Use per-request seed/offset if arrays provided, otherwise use scalar values + uint64_t seed = (seed_arr != nullptr) ? seed_arr[bx] : philox_seed; + uint64_t offset = (offset_arr != nullptr) ? offset_arr[bx] : philox_offset; + // When using per-request seeds, subsequence should be 0 since bx is already incorporated via + // seed_arr[bx] When using scalar seed, use bx as subsequence to differentiate between blocks + uint64_t subsequence = (seed_arr != nullptr) ? 0 : bx; + curandStatePhilox4_32_10_t state; - curand_init(philox_seed, bx, philox_offset, &state); + curand_init(seed, subsequence, offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[row_idx]; @@ -927,7 +971,9 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* float q = 1; double low = 0, high = 1.f; int sampled_id; + int round = 0; do { + round += 1; temp_storage.sampled_id = d; __syncthreads(); float u = curand_uniform(&state) * q; @@ -1010,6 +1056,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); if (tx == 0) { output[bx] = sampled_id; + + // Atomically update offset if using per-request generators + // TopP sampling calls curand_uniform once per round (variable count) + // We increment by round * 4 since each call consumes 4 values + // Note: All threads converge to same round count due to __syncthreads() in loop + if (offset_arr != nullptr) { + atomicAdd(&offset_arr[bx], static_cast(round * 4)); + } } } @@ -1018,11 +1072,21 @@ template __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdType* output, IdType* indices, float min_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, + uint64_t* offset_arr = nullptr) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; float p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx]; + + // Use per-request seed/offset if arrays provided, otherwise use scalar values + uint64_t seed = (seed_arr != nullptr) ? seed_arr[bx] : philox_seed; + uint64_t offset = (offset_arr != nullptr) ? offset_arr[bx] : philox_offset; + // When using per-request seeds, subsequence should be 0 since bx is already incorporated via + // seed_arr[bx] When using scalar seed, use bx as subsequence to differentiate between blocks + uint64_t subsequence = (seed_arr != nullptr) ? 0 : bx; + curandStatePhilox4_32_10_t state; - curand_init(philox_seed, bx, philox_offset, &state); + curand_init(seed, subsequence, offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; extern __shared__ __align__( @@ -1091,6 +1155,12 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp sampled_id = temp_storage.last_valid_id; } output[bx] = sampled_id; + + // Atomically update offset if using per-request generators + // MinP sampling calls curand_uniform once + if (tx == 0 && offset_arr != nullptr) { + atomicAdd(&offset_arr[bx], 4ULL); + } } template (round * 4)); + } } } @@ -1382,14 +1471,16 @@ cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint3 template cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, cudaStream_t stream = 0) { + uint64_t philox_offset, uint64_t* seed_arr = nullptr, + uint64_t* offset_arr = nullptr, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &d, + &philox_seed, &philox_offset, &seed_arr, &offset_arr}; const uint32_t smem_size = sizeof(SamplingTempStorage); DISPATCH_ALIGNED_VEC_SIZE( @@ -1407,6 +1498,7 @@ template cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, uint64_t* offset_arr = nullptr, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1415,8 +1507,8 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, - &top_k_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val, + &d, &philox_seed, &philox_offset, &seed_arr, &offset_arr}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1435,6 +1527,7 @@ template cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, uint64_t* offset_arr = nullptr, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1443,8 +1536,8 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_p_arr, - &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_p_arr, &top_p_val, + &d, &philox_seed, &philox_offset, &seed_arr, &offset_arr}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1463,6 +1556,7 @@ template cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* indices, uint32_t batch_size, float min_p_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, uint64_t* offset_arr = nullptr, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1471,8 +1565,8 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr, &output, &indices, - &min_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &min_p_arr, &output, &indices, &min_p_val, + &d, &philox_seed, &philox_offset, &seed_arr, &offset_arr}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1492,6 +1586,7 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* indices, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + uint64_t* seed_arr = nullptr, uint64_t* offset_arr = nullptr, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1500,8 +1595,8 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, - &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, &top_k_val, + &top_p_val, &d, &philox_seed, &philox_offset, &seed_arr, &offset_arr}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index fb1c015c7e..3aa4fa5813 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -1069,6 +1069,219 @@ def test_int64_indices_sampling(batch_size, vocab_size, sampling_type, indices_d assert torch.all(samples < vocab_size) and torch.all(samples >= 0) +@pytest.mark.parametrize("batch_size", [1, 32, 128]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_per_request_generator_reproducibility(batch_size, vocab_size): + """Test that per-request generators produce reproducible results.""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + # Create per-request generator tensors + seed_arr = torch.randint( + 0, 2**32, (batch_size,), dtype=torch.int64, device="cuda:0" + ) + offset_arr1 = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + offset_arr2 = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + + # Same seeds and offsets should produce identical samples + samples1 = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr1) + ) + samples2 = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr.clone(), offset_arr2) + ) + + assert torch.all(samples1 == samples2), ( + "Per-request generators with same seeds should produce identical samples" + ) + + +@pytest.mark.parametrize("batch_size", [8, 32]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_per_request_generator_independence(batch_size, vocab_size): + """Test that different per-request seeds produce different samples.""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + # Different seeds for each request + seed_arr1 = torch.arange(batch_size, dtype=torch.int64, device="cuda:0") + seed_arr2 = torch.arange(batch_size, dtype=torch.int64, device="cuda:0") + 1000 + offset_arr1 = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + offset_arr2 = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + + samples1 = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr1, offset_arr1) + ) + samples2 = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr2, offset_arr2) + ) + + # Different seeds should produce mostly different samples + match_rate = (samples1 == samples2).float().mean().item() + assert match_rate < 0.9, ( + f"Different per-request seeds should produce mostly different samples, " + f"got {match_rate:.2%} match rate" + ) + + +@pytest.mark.parametrize("batch_size", [8, 32]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +@pytest.mark.parametrize( + "sampling_func", + [ + "sampling_from_probs", + "top_p_sampling_from_probs", + "top_k_sampling_from_probs", + "min_p_sampling_from_probs", + "top_k_top_p_sampling_from_probs", + ], +) +def test_per_request_generator_offset_update(batch_size, vocab_size, sampling_func): + """Test that offset_arr is correctly updated after sampling.""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + seed_arr = torch.randint( + 0, 2**32, (batch_size,), dtype=torch.int64, device="cuda:0" + ) + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + + # Call sampling function + if sampling_func == "sampling_from_probs": + flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr) + ) + # Simple samplers: fixed increment of 4 + assert torch.all(offset_arr == 4), ( + f"{sampling_func}: Offsets should be exactly 4, got {offset_arr}" + ) + elif sampling_func == "top_p_sampling_from_probs": + flashinfer.sampling.top_p_sampling_from_probs( + normalized_prob, 0.9, generator=(seed_arr, offset_arr) + ) + # Iterative samplers: variable increments (multiples of 4) + assert torch.all(offset_arr > 0), ( + f"{sampling_func}: All offsets should be updated (> 0)" + ) + assert torch.all(offset_arr % 4 == 0), ( + f"{sampling_func}: All offsets should be multiples of 4" + ) + elif sampling_func == "top_k_sampling_from_probs": + k = min(100, vocab_size) + flashinfer.sampling.top_k_sampling_from_probs( + normalized_prob, k, generator=(seed_arr, offset_arr) + ) + # Iterative samplers: variable increments (multiples of 4) + assert torch.all(offset_arr > 0), ( + f"{sampling_func}: All offsets should be updated (> 0)" + ) + assert torch.all(offset_arr % 4 == 0), ( + f"{sampling_func}: All offsets should be multiples of 4" + ) + elif sampling_func == "min_p_sampling_from_probs": + flashinfer.sampling.min_p_sampling_from_probs( + normalized_prob, 0.1, generator=(seed_arr, offset_arr) + ) + # Simple sampler: fixed increment of 4 + assert torch.all(offset_arr == 4), ( + f"{sampling_func}: Offsets should be exactly 4, got {offset_arr}" + ) + elif sampling_func == "top_k_top_p_sampling_from_probs": + k = min(100, vocab_size) + flashinfer.sampling.top_k_top_p_sampling_from_probs( + normalized_prob, + k, + 0.9, + generator=(seed_arr, offset_arr), + filter_apply_order="joint", + ) + # Iterative samplers: variable increments (multiples of 4) + assert torch.all(offset_arr > 0), ( + f"{sampling_func}: All offsets should be updated (> 0)" + ) + assert torch.all(offset_arr % 4 == 0), ( + f"{sampling_func}: All offsets should be multiples of 4" + ) + + +@pytest.mark.parametrize("batch_size", [8, 32]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_per_request_generator_validation(batch_size, vocab_size): + """Test that invalid per-request generator inputs are rejected.""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + # Test 1: Wrong dtype (float instead of int64) + seed_arr = torch.randint( + 0, 2**32, (batch_size,), dtype=torch.float32, device="cuda:0" + ) + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + with pytest.raises( + TypeError, match="seed_arr and offset_arr must be int64 tensors" + ): + flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr) + ) + + # Test 2: Wrong device (CPU instead of CUDA) + seed_arr = torch.randint(0, 2**32, (batch_size,), dtype=torch.int64, device="cpu") + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + with pytest.raises( + ValueError, match="seed_arr and offset_arr must be on CUDA device" + ): + flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr) + ) + + # Test 3: Wrong shape + seed_arr = torch.randint( + 0, 2**32, (batch_size + 1,), dtype=torch.int64, device="cuda:0" + ) + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + with pytest.raises(ValueError, match="seed_arr and offset_arr must have shape"): + flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr) + ) + + +@pytest.mark.parametrize("batch_size", [8, 32]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_per_request_generator_vs_traditional(batch_size, vocab_size): + """Test that per-request generator produces valid samples (no correctness comparison).""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + # Per-request generators + seed_arr = torch.randint( + 0, 2**32, (batch_size,), dtype=torch.int64, device="cuda:0" + ) + offset_arr = torch.zeros(batch_size, dtype=torch.int64, device="cuda:0") + samples_per_request = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=(seed_arr, offset_arr) + ) + + # Traditional generator + gen = torch.Generator("cuda:0") + gen.manual_seed(42) + samples_traditional = flashinfer.sampling.sampling_from_probs( + normalized_prob, generator=gen + ) + + # Both should produce valid samples + assert torch.all(samples_per_request < vocab_size) and torch.all( + samples_per_request >= 0 + ) + assert torch.all(samples_traditional < vocab_size) and torch.all( + samples_traditional >= 0 + ) + # We don't expect them to match since they use different RNG mechanisms + + @pytest.mark.parametrize("batch_size", [1, 19, 99]) @pytest.mark.parametrize("vocab_size", [111, 32000]) def test_sampling_with_default_device_cuda(batch_size, vocab_size):