Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OpenMP thread allocation in Linux #5551

Merged
merged 15 commits into from
Nov 29, 2022
3 changes: 2 additions & 1 deletion include/LightGBM/bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ class Bin {
/*!
* \brief Initialize for pushing. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/
virtual void InitStreaming(uint32_t /*num_thread*/) { }
virtual void InitStreaming(uint32_t /*num_thread*/, int32_t /*omp_max_threads*/) { }
/*!
* \brief Push one record
* \param tid Thread id
Expand Down
4 changes: 3 additions & 1 deletion include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \param has_queries Whether the dataset has Metadata queries/groups
* \param nclasses Number of initial score classes
* \param nthreads Number of external threads that will use the PushRows APIs
* \param omp_max_threads Maximum number of OpenMP threads (-1 for default)
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads);
int32_t nthreads,
int32_t omp_max_threads);

/*!
* \brief Push data to existing dataset, if ``nrow + start_row == num_total_row``, will call ``dataset->FinishLoad``.
Expand Down
16 changes: 14 additions & 2 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,18 @@ class Dataset {
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads) {
int32_t nthreads,
int32_t omp_max_threads) {
// Initialize optional max thread count with either parameter or OMP setting
if (omp_max_threads > 0) {
omp_max_threads_ = omp_max_threads;
} else if (omp_max_threads_ <= 0) {
omp_max_threads_ = OMP_NUM_THREADS();
}

metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses);
for (int i = 0; i < num_groups_; ++i) {
feature_groups_[i]->InitStreaming(nthreads);
feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_);
}
}

Expand Down Expand Up @@ -846,6 +854,9 @@ class Dataset {
/*! \brief Get whether FinishLoad is automatically called when pushing last row. */
inline bool wait_for_manual_finish() const { return wait_for_manual_finish_; }

/*! \brief Get the maximum number of OpenMP threads to allocate for. */
inline int omp_max_threads() const { return omp_max_threads_; }

/*! \brief Set whether the Dataset is finished automatically when last row is pushed or with a manual
* MarkFinished API call. Set to true for thread-safe streaming and/or if will be coalesced later.
* FinishLoad should not be called on any Dataset that will be coalesced.
Expand Down Expand Up @@ -947,6 +958,7 @@ class Dataset {
std::vector<int> feature_need_push_zeros_;
std::vector<std::vector<float>> raw_data_;
bool wait_for_manual_finish_;
int omp_max_threads_ = -1;
bool has_raw_;
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_;
Expand Down
7 changes: 4 additions & 3 deletions include/LightGBM/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,15 @@ class FeatureGroup {
/*!
* \brief Initialize for pushing in a streaming fashion. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/
void InitStreaming(int32_t num_thread) {
void InitStreaming(int32_t num_thread, int32_t omp_max_threads) {
if (is_multi_val_) {
for (int i = 0; i < num_feature_; ++i) {
multi_bin_data_[i]->InitStreaming(num_thread);
multi_bin_data_[i]->InitStreaming(num_thread, omp_max_threads);
}
} else {
bin_data_->InitStreaming(num_thread);
bin_data_->InitStreaming(num_thread, omp_max_threads);
}
}

Expand Down
16 changes: 10 additions & 6 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,11 +1018,12 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads) {
int32_t nthreads,
int32_t omp_max_threads) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto num_data = p_dataset->num_data();
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads);
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads);
p_dataset->set_wait_for_manual_finish(true);
API_END();
}
Expand Down Expand Up @@ -1073,19 +1074,20 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
if (!data) {
Log::Fatal("data cannot be null.");
}
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}

const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid);
const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, start_row + i, one_row);
OMP_LOOP_EX_END();
Expand Down Expand Up @@ -1154,19 +1156,21 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
if (!data) {
Log::Fatal("data cannot be null.");
}
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}

const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid);
const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, static_cast<data_size_t>(start_row + i), one_row);
OMP_LOOP_EX_END();
Expand Down
8 changes: 4 additions & 4 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ class SparseBin : public Bin {

~SparseBin() {}

void InitStreaming(uint32_t num_thread) override {
// Each thread needs its own push buffer, so allocate external num_thread times the number of OMP threads
int num_omp_threads = OMP_NUM_THREADS();
push_buffers_.resize(num_omp_threads * num_thread);
void InitStreaming(uint32_t num_thread, int32_t omp_max_threads) override {
// Each external thread needs its own set of OpenMP push buffers,
// so allocate num_thread times the maximum number of OMP threads per external thread
push_buffers_.resize(omp_max_threads * num_thread);
};

void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp_tests/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void test_stream_dense(
&dataset_handle);
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;

result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1);
result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1, -1);
EXPECT_EQ(0, result) << "LGBM_DatasetInitStreaming result code: " << result;
break;
}
Expand Down Expand Up @@ -197,7 +197,7 @@ void test_stream_sparse(
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;

dataset = static_cast<Dataset*>(dataset_handle);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2, -1);
break;
}

Expand Down