Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 100 additions & 35 deletions csrc/kernels/topk_per_row_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -625,36 +625,73 @@ __device__ void last_filter(T const* in_buf,
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;
IdxT* p_equal = out_idx + k - num_of_kth_needed;
for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if(bits < kth_value_bits)
if(in_idx_buf) {
for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
const T value = in_buf[i];
auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if(bits < kth_value_bits)
{
out[pos] = value;
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
// For one-block version, `in_idx_buf` could be nullptr at pass 0.
// For non one-block version, if writing has been skipped, `in_idx_buf`
// could be nullptr if `in_buf` is `in`
out_idx[pos] = in_idx_buf[i];
}
else if(bits == kth_value_bits)
{
IdxT new_idx = in_idx_buf[i];
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if(back_pos < num_of_kth_needed)
{
IdxT pos = k - 1 - back_pos;
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
if constexpr(!prioritize_smaller_indice)
{
out_idx[pos] = new_idx;
}
}
}
// For one-block version, `in_idx_buf` could be nullptr at pass 0.
// For non one-block version, if writing has been skipped, `in_idx_buf`
// could be nullptr if `in_buf` is `in`
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
else if(bits == kth_value_bits)
}else {
for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x)
{
IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i;
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if(back_pos < num_of_kth_needed)
const T value = in_buf[i];
auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if(bits < kth_value_bits)
{
IdxT pos = k - 1 - back_pos;
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
if constexpr(!prioritize_smaller_indice)
// For one-block version, `in_idx_buf` could be nullptr at pass 0.
// For non one-block version, if writing has been skipped, `in_idx_buf`
// could be nullptr if `in_buf` is `in`
out_idx[pos] = i;
}
else if(bits == kth_value_bits)
{
IdxT new_idx = i;
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if(back_pos < num_of_kth_needed)
{
out_idx[pos] = new_idx;
IdxT pos = k - 1 - back_pos;
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
if constexpr(!prioritize_smaller_indice)
{
out_idx[pos] = new_idx;
}
}
}
}
Expand Down Expand Up @@ -1155,29 +1192,57 @@ __device__ void filter_and_histogram_for_one_block(T const* in_buf,
auto const kth_value_bits = counter->kth_value_bits;
int const previous_start_bit = calc_start_bit<T, BitsPerPass>(pass - 1);

for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
if(in_idx_buf) {
for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
{

IdxT pos = atomicAdd(p_filter_cnt, static_cast<IdxT>(1));
out_buf[pos] = value;
out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i;
IdxT pos = atomicAdd(p_filter_cnt, static_cast<IdxT>(1));
out_buf[pos] = value;
out_idx_buf[pos] = in_idx_buf[i];

int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
else if(previous_bits < kth_value_bits)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
out_idx[pos] = in_idx_buf[i];
}
}
else if(previous_bits < kth_value_bits)
} else {
for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
{
out[pos] = value;

IdxT pos = atomicAdd(p_filter_cnt, static_cast<IdxT>(1));
out_buf[pos] = value;
out_idx_buf[pos] = i;

int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
else if(previous_bits < kth_value_bits)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
out_idx[pos] = i;
}
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
}
}
Expand Down
Loading