Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
}

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
Optional<TensorView> maybe_seed_arr,
Optional<TensorView> maybe_offset_arr) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
Expand All @@ -86,7 +88,12 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, vocab_size, deterministic, philox_seed, philox_offset,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -96,7 +103,9 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
void top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
Optional<TensorView> maybe_seed_arr,
Optional<TensorView> maybe_offset_arr) {
CHECK_INPUT(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
Expand All @@ -115,7 +124,12 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -125,7 +139,9 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
void top_k_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
Optional<TensorView> maybe_seed_arr,
Optional<TensorView> maybe_offset_arr) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
Expand All @@ -147,7 +163,12 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -157,7 +178,9 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
void min_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
Optional<TensorView> maybe_seed_arr,
Optional<TensorView> maybe_offset_arr) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
Expand All @@ -180,7 +203,12 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
static_cast<IdType*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
return true;
Expand All @@ -192,7 +220,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_top_k_arr, double top_k_val,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed,
uint64_t philox_offset) {
uint64_t philox_offset, Optional<TensorView> maybe_seed_arr,
Optional<TensorView> maybe_offset_arr) {
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
Expand All @@ -219,6 +248,10 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
: nullptr,
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
maybe_seed_arr.has_value() ? static_cast<uint64_t*>(maybe_seed_arr.value().data_ptr())
: nullptr,
maybe_offset_arr.has_value() ? static_cast<uint64_t*>(maybe_offset_arr.value().data_ptr())
: nullptr,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
Expand Down
Loading
Loading