Skip to content

Commit

Permalink
Merge branch 'master' into dev_0830
Browse files Browse the repository at this point in the history
  • Loading branch information
junpeng0715 authored Sep 5, 2022
2 parents 044b4a9 + 649ef60 commit e1b07b8
Show file tree
Hide file tree
Showing 33 changed files with 1,880 additions and 208 deletions.
2 changes: 1 addition & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
lgb.dump <- function(booster, num_iteration = NULL) {

if (!lgb.is.Booster(x = booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
}

# Return booster at requested iteration
Expand Down
8 changes: 4 additions & 4 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Dataset <- R6::R6Class(
# Provided indices, but some indices are missing?
if (sum(is.na(cate_indices)) > 0L) {
stop(
"lgb.self.get.handle: supplied an unknown feature in categorical_feature: "
"lgb.Dataset.construct: supplied an unknown feature in categorical_feature: "
, sQuote(private$categorical_feature[is.na(cate_indices)])
)
}
Expand All @@ -172,7 +172,7 @@ Dataset <- R6::R6Class(
data_is_not_filename <- !is.character(private$raw_data)
if (data_is_not_filename && max(private$categorical_feature) > ncol(private$raw_data)) {
stop(
"lgb.self.get.handle: supplied a too large value in categorical_feature: "
"lgb.Dataset.construct: supplied a too large value in categorical_feature: "
, max(private$categorical_feature)
, " but only "
, ncol(private$raw_data)
Expand Down Expand Up @@ -1250,11 +1250,11 @@ lgb.Dataset.set.reference <- function(dataset, reference) {
lgb.Dataset.save <- function(dataset, fname) {

if (!lgb.is.Dataset(x = dataset)) {
stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
stop("lgb.Dataset.save: input dataset should be an lgb.Dataset object")
}

if (!is.character(fname)) {
stop("lgb.Dataset.set: fname should be a character or a file connection")
stop("lgb.Dataset.save: fname should be a character or a file connection")
}

return(invisible(dataset$save_binary(fname = fname)))
Expand Down
2 changes: 1 addition & 1 deletion docs/Features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Optimization in Network Communication
-------------------------------------

It only needs to use some collective communication algorithms, like "All reduce", "All gather" and "Reduce scatter", in distributed learning of LightGBM.
LightGBM implements state-of-art algorithms\ `[9] <#references>`__.
LightGBM implements state-of-the-art algorithms\ `[9] <#references>`__.
These collective communication algorithms can provide much better performance than point-to-point communication.

.. _Optimization in Parallel Learning:
Expand Down
169 changes: 163 additions & 6 deletions include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@

#include <algorithm>

#define NUM_BANKS_DATA_PARTITION (16)
#define LOG_NUM_BANKS_DATA_PARTITION (4)
#define GLOBAL_PREFIX_SUM_BLOCK_SIZE (1024)

#define BITONIC_SORT_NUM_ELEMENTS (1024)
#define BITONIC_SORT_DEPTH (11)
#define BITONIC_SORT_QUERY_ITEM_BLOCK_SIZE (10)

#define CONFLICT_FREE_INDEX(n) \
((n) + ((n) >> LOG_NUM_BANKS_DATA_PARTITION)) \

namespace LightGBM {

template <typename T>
Expand Down Expand Up @@ -107,6 +101,9 @@ __device__ __forceinline__ T ShufflePrefixSumExclusive(T value, T* shared_mem_bu
template <typename T>
void ShufflePrefixSumGlobal(T* values, size_t len, T* block_prefix_sum_buffer);

template <typename VAL_T, typename REDUCE_T, typename INDEX_T>
void GlobalInclusiveArgPrefixSum(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n);

template <typename T>
__device__ __forceinline__ T ShuffleReduceSumWarp(T value, const data_size_t len) {
if (len > 0) {
Expand Down Expand Up @@ -220,6 +217,54 @@ __device__ __forceinline__ void BitonicArgSort_1024(const VAL_T* scores, INDEX_T
}
}

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__device__ __forceinline__ void BitonicArgSort_2048(const VAL_T* scores, INDEX_T* indices) {
for (INDEX_T base = 0; base < 2048; base += 1024) {
for (INDEX_T outer_depth = 10; outer_depth >= 1; --outer_depth) {
const INDEX_T outer_segment_length = 1 << (11 - outer_depth);
const INDEX_T outer_segment_index = threadIdx.x / outer_segment_length;
const bool ascending = ((base == 0) ^ ASCENDING) ? (outer_segment_index % 2 > 0) : (outer_segment_index % 2 == 0);
for (INDEX_T inner_depth = outer_depth; inner_depth < 11; ++inner_depth) {
const INDEX_T segment_length = 1 << (11 - inner_depth);
const INDEX_T half_segment_length = segment_length >> 1;
const INDEX_T half_segment_index = threadIdx.x / half_segment_length;
if (half_segment_index % 2 == 0) {
const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base;
if ((scores[indices[threadIdx.x + base]] > scores[indices[index_to_compare]]) == ascending) {
const INDEX_T index = indices[threadIdx.x + base];
indices[threadIdx.x + base] = indices[index_to_compare];
indices[index_to_compare] = index;
}
}
__syncthreads();
}
}
}
const unsigned int index_to_compare = threadIdx.x + 1024;
if (scores[indices[index_to_compare]] > scores[indices[threadIdx.x]]) {
const INDEX_T temp_index = indices[index_to_compare];
indices[index_to_compare] = indices[threadIdx.x];
indices[threadIdx.x] = temp_index;
}
__syncthreads();
for (INDEX_T base = 0; base < 2048; base += 1024) {
for (INDEX_T inner_depth = 1; inner_depth < 11; ++inner_depth) {
const INDEX_T segment_length = 1 << (11 - inner_depth);
const INDEX_T half_segment_length = segment_length >> 1;
const INDEX_T half_segment_index = threadIdx.x / half_segment_length;
if (half_segment_index % 2 == 0) {
const INDEX_T index_to_compare = threadIdx.x + half_segment_length + base;
if (scores[indices[threadIdx.x + base]] < scores[indices[index_to_compare]]) {
const INDEX_T index = indices[threadIdx.x + base];
indices[threadIdx.x + base] = indices[index_to_compare];
indices[index_to_compare] = index;
}
}
__syncthreads();
}
}
}

template <typename VAL_T, typename INDEX_T, bool ASCENDING, uint32_t BLOCK_DIM, uint32_t MAX_DEPTH>
__device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, const int len) {
__shared__ VAL_T shared_values[BLOCK_DIM];
Expand Down Expand Up @@ -384,6 +429,118 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons
}
}

void BitonicArgSortItemsGlobal(
const double* scores,
const int num_queries,
const data_size_t* cuda_query_boundaries,
data_size_t* out_indices);

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
void BitonicArgSortGlobal(const VAL_T* values, INDEX_T* indices, const size_t len);

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer);

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobal(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer);

template <typename VAL_T, typename REDUCE_VAL_T, typename INDEX_T>
__device__ void ShuffleSortedPrefixSumDevice(const VAL_T* in_values,
const INDEX_T* sorted_indices,
REDUCE_VAL_T* out_values,
const INDEX_T num_data) {
__shared__ REDUCE_VAL_T shared_buffer[32];
const INDEX_T num_data_per_thread = (num_data + static_cast<INDEX_T>(blockDim.x) - 1) / static_cast<INDEX_T>(blockDim.x);
const INDEX_T start = num_data_per_thread * static_cast<INDEX_T>(threadIdx.x);
const INDEX_T end = min(start + num_data_per_thread, num_data);
REDUCE_VAL_T thread_sum = 0;
for (INDEX_T index = start; index < end; ++index) {
thread_sum += static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
thread_sum = ShufflePrefixSumExclusive<REDUCE_VAL_T>(thread_sum, shared_buffer);
const REDUCE_VAL_T thread_base = shared_buffer[threadIdx.x];
for (INDEX_T index = start; index < end; ++index) {
out_values[index] = thread_base + static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
__global__ void PercentileGlobalKernel(const VAL_T* values,
const WEIGHT_T* weights,
const INDEX_T* sorted_indices,
const WEIGHT_REDUCE_T* weights_prefix_sum,
const double alpha,
const INDEX_T len,
VAL_T* out_value) {
if (!USE_WEIGHT) {
const double float_pos = (1.0f - alpha) * len;
const INDEX_T pos = static_cast<INDEX_T>(float_pos);
if (pos < 1) {
*out_value = values[sorted_indices[0]];
} else if (pos >= len) {
*out_value = values[sorted_indices[len - 1]];
} else {
const double bias = float_pos - static_cast<double>(pos);
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * bias);
}
} else {
const WEIGHT_REDUCE_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha);
__shared__ INDEX_T pos;
if (threadIdx.x == 0) {
pos = len;
}
__syncthreads();
for (INDEX_T index = static_cast<INDEX_T>(threadIdx.x); index < len; index += static_cast<INDEX_T>(blockDim.x)) {
if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) {
pos = index;
}
}
__syncthreads();
pos = min(pos, len - 1);
if (pos == 0 || pos == len - 1) {
*out_value = values[pos];
}
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1]));
}
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
void PercentileGlobal(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
WEIGHT_REDUCE_T* weights_prefix_sum,
WEIGHT_REDUCE_T* weights_prefix_sum_buffer,
const double alpha,
const INDEX_T len,
VAL_T* cuda_out_value) {
if (len <= 1) {
CopyFromCUDADeviceToCUDADevice<VAL_T>(cuda_out_value, values, 1, __FILE__, __LINE__);
}
BitonicArgSortGlobal<VAL_T, INDEX_T, ASCENDING>(values, indices, len);
SynchronizeCUDADevice(__FILE__, __LINE__);
if (USE_WEIGHT) {
GlobalInclusiveArgPrefixSum<WEIGHT_T, WEIGHT_REDUCE_T, INDEX_T>(indices, weights, weights_prefix_sum, weights_prefix_sum_buffer, static_cast<size_t>(len));
}
SynchronizeCUDADevice(__FILE__, __LINE__);
PercentileGlobalKernel<VAL_T, INDEX_T, WEIGHT_T, WEIGHT_REDUCE_T, ASCENDING, USE_WEIGHT><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, weights, indices, weights_prefix_sum, alpha, len, cuda_out_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename REDUCE_WEIGHT_T, bool ASCENDING, bool USE_WEIGHT>
__device__ VAL_T PercentileDevice(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
const double alpha,
const INDEX_T len);


} // namespace LightGBM

#endif // USE_CUDA_EXP
Expand Down
6 changes: 3 additions & 3 deletions include/LightGBM/cuda/cuda_column_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#ifdef USE_CUDA_EXP

#ifndef LIGHTGBM_CUDA_COLUMN_DATA_HPP_
#define LIGHTGBM_CUDA_COLUMN_DATA_HPP_
#ifndef LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
#define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_

#include <LightGBM/config.h>
#include <LightGBM/cuda/cuda_utils.h>
Expand Down Expand Up @@ -135,6 +135,6 @@ class CUDAColumnData {

} // namespace LightGBM

#endif // LIGHTGBM_CUDA_COLUMN_DATA_HPP_
#endif // LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_

#endif // USE_CUDA_EXP
6 changes: 3 additions & 3 deletions include/LightGBM/cuda/cuda_metadata.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#ifdef USE_CUDA_EXP

#ifndef LIGHTGBM_CUDA_META_DATA_HPP_
#define LIGHTGBM_CUDA_META_DATA_HPP_
#ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_
#define LIGHTGBM_CUDA_CUDA_METADATA_HPP_

#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/meta.h>
Expand Down Expand Up @@ -53,6 +53,6 @@ class CUDAMetadata {

} // namespace LightGBM

#endif // LIGHTGBM_CUDA_META_DATA_HPP_
#endif // LIGHTGBM_CUDA_CUDA_METADATA_HPP_

#endif // USE_CUDA_EXP
6 changes: 3 additions & 3 deletions include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* license information.
*/

#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#ifndef LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_
#define LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_

#ifdef USE_CUDA_EXP

Expand All @@ -24,4 +24,4 @@ class CUDAObjectiveInterface {

#endif // USE_CUDA_EXP

#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#endif // LIGHTGBM_CUDA_CUDA_OBJECTIVE_FUNCTION_HPP_
6 changes: 3 additions & 3 deletions include/LightGBM/cuda/cuda_row_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#ifdef USE_CUDA_EXP

#ifndef LIGHTGBM_CUDA_ROW_DATA_HPP_
#define LIGHTGBM_CUDA_ROW_DATA_HPP_
#ifndef LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_
#define LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_

#include <LightGBM/bin.h>
#include <LightGBM/config.h>
Expand Down Expand Up @@ -174,6 +174,6 @@ class CUDARowData {
};

} // namespace LightGBM
#endif // LIGHTGBM_CUDA_ROW_DATA_HPP_
#endif // LIGHTGBM_CUDA_CUDA_ROW_DATA_HPP_

#endif // USE_CUDA_EXP
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class CUDAVector {
return host_vector;
}

T* RawData() {
T* RawData() const {
return data_;
}

Expand Down
4 changes: 2 additions & 2 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class Metadata {
* \param values Initial score values for this record, one per class
*/
inline void SetInitScoreAt(data_size_t idx, const double* values) {
const auto nclasses = num_classes();
const auto nclasses = num_init_score_classes();
const double* val_ptr = values;
for (int i = idx; i < nclasses * num_data_; i += num_data_, ++val_ptr) {
init_score_[i] = *val_ptr;
Expand Down Expand Up @@ -265,7 +265,7 @@ class Metadata {
/*!
* \brief Get number of classes
*/
inline int32_t num_classes() const {
inline int32_t num_init_score_classes() const {
if (num_data_ && num_init_score_) {
return static_cast<int>(num_init_score_ / num_data_);
}
Expand Down
3 changes: 3 additions & 0 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class ObjectiveFunction {
const data_size_t*,
data_size_t) const { return ori_output; }

virtual void RenewTreeOutputCUDA(const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
const data_size_t* /*data_start_in_leaf*/, const int /*num_leaves*/, double* /*leaf_value*/) const {}

virtual double BoostFromScore(int /*class_id*/) const { return 0.0; }

virtual bool ClassNeedTrain(int /*class_id*/) const { return true; }
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class TreeLearner {
virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;

virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* train_score) const = 0;

TreeLearner() = default;
/*! \brief Disable copy */
Expand Down
Loading

0 comments on commit e1b07b8

Please sign in to comment.