From a92bf3742be78b96edc25bffac95027dc78fc400 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Wed, 13 Sep 2023 01:06:20 +0800 Subject: [PATCH] [fix] fix quantized training (fixes #5982) (fixes #5994) (#6092) * fix leaf splits update after split in quantized training * fix preparation ordered gradients for quantized training * remove force_row_wise in distributed test for quantized training * Update src/treelearner/leaf_splits.hpp --------- Co-authored-by: James Lamb --- src/io/dataset.cpp | 37 +++++--- src/treelearner/leaf_splits.hpp | 19 ++++ src/treelearner/serial_tree_learner.cpp | 115 ++++++++++++++++++++---- src/treelearner/serial_tree_learner.h | 2 + tests/python_package_test/test_dask.py | 1 - 5 files changed, 142 insertions(+), 32 deletions(-) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index d5aa707adcc0..cd692afb031a 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner( auto ptr_ordered_grad = gradients; auto ptr_ordered_hess = hessians; if (num_used_dense_group > 0) { - if (USE_INDICES) { - if (USE_HESSIAN) { -#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + if (USE_QUANT_GRAD) { + int16_t* ordered_gradients_and_hessians = reinterpret_cast(ordered_gradients); + const int16_t* gradients_and_hessians = reinterpret_cast(gradients); + if (USE_INDICES) { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) for (data_size_t i = 0; i < num_data; ++i) { - ordered_gradients[i] = gradients[data_indices[i]]; - ordered_hessians[i] = hessians[data_indices[i]]; + ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]]; } - ptr_ordered_grad = ordered_gradients; - ptr_ordered_hess = ordered_hessians; - } else { -#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) - for (data_size_t i = 0; i < num_data; ++i) { - ordered_gradients[i] = gradients[data_indices[i]]; + ptr_ordered_grad = reinterpret_cast(ordered_gradients); + ptr_ordered_hess = nullptr; + } + } else { + if (USE_INDICES) { + if (USE_HESSIAN) { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + for (data_size_t i = 0; i < num_data; ++i) { + ordered_gradients[i] = gradients[data_indices[i]]; + ordered_hessians[i] = hessians[data_indices[i]]; + } + ptr_ordered_grad = ordered_gradients; + ptr_ordered_hess = ordered_hessians; + } else { + #pragma omp parallel for schedule(static, 512) if (num_data >= 1024) + for (data_size_t i = 0; i < num_data; ++i) { + ordered_gradients[i] = gradients[data_indices[i]]; + } + ptr_ordered_grad = ordered_gradients; } - ptr_ordered_grad = ordered_gradients; } } OMP_INIT_EX(); diff --git a/src/treelearner/leaf_splits.hpp b/src/treelearner/leaf_splits.hpp index 163bfc4df9ca..fdf55693a0e9 100644 --- a/src/treelearner/leaf_splits.hpp +++ b/src/treelearner/leaf_splits.hpp @@ -53,6 +53,25 @@ class LeafSplits { weight_ = weight; } + /*! + * \brief Init split on current leaf on partial data. + * \param leaf Index of current leaf + * \param data_partition current data partition + * \param sum_gradients + * \param sum_hessians + * \param sum_gradients_and_hessians + * \param weight + */ + void Init(int leaf, const DataPartition* data_partition, double sum_gradients, + double sum_hessians, int64_t sum_gradients_and_hessians, double weight) { + leaf_index_ = leaf; + data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_); + sum_gradients_ = sum_gradients; + sum_hessians_ = sum_hessians; + int_sum_gradients_and_hessians_ = sum_gradients_and_hessians; + weight_ = weight; + } + /*! * \brief Init split on current leaf on partial data. * \param leaf Index of current leaf diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index c322c1a796c2..37d9a2a50713 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, #endif // init the leaves that used on next iteration - if (best_split_info.left_count < best_split_info.right_count) { - CHECK_GT(best_split_info.left_count, 0); - smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), - best_split_info.left_sum_gradient, - best_split_info.left_sum_hessian, - best_split_info.left_output); - larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), - best_split_info.right_sum_gradient, - best_split_info.right_sum_hessian, - best_split_info.right_output); + if (!config_->use_quantized_grad) { + if (best_split_info.left_count < best_split_info.right_count) { + CHECK_GT(best_split_info.left_count, 0); + smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_output); + larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_output); + } else { + CHECK_GT(best_split_info.right_count, 0); + smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_output); + larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_output); + } } else { - CHECK_GT(best_split_info.right_count, 0); - smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), - best_split_info.right_sum_gradient, - best_split_info.right_sum_hessian, - best_split_info.right_output); - larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), - best_split_info.left_sum_gradient, - best_split_info.left_sum_hessian, - best_split_info.left_output); + if (best_split_info.left_count < best_split_info.right_count) { + CHECK_GT(best_split_info.left_count, 0); + smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_sum_gradient_and_hessian, + best_split_info.left_output); + larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_sum_gradient_and_hessian, + best_split_info.right_output); + } else { + CHECK_GT(best_split_info.right_count, 0); + smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), + best_split_info.right_sum_gradient, + best_split_info.right_sum_hessian, + best_split_info.right_sum_gradient_and_hessian, + best_split_info.right_output); + larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), + best_split_info.left_sum_gradient, + best_split_info.left_sum_hessian, + best_split_info.left_sum_gradient_and_hessian, + best_split_info.left_output); + } } if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) { gradient_discretizer_->SetNumBitsInHistogramBin(*left_leaf, *right_leaf, data_partition_->leaf_count(*left_leaf), data_partition_->leaf_count(*right_leaf)); } + + #ifdef DEBUG + CheckSplit(best_split_info, *left_leaf, *right_leaf); + #endif + auto leaves_need_update = constraints_->Update( is_numerical_split, *left_leaf, *right_leaf, best_split_info.monotone_type, best_split_info.right_output, @@ -1024,4 +1057,48 @@ std::vector node_used_features = col_sampler_.GetByNode(tree, leaf); *split = bests[best_idx]; } +#ifdef DEBUG +void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) { + data_size_t num_data_in_left = 0; + data_size_t num_data_in_right = 0; + const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left); + const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right); + if (config_->use_quantized_grad) { + int32_t sum_left_gradient = 0; + int32_t sum_left_hessian = 0; + int32_t sum_right_gradient = 0; + int32_t sum_right_hessian = 0; + const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians(); + for (data_size_t i = 0; i < num_data_in_left; ++i) { + const data_size_t index = data_indices_in_left[i]; + sum_left_gradient += discretized_grad_and_hess[2 * index + 1]; + sum_left_hessian += discretized_grad_and_hess[2 * index]; + } + for (data_size_t i = 0; i < num_data_in_right; ++i) { + const data_size_t index = data_indices_in_right[i]; + sum_right_gradient += discretized_grad_and_hess[2 * index + 1]; + sum_right_hessian += discretized_grad_and_hess[2 * index]; + } + Log::Warning("============================ start leaf split info ============================"); + Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index); + Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right); + Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient, + static_cast(best_split_info.left_sum_gradient_and_hessian >> 32)); + Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian, + static_cast(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff)); + Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient, + static_cast(best_split_info.right_sum_gradient_and_hessian >> 32)); + Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian, + static_cast(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff)); + CHECK_EQ(num_data_in_left, best_split_info.left_count); + CHECK_EQ(num_data_in_right, best_split_info.right_count); + CHECK_EQ(sum_left_gradient, static_cast(best_split_info.left_sum_gradient_and_hessian >> 32)) + CHECK_EQ(sum_left_hessian, static_cast(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff)); + CHECK_EQ(sum_right_gradient, static_cast(best_split_info.right_sum_gradient_and_hessian >> 32)); + CHECK_EQ(sum_right_hessian, static_cast(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff)); + Log::Warning("============================ end leaf split info ============================"); + } +} +#endif + } // namespace LightGBM diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index d815d265c0d2..93e0787a90cf 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner { std::set FindAllForceFeatures(Json force_split_leaf_setting); + #ifdef DEBUG void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index); + #endif /*! * \brief Get the number of data in a leaf diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index cb69440b3cde..9da50945385c 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster): 'num_grad_quant_bins': 30, 'quant_train_renew_leaf': True, 'verbose': -1, - 'force_row_wise': True, } quant_dask_classifier = lgb.DaskLGBMRegressor(