Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions csrc/flashinfer_sampling_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,38 @@ using tvm::ffi::Optional;
void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
Optional<TensorView> maybe_temperature_arr, double temperature_val, bool enable_pdl);

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
void sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices, bool deterministic,
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val);

void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_k_sampling_from_probs(TensorView probs, TensorView output,
void top_k_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void min_p_sampling_from_probs(TensorView probs, TensorView output,
void min_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val);

void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, double top_k_val,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
Expand Down
35 changes: 27 additions & 8 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
});
}

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
void sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices, bool deterministic,
Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_INPUT(valid);
CHECK_DIM(1, valid);
CHECK_DEVICE(valid, probs);
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);
Expand All @@ -119,6 +123,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
cudaError_t status = sampling::SamplingFromProb<float, IdType>(
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
static_cast<bool*>(valid.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, vocab_size, deterministic,
Expand All @@ -134,14 +139,17 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
});
}

void top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
uint64_t seed_val, Optional<TensorView> maybe_offset_arr,
uint64_t offset_val) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_INPUT(valid);
CHECK_DIM(1, valid);
CHECK_DEVICE(valid, probs);
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);
Expand All @@ -157,6 +165,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
cudaError_t status = sampling::TopPSamplingFromProb<float, IdType>(
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
static_cast<bool*>(valid.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
Expand All @@ -173,7 +182,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
});
}

void top_k_sampling_from_probs(TensorView probs, TensorView output,
void top_k_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
Expand All @@ -184,6 +193,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_INPUT(valid);
CHECK_DIM(1, valid);
CHECK_DEVICE(valid, probs);
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);
Expand All @@ -199,6 +211,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
cudaError_t status = sampling::TopKSamplingFromProb<float, IdType>(
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
static_cast<bool*>(valid.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
Expand All @@ -215,7 +228,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
});
}

void min_p_sampling_from_probs(TensorView probs, TensorView output,
void min_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, Optional<TensorView> maybe_seed_arr,
Expand All @@ -226,6 +239,9 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_INPUT(valid);
CHECK_DIM(1, valid);
CHECK_DEVICE(valid, probs);
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);
Expand All @@ -242,7 +258,7 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
cudaError_t status = sampling::MinPSamplingFromProb<float, IdType>(
static_cast<float*>(probs.data_ptr()),
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr.value().data_ptr()) : nullptr,
static_cast<IdType*>(output.data_ptr()),
static_cast<IdType*>(output.data_ptr()), static_cast<bool*>(valid.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, min_p_val, vocab_size, deterministic,
Expand All @@ -258,7 +274,7 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
});
}

void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, double top_k_val,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
Expand All @@ -270,6 +286,9 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, output); // output: (batch_size)
CHECK_INPUT(valid);
CHECK_DIM(1, valid);
CHECK_DEVICE(valid, probs);
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, probs);
Expand All @@ -289,7 +308,7 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
static_cast<float*>(probs.data_ptr()),
has_top_k_arr ? static_cast<IdType*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
static_cast<IdType*>(output.data_ptr()),
static_cast<IdType*>(output.data_ptr()), static_cast<bool*>(valid.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, top_k_val, top_p_val, vocab_size, deterministic,
Expand Down
Loading
Loading