Skip to content

Commit

Permalink
Fix OpenMP thread allocation in Linux (#5551)
Browse files Browse the repository at this point in the history
  • Loading branch information
svotaw authored Nov 29, 2022
1 parent 51efd90 commit 4c5d0fb
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 19 deletions.
3 changes: 2 additions & 1 deletion include/LightGBM/bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ class Bin {
/*!
* \brief Initialize for pushing. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/
virtual void InitStreaming(uint32_t /*num_thread*/) { }
virtual void InitStreaming(uint32_t /*num_thread*/, int32_t /*omp_max_threads*/) { }
/*!
* \brief Push one record
* \param tid Thread id
Expand Down
4 changes: 3 additions & 1 deletion include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \param has_queries Whether the dataset has Metadata queries/groups
* \param nclasses Number of initial score classes
* \param nthreads Number of external threads that will use the PushRows APIs
* \param omp_max_threads Maximum number of OpenMP threads (-1 for default)
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_weights,
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads);
int32_t nthreads,
int32_t omp_max_threads);

/*!
* \brief Push data to existing dataset, if ``nrow + start_row == num_total_row``, will call ``dataset->FinishLoad``.
Expand Down
16 changes: 14 additions & 2 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,18 @@ class Dataset {
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads) {
int32_t nthreads,
int32_t omp_max_threads) {
// Initialize optional max thread count with either parameter or OMP setting
if (omp_max_threads > 0) {
omp_max_threads_ = omp_max_threads;
} else if (omp_max_threads_ <= 0) {
omp_max_threads_ = OMP_NUM_THREADS();
}

metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses);
for (int i = 0; i < num_groups_; ++i) {
feature_groups_[i]->InitStreaming(nthreads);
feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_);
}
}

Expand Down Expand Up @@ -846,6 +854,9 @@ class Dataset {
/*! \brief Get whether FinishLoad is automatically called when pushing last row. */
inline bool wait_for_manual_finish() const { return wait_for_manual_finish_; }

/*! \brief Get the maximum number of OpenMP threads to allocate for. */
inline int omp_max_threads() const { return omp_max_threads_; }

/*! \brief Set whether the Dataset is finished automatically when last row is pushed or with a manual
* MarkFinished API call. Set to true for thread-safe streaming and/or if will be coalesced later.
* FinishLoad should not be called on any Dataset that will be coalesced.
Expand Down Expand Up @@ -947,6 +958,7 @@ class Dataset {
std::vector<int> feature_need_push_zeros_;
std::vector<std::vector<float>> raw_data_;
bool wait_for_manual_finish_;
int omp_max_threads_ = -1;
bool has_raw_;
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_;
Expand Down
7 changes: 4 additions & 3 deletions include/LightGBM/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,15 @@ class FeatureGroup {
/*!
* \brief Initialize for pushing in a streaming fashion. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/
void InitStreaming(int32_t num_thread) {
void InitStreaming(int32_t num_thread, int32_t omp_max_threads) {
if (is_multi_val_) {
for (int i = 0; i < num_feature_; ++i) {
multi_bin_data_[i]->InitStreaming(num_thread);
multi_bin_data_[i]->InitStreaming(num_thread, omp_max_threads);
}
} else {
bin_data_->InitStreaming(num_thread);
bin_data_->InitStreaming(num_thread, omp_max_threads);
}
}

Expand Down
16 changes: 10 additions & 6 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,11 +1018,12 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_init_scores,
int32_t has_queries,
int32_t nclasses,
int32_t nthreads) {
int32_t nthreads,
int32_t omp_max_threads) {
API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto num_data = p_dataset->num_data();
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads);
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads);
p_dataset->set_wait_for_manual_finish(true);
API_END();
}
Expand Down Expand Up @@ -1073,19 +1074,20 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
if (!data) {
Log::Fatal("data cannot be null.");
}
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}

const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid);
const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, start_row + i, one_row);
OMP_LOOP_EX_END();
Expand Down Expand Up @@ -1154,19 +1156,21 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
if (!data) {
Log::Fatal("data cannot be null.");
}
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
}

const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();

OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid);
const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, static_cast<data_size_t>(start_row + i), one_row);
OMP_LOOP_EX_END();
Expand Down
8 changes: 4 additions & 4 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ class SparseBin : public Bin {

~SparseBin() {}

void InitStreaming(uint32_t num_thread) override {
// Each thread needs its own push buffer, so allocate external num_thread times the number of OMP threads
int num_omp_threads = OMP_NUM_THREADS();
push_buffers_.resize(num_omp_threads * num_thread);
void InitStreaming(uint32_t num_thread, int32_t omp_max_threads) override {
// Each external thread needs its own set of OpenMP push buffers,
// so allocate num_thread times the maximum number of OMP threads per external thread
push_buffers_.resize(omp_max_threads * num_thread);
};

void ReSize(data_size_t num_data) override { num_data_ = num_data; }
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp_tests/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void test_stream_dense(
&dataset_handle);
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;

result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1);
result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1, -1);
EXPECT_EQ(0, result) << "LGBM_DatasetInitStreaming result code: " << result;
break;
}
Expand Down Expand Up @@ -197,7 +197,7 @@ void test_stream_sparse(
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;

dataset = static_cast<Dataset*>(dataset_handle);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2, -1);
break;
}

Expand Down

0 comments on commit 4c5d0fb

Please sign in to comment.