Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions csrc/flashinfer_sampling_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,42 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
Optional<TensorView> maybe_temperature_arr, double temperature_val, bool enable_pdl);

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);

void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_k_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void min_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, double top_k_val,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed,
uint64_t philox_offset);
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_p_arr, double top_p_val);
Expand All @@ -63,7 +73,8 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i
TensorView target_probs, TensorView output_token_ids,
TensorView output_accepted_token_num,
TensorView output_emitted_draft_token_num, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset);
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);

// Softmax
TVM_FFI_DLL_EXPORT_TYPED_FUNC(softmax, softmax);
Expand Down
120 changes: 104 additions & 16 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,28 @@ using namespace flashinfer;

using tvm::ffi::Optional;

// Helper function to validate seed/offset tensors for sampling operations
inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
const Optional<TensorView>& maybe_offset_arr,
const TensorView& reference_tensor) {
if (maybe_seed_arr.has_value()) {
CHECK_INPUT(maybe_seed_arr.value());
CHECK_DIM(1, maybe_seed_arr.value());
TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 ||
maybe_seed_arr.value().dtype() == dl_uint64)
<< "seed tensor must be int64 or uint64";
CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor);
}
if (maybe_offset_arr.has_value()) {
CHECK_INPUT(maybe_offset_arr.value());
CHECK_DIM(1, maybe_offset_arr.value());
TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 ||
maybe_offset_arr.value().dtype() == dl_uint64)
<< "offset tensor must be int64 or uint64";
CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor);
}
}
Comment on lines +25 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate seed/offset tensor lengths to prevent OOB reads.

The helper checks dtype/device/ndim but not length. A zero-length tensor or mismatched batch length will pass and kernels read element 0 (or bx), causing OOB. Please enforce length == 1 or == output batch size (use output.size(0) when indices is supplied).

πŸ”§ Suggested validation (and call-site update)
-inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
-                                         const Optional<TensorView>& maybe_offset_arr,
-                                         const TensorView& reference_tensor) {
+inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
+                                         const Optional<TensorView>& maybe_offset_arr,
+                                         const TensorView& reference_tensor,
+                                         int64_t batch_size) {
   if (maybe_seed_arr.has_value()) {
     CHECK_INPUT(maybe_seed_arr.value());
     CHECK_DIM(1, maybe_seed_arr.value());
+    TVM_FFI_ICHECK(maybe_seed_arr.value().size(0) == 1 ||
+                   maybe_seed_arr.value().size(0) == batch_size)
+        << "seed tensor length must be 1 or batch_size";
     TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 ||
                    maybe_seed_arr.value().dtype() == dl_uint64)
         << "seed tensor must be int64 or uint64";
     CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor);
   }
   if (maybe_offset_arr.has_value()) {
     CHECK_INPUT(maybe_offset_arr.value());
     CHECK_DIM(1, maybe_offset_arr.value());
+    TVM_FFI_ICHECK(maybe_offset_arr.value().size(0) == 1 ||
+                   maybe_offset_arr.value().size(0) == batch_size)
+        << "offset tensor length must be 1 or batch_size";
     TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 ||
                    maybe_offset_arr.value().dtype() == dl_uint64)
         << "offset tensor must be int64 or uint64";
     CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor);
   }
 }
-  validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits);
+  validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits, output.size(0));
πŸ€– Prompt for AI Agents
In `@csrc/sampling.cu` around lines 25 - 45, validate_seed_offset_tensors
currently checks dtype/ndim/device but not tensor length, allowing zero-length
or mismatched batch-length tensors which can cause OOB reads; update
validate_seed_offset_tensors to also check that when maybe_seed_arr/
maybe_offset_arr is present their size(0) is either 1 or equals the expected
batch size (pass the expected batch size or the output tensor and compare
against output.size(0) when indices are supplied), and update call sites that
invoke validate_seed_offset_tensors to supply the output tensor or explicit
batch size so the function can enforce length == 1 || length ==
expected_batch_size.


void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
Optional<TensorView> maybe_temperature_arr, double temperature_val, bool enable_pdl) {
CHECK_INPUT(workspace_buffer);
Expand All @@ -46,11 +68,15 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
}

void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(logits);
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = logits.size(1);

Expand All @@ -62,19 +88,28 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
static_cast<float*>(logits.data_ptr()), static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
return true;
});
}

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);

Expand All @@ -86,7 +121,13 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -96,11 +137,15 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
void top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);
check_tensor_param(maybe_top_p_arr, probs);
Expand All @@ -115,7 +160,13 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, top_p_val, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -125,14 +176,18 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_k_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);
check_tensor_param(maybe_top_k_arr, probs);
Expand All @@ -147,7 +202,13 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, top_k_val, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -157,14 +218,18 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
void min_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);
check_tensor_param(maybe_min_p_arr, probs);
Expand All @@ -180,7 +245,13 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, min_p_val, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -191,15 +262,18 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, double top_k_val,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed,
uint64_t philox_offset) {
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);

unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);
check_tensor_param(maybe_top_k_arr, probs);
Expand All @@ -218,8 +292,13 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
stream);
batch_size, top_k_val, top_p_val, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -230,12 +309,15 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i
TensorView target_probs, TensorView output_token_ids,
TensorView output_accepted_token_num,
TensorView output_emitted_draft_token_num, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset) {
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(target_probs);
CHECK_DEVICE(draft_token_ids, draft_probs);
CHECK_DEVICE(target_probs, draft_probs);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, draft_probs);

CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size)
CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens)
CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size)
Expand All @@ -256,7 +338,13 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i
static_cast<float*>(target_probs.data_ptr()), static_cast<int*>(output_token_ids.data_ptr()),
static_cast<int*>(output_accepted_token_num.data_ptr()),
static_cast<int*>(output_emitted_draft_token_num.data_ptr()), batch_size,
num_speculate_tokens, vocab_size, deterministic, philox_seed, philox_offset, stream);
num_speculate_tokens, vocab_size, deterministic,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
seed_val,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
offset_val, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "ChainSpeculativeSampling failed with error code " << cudaGetErrorString(status);
Expand Down
Loading
Loading