diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index ccef79d35d..b9210e8b2d 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -625,36 +625,73 @@ __device__ void last_filter(T const* in_buf, IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; IdxT* p_equal = out_idx + k - num_of_kth_needed; - for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x) - { - const T value = in_buf[i]; - auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; - if(bits < kth_value_bits) + if(in_idx_buf) { + for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + const T value = in_buf[i]; + auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if(bits < kth_value_bits) { - out[pos] = value; + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf[i]; + } + else if(bits == kth_value_bits) + { + IdxT new_idx = in_idx_buf[i]; + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if(back_pos < num_of_kth_needed) + { + IdxT pos = k - 1 - back_pos; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + if constexpr(!prioritize_smaller_indice) + { + out_idx[pos] = new_idx; + } + } } - // For one-block version, `in_idx_buf` could be nullptr at pass 0. - // For non one-block version, if writing has been skipped, `in_idx_buf` - // could be nullptr if `in_buf` is `in` - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - else if(bits == kth_value_bits) + }else { + for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i; - IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); - if(back_pos < num_of_kth_needed) + const T value = in_buf[i]; + auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if(bits < kth_value_bits) { - IdxT pos = k - 1 - back_pos; + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); if(WRITE_TOPK_VALUES) { out[pos] = value; } - if constexpr(!prioritize_smaller_indice) + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = i; + } + else if(bits == kth_value_bits) + { + IdxT new_idx = i; + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if(back_pos < num_of_kth_needed) { - out_idx[pos] = new_idx; + IdxT pos = k - 1 - back_pos; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + if constexpr(!prioritize_smaller_indice) + { + out_idx[pos] = new_idx; + } } } } @@ -1155,29 +1192,57 @@ __device__ void filter_and_histogram_for_one_block(T const* in_buf, auto const kth_value_bits = counter->kth_value_bits; int const previous_start_bit = calc_start_bit(pass - 1); - for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) - { - const T value = in_buf[i]; - auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) - << previous_start_bit; - if(previous_bits == kth_value_bits) + if(in_idx_buf) { + for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if(previous_bits == kth_value_bits) + { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); - out_buf[pos] = value; - out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram + bucket, static_cast(1)); + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + else if(previous_bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf[i]; + } } - else if(previous_bits < kth_value_bits) + } else { + for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + const T value = in_buf[i]; + auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if(previous_bits == kth_value_bits) { - out[pos] = value; + + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + else if(previous_bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = i; } - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } }