diff --git a/csrc/flashinfer_sampling_binding.cu b/csrc/flashinfer_sampling_binding.cu index 5282ff480a..a3a6a08275 100644 --- a/csrc/flashinfer_sampling_binding.cu +++ b/csrc/flashinfer_sampling_binding.cu @@ -21,32 +21,42 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, Optional maybe_temperature_arr, double temperature_val, bool enable_pdl); void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, + Optional maybe_offset_arr, uint64_t offset_val); void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val); void top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val); void top_k_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val); void min_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val); void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, double top_k_val, Optional maybe_top_p_arr, double top_p_val, - bool deterministic, uint64_t philox_seed, - uint64_t philox_offset); + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val); void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_p_arr, double top_p_val); @@ -63,7 +73,8 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i TensorView target_probs, TensorView output_token_ids, TensorView output_accepted_token_num, TensorView output_emitted_draft_token_num, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset); + Optional maybe_seed_arr, uint64_t seed_val, + Optional maybe_offset_arr, uint64_t offset_val); // Softmax TVM_FFI_DLL_EXPORT_TYPED_FUNC(softmax, softmax); diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 6f59507991..45f6dd9634 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -22,6 +22,28 @@ using namespace flashinfer; using tvm::ffi::Optional; +// Helper function to validate seed/offset tensors for sampling operations +inline void validate_seed_offset_tensors(const Optional& maybe_seed_arr, + const Optional& maybe_offset_arr, + const TensorView& reference_tensor) { + if (maybe_seed_arr.has_value()) { + CHECK_INPUT(maybe_seed_arr.value()); + CHECK_DIM(1, maybe_seed_arr.value()); + TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 || + maybe_seed_arr.value().dtype() == dl_uint64) + << "seed tensor must be int64 or uint64"; + CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor); + } + if (maybe_offset_arr.has_value()) { + CHECK_INPUT(maybe_offset_arr.value()); + CHECK_DIM(1, maybe_offset_arr.value()); + TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 || + maybe_offset_arr.value().dtype() == dl_uint64) + << "offset tensor must be int64 or uint64"; + CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor); + } +} + void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, Optional maybe_temperature_arr, double temperature_val, bool enable_pdl) { CHECK_INPUT(workspace_buffer); @@ -46,11 +68,15 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, } void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val) { CHECK_INPUT(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits); + unsigned int batch_size = output.size(0); unsigned int vocab_size = logits.size(1); @@ -62,7 +88,13 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional(logits.data_ptr()), static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromLogits failed with error code " << cudaGetErrorString(status); return true; @@ -70,11 +102,14 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, uint64_t seed_val, + Optional maybe_offset_arr, uint64_t offset_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); + unsigned int batch_size = output.size(0); unsigned int vocab_size = probs.size(1); @@ -86,7 +121,13 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional(probs.data_ptr()), static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -96,11 +137,15 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); + unsigned int batch_size = output.size(0); unsigned int vocab_size = probs.size(1); check_tensor_param(maybe_top_p_arr, probs); @@ -115,7 +160,13 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, - batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, top_p_val, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -125,7 +176,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, void top_k_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -133,6 +186,8 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, CHECK_DIM(1, output); // output: (batch_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); + unsigned int batch_size = output.size(0); unsigned int vocab_size = probs.size(1); check_tensor_param(maybe_top_k_arr, probs); @@ -147,7 +202,13 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, - batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, top_k_val, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -157,7 +218,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, void min_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -165,6 +228,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DIM(1, output); // output: (batch_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); + unsigned int batch_size = output.size(0); unsigned int vocab_size = probs.size(1); check_tensor_param(maybe_min_p_arr, probs); @@ -180,7 +245,13 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, min_p_val, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status); return true; @@ -191,8 +262,9 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, double top_k_val, Optional maybe_top_p_arr, double top_p_val, - bool deterministic, uint64_t philox_seed, - uint64_t philox_offset) { + bool deterministic, Optional maybe_seed_arr, + uint64_t seed_val, Optional maybe_offset_arr, + uint64_t offset_val) { CHECK_INPUT(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); @@ -200,6 +272,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DIM(1, output); // output: (batch_size) CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64); CHECK_MAYBE_SAME_DTYPE(maybe_indices, output); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs); + unsigned int batch_size = output.size(0); unsigned int vocab_size = probs.size(1); check_tensor_param(maybe_top_k_arr, probs); @@ -218,8 +292,13 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, static_cast(output.data_ptr()), maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, - batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, - stream); + batch_size, top_k_val, top_p_val, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); return true; @@ -230,12 +309,15 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i TensorView target_probs, TensorView output_token_ids, TensorView output_accepted_token_num, TensorView output_emitted_draft_token_num, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset) { + Optional maybe_seed_arr, uint64_t seed_val, + Optional maybe_offset_arr, uint64_t offset_val) { CHECK_INPUT(draft_probs); CHECK_INPUT(draft_token_ids); CHECK_INPUT(target_probs); CHECK_DEVICE(draft_token_ids, draft_probs); CHECK_DEVICE(target_probs, draft_probs); + validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, draft_probs); + CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size) CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens) CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size) @@ -256,7 +338,13 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i static_cast(target_probs.data_ptr()), static_cast(output_token_ids.data_ptr()), static_cast(output_accepted_token_num.data_ptr()), static_cast(output_emitted_draft_token_num.data_ptr()), batch_size, - num_speculate_tokens, vocab_size, deterministic, philox_seed, philox_offset, stream); + num_speculate_tokens, vocab_size, deterministic, + maybe_seed_arr.has_value() ? static_cast(maybe_seed_arr.value().data_ptr()) + : nullptr, + seed_val, + maybe_offset_arr.has_value() ? static_cast(maybe_offset_arr.value().data_ptr()) + : nullptr, + offset_val, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "ChainSpeculativeSampling failed with error code " << cudaGetErrorString(status); diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 8571016a15..034109b510 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -93,8 +93,8 @@ def sampling_from_logits( indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = logits.device # TODO: support more data types in logits to avoid conversion @@ -107,13 +107,20 @@ def sampling_from_logits( seed, offset = get_seed_and_offset( batch_size * logits.size(1), generator, device ) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.sampling_from_logits( logits, samples, indices, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -136,8 +143,8 @@ def sampling_from_probs( indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -146,13 +153,20 @@ def sampling_from_probs( samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size, generator, device) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.sampling_from_probs( probs, samples, indices, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -179,8 +193,8 @@ def top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -192,6 +206,11 @@ def top_p_sampling_from_probs( samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.top_p_sampling_from_probs( probs, samples, @@ -199,8 +218,10 @@ def top_p_sampling_from_probs( maybe_top_p_arr, top_p_val, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -228,8 +249,8 @@ def top_k_sampling_from_probs( top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -239,6 +260,11 @@ def top_k_sampling_from_probs( samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.top_k_sampling_from_probs( probs, samples, @@ -246,8 +272,10 @@ def top_k_sampling_from_probs( maybe_top_k_arr, top_k_val, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -275,8 +303,8 @@ def min_p_sampling_from_probs( min_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -288,6 +316,11 @@ def min_p_sampling_from_probs( samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size, generator, device) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.min_p_sampling_from_probs( probs, samples, @@ -295,8 +328,10 @@ def min_p_sampling_from_probs( maybe_min_p_arr, min_p_val, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -312,8 +347,8 @@ def top_k_top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -326,6 +361,11 @@ def top_k_top_p_sampling_from_probs( samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset(batch_size * 32, generator, device) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.top_k_top_p_sampling_from_probs( probs, samples, @@ -335,8 +375,10 @@ def top_k_top_p_sampling_from_probs( maybe_top_p_arr, top_p_val, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return samples @@ -473,8 +515,8 @@ def chain_speculative_sampling( output_emitted_draft_token_num: torch.Tensor, deterministic: bool, generator: Optional[torch.Generator], - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: device = draft_probs.device draft_probs = draft_probs.float() @@ -484,10 +526,16 @@ def chain_speculative_sampling( output_emitted_draft_token_num = output_emitted_draft_token_num.int() b, n = draft_token_ids.shape output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device) + batch_size = b if seed is None or offset is None: seed, offset = get_seed_and_offset( draft_probs.size(0) * (draft_probs.size(1) + 1), generator, device ) + + maybe_seed_arr, seed_val, maybe_offset_arr, offset_val = ( + _validate_and_convert_seed_offset(seed, offset, device, batch_size) + ) + module.chain_speculative_sampling( draft_probs, draft_token_ids, @@ -496,8 +544,10 @@ def chain_speculative_sampling( output_accepted_token_num, output_emitted_draft_token_num, deterministic, - seed, - offset, + maybe_seed_arr, + seed_val, + maybe_offset_arr, + offset_val, ) return output_token_ids @@ -538,6 +588,67 @@ def _to_tensor_scalar_tuple(x): return (None, x) +def _validate_and_convert_seed_offset( + seed: Union[int, torch.Tensor], + offset: Union[int, torch.Tensor], + device: torch.device, + batch_size: int, +) -> Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int]: + """Validate and convert seed/offset to tensor/scalar tuples for sampling kernels. + + Parameters + ---------- + seed : Union[int, torch.Tensor] + Seed value or tensor. + offset : Union[int, torch.Tensor] + Offset value or tensor. + device : torch.device + Expected device for tensor inputs. + batch_size : int + Expected batch size for tensor length validation. + + Returns + ------- + Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int] + (maybe_seed_arr, seed_val, maybe_offset_arr, offset_val) + + Raises + ------ + ValueError + If seed and offset are not both tensors or both scalars, or if tensor + properties (device, dtype, ndim, size) are invalid. + """ + # Validate tensor/scalar consistency + if isinstance(seed, torch.Tensor) != isinstance(offset, torch.Tensor): + raise ValueError("seed and offset must both be tensors or both be scalars") + + # Convert to tensor/scalar tuple + maybe_seed_arr, seed_val = _to_tensor_scalar_tuple(seed) + maybe_offset_arr, offset_val = _to_tensor_scalar_tuple(offset) + + # Validate tensor properties + if maybe_seed_arr is not None: + if maybe_seed_arr.device != device: + raise ValueError(f"seed tensor must be on {device}") + if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]: + raise ValueError("seed tensor must be int64/uint64") + if maybe_seed_arr.ndim != 1: + raise ValueError("seed tensor must be 1D") + if maybe_seed_arr.size(0) not in [1, batch_size]: + raise ValueError(f"seed tensor length must be 1 or {batch_size}") + if maybe_offset_arr is not None: + if maybe_offset_arr.device != device: + raise ValueError(f"offset tensor must be on {device}") + if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]: + raise ValueError("offset tensor must be int64/uint64") + if maybe_offset_arr.ndim != 1: + raise ValueError("offset tensor must be 1D") + if maybe_offset_arr.size(0) not in [1, batch_size]: + raise ValueError(f"offset tensor length must be 1 or {batch_size}") + + return maybe_seed_arr, seed_val, maybe_offset_arr, offset_val + + @flashinfer_api def softmax( logits: torch.Tensor, @@ -603,8 +714,8 @@ def sampling_from_logits( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling from :attr:`logits` after applying softmax. @@ -629,10 +740,23 @@ def sampling_from_logits( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`logits`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- samples: torch.Tensor @@ -670,8 +794,8 @@ def sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from probabilities. @@ -695,10 +819,23 @@ def sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -744,8 +881,8 @@ def top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -778,10 +915,23 @@ def top_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -841,8 +991,8 @@ def top_k_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -875,10 +1025,23 @@ def top_k_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -938,8 +1101,8 @@ def min_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, @@ -973,10 +1136,23 @@ def min_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -1033,8 +1209,8 @@ def top_k_top_p_sampling_from_logits( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, @@ -1076,10 +1252,23 @@ def top_k_top_p_sampling_from_logits( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -1166,8 +1355,8 @@ def top_k_top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from probabilities, @@ -1209,10 +1398,23 @@ def top_k_top_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- @@ -1506,8 +1708,8 @@ def chain_speculative_sampling( maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, - seed: Optional[int] = None, - offset: Optional[int] = None, + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper `Accelerating Large Language Model Decoding with Speculative Sampling `_), @@ -1543,10 +1745,23 @@ def chain_speculative_sampling( Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. - seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. Common approaches include: + - Incrementing offset by the number of random values consumed + - Updating seed based on the number of calls to the operation + offset: Optional[Union[int, torch.Tensor]] + Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. + When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. + Using torch.Tensor is required for CUDA graph compatibility. + + Warning: If you provide seed and offset explicitly, you are responsible for updating + their values between calls to ensure different random samples. The offset should be + incremented based on the number of random values consumed by the operation. Returns ------- diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 1dddd107fe..7008208171 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -707,8 +707,14 @@ template __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; using SharedMem = typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage; @@ -748,9 +754,15 @@ template __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, + uint64_t offset_val) { curandStatePhilox4_32_10_t state; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; @@ -796,9 +808,15 @@ template __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, IdType* top_k_arr, uint32_t top_k_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -913,9 +931,15 @@ template __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, float* top_p_arr, float top_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; @@ -1024,8 +1048,14 @@ template __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdType* output, IdType* indices, float min_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + float p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx]; curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); @@ -1104,10 +1134,16 @@ template __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, float* top_p_arr, IdType* output, IdType* indices, IdType top_k_val, - float top_p_val, uint32_t d, uint64_t philox_seed, - uint64_t philox_offset) { + float top_p_val, uint32_t d, uint64_t* seed_arr, + uint64_t seed_val, uint64_t* offset_arr, + uint64_t offset_val) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; @@ -1362,15 +1398,16 @@ cudaError_t OnlineSoftmax(DType* logits, DType* output, uint32_t batch_size, uin template cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, cudaStream_t stream = 0) { + uint32_t d, bool deterministic, uint64_t* seed_arr, + uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val, + cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; + void* args[] = {&logits, &output, &indices, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; const uint32_t smem_size = sizeof( typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); @@ -1387,15 +1424,15 @@ cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint3 template cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, cudaStream_t stream = 0) { + uint32_t d, bool deterministic, uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; const uint32_t smem_size = sizeof(SamplingTempStorage); DISPATCH_ALIGNED_VEC_SIZE( @@ -1412,7 +1449,8 @@ cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t template cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + bool deterministic, uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1421,8 +1459,8 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, - &top_k_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val, + &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1440,8 +1478,8 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t template cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, - cudaStream_t stream = 0) { + uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, + uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); @@ -1449,8 +1487,8 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_p_arr, - &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_p_arr, &top_p_val, + &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1468,7 +1506,8 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t template cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* indices, uint32_t batch_size, float min_p_val, uint32_t d, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + bool deterministic, uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1477,8 +1516,8 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr, &output, &indices, - &min_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &min_p_arr, &output, &indices, &min_p_val, + &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1497,8 +1536,8 @@ template cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, IdType* indices, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, - cudaStream_t stream = 0) { + uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, + uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); @@ -1506,8 +1545,8 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, - &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, &top_k_val, + &top_p_val, &d, &seed_arr, &seed_val, &offset_arr, &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1750,9 +1789,15 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token IdType* output_accepted_token_num, IdType* output_emitted_draft_token_num, uint32_t num_speculative_tokens, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint64_t* seed_arr, uint64_t seed_val, + uint64_t* offset_arr, uint64_t offset_val) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; + + // Resolve seed/offset from tensor or scalar + uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; + uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; + curandStatePhilox4_32_10_t curand_state; curand_init(philox_seed, bx, philox_offset, &curand_state); @@ -1879,13 +1924,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token } template -cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids, - DType* target_probs, IdType* output_token_ids, - IdType* output_accepted_token_num, - IdType* output_emitted_draft_token_num, uint32_t batch_size, - uint32_t num_speculative_tokens, uint32_t d, - bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, cudaStream_t stream = 0) { +cudaError_t ChainSpeculativeSampling( + DType* draft_probs, IdType* draft_token_ids, DType* target_probs, IdType* output_token_ids, + IdType* output_accepted_token_num, IdType* output_emitted_draft_token_num, uint32_t batch_size, + uint32_t num_speculative_tokens, uint32_t d, bool deterministic, uint64_t* seed_arr, + uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); auto compute_capacity = GetCudaComputeCapability(); @@ -1901,8 +1944,10 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids &output_emitted_draft_token_num, &num_speculative_tokens, &d, - &philox_seed, - &philox_offset}; + &seed_arr, + &seed_val, + &offset_arr, + &offset_val}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { auto kernel = ChainSpeculativeSampling