Skip to content

Commit

Permalink
gpu_hist performance fixes (#5558)
Browse files Browse the repository at this point in the history
* Remove unnecessary cuda API calls

* Fix histogram memory growth
RAMitchell authored Apr 19, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent e1f22ba commit d6d1035
Showing 7 changed files with 52 additions and 109 deletions.
85 changes: 25 additions & 60 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
@@ -209,7 +209,6 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
if (n == 0) {
return;
}
safe_cuda(cudaSetDevice(device_idx));
const int GRID_SIZE =
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>( // NOLINT
@@ -368,6 +367,7 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
GetGlobalCachingAllocator().DeviceFree(ptr.get());
}

__host__ __device__
void construct(T *) // NOLINT
{
@@ -391,6 +391,24 @@ using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLI
template <typename T>
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT

// Faster to instantiate than caching_device_vector and invokes no synchronisation
// Use this where vector functionality (e.g. resize) is not required
template <typename T>
class TemporaryArray {
public:
using AllocT = XGBCachingDeviceAllocator<T>;
using value_type = T; // NOLINT
explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); }
~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); }

thrust::device_ptr<T> data() { return ptr_; } // NOLINT
size_t size() { return size_; } // NOLINT

private:
thrust::device_ptr<T> ptr_;
size_t size_;
};

/**
* \brief A double buffer, useful for algorithms like sort.
*/
@@ -474,84 +492,31 @@ struct PinnedMemory {
}
};

// Keep track of cub library device allocation
struct CubMemory {
void *d_temp_storage { nullptr };
size_t temp_storage_bytes { 0 };

// Thrust
using value_type = char; // NOLINT

CubMemory() = default;

~CubMemory() { Free(); }

template <typename T>
xgboost::common::Span<T> GetSpan(size_t size) {
this->LazyAllocate(size * sizeof(T));
return xgboost::common::Span<T>(static_cast<T*>(d_temp_storage), size);
}

void Free() {
if (this->IsAllocated()) {
XGBDeviceAllocator<uint8_t> allocator;
allocator.deallocate(thrust::device_ptr<uint8_t>(static_cast<uint8_t *>(d_temp_storage)),
temp_storage_bytes);
d_temp_storage = nullptr;
temp_storage_bytes = 0;
}
}

void LazyAllocate(size_t num_bytes) {
if (num_bytes > temp_storage_bytes) {
Free();
XGBDeviceAllocator<uint8_t> allocator;
d_temp_storage = static_cast<void *>(allocator.allocate(num_bytes).get());
temp_storage_bytes = num_bytes;
}
}
// Thrust
char *allocate(std::ptrdiff_t num_bytes) { // NOLINT
LazyAllocate(num_bytes);
return reinterpret_cast<char *>(d_temp_storage);
}

// Thrust
void deallocate(char *ptr, size_t n) { // NOLINT

// Do nothing
}

bool IsAllocated() { return d_temp_storage != nullptr; }
};

/*
* Utility functions
*/

/**
* @brief Helper function to perform device-wide sum-reduction, returns to the
* host
* @param tmp_mem cub temporary memory info
* @param in the input array to be reduced
* @param nVals number of elements in the input array
*/
template <typename T>
typename std::iterator_traits<T>::value_type SumReduction(
dh::CubMemory* tmp_mem, T in, int nVals) {
typename std::iterator_traits<T>::value_type SumReduction(T in, int nVals) {
using ValueT = typename std::iterator_traits<T>::value_type;
size_t tmpSize {0};
ValueT *dummy_out = nullptr;
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
// Allocate small extra memory for the return value
tmp_mem->LazyAllocate(tmpSize + sizeof(ValueT));
auto ptr = reinterpret_cast<ValueT *>(tmp_mem->d_temp_storage) + 1;

TemporaryArray<char> temp(tmpSize + sizeof(ValueT));
auto ptr = reinterpret_cast<ValueT *>(temp.data().get()) + 1;
dh::safe_cuda(cub::DeviceReduce::Sum(
reinterpret_cast<void *>(ptr), tmpSize, in,
reinterpret_cast<ValueT *>(tmp_mem->d_temp_storage),
reinterpret_cast<ValueT *>(temp.data().get()),
nVals));
ValueT sum;
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem->d_temp_storage, sizeof(ValueT),
dh::safe_cuda(cudaMemcpy(&sum, temp.data().get(), sizeof(ValueT),
cudaMemcpyDeviceToHost));
return sum;
}
5 changes: 2 additions & 3 deletions src/linear/updater_gpu_coordinate.cu
Original file line number Diff line number Diff line change
@@ -185,7 +185,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
counting, f);
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);

return dh::SumReduction(&temp_, perm, num_row_);
return dh::SumReduction(perm, num_row_);
}

// This needs to be public because of the __device__ lambda.
@@ -213,7 +213,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
}; // NOLINT
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
multiply_iterator(counting, f);
return dh::SumReduction(&temp_, multiply_iterator, col_size);
return dh::SumReduction(multiply_iterator, col_size);
}

// This needs to be public because of the __device__ lambda.
@@ -249,7 +249,6 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
std::vector<size_t> row_ptr_;
dh::device_vector<xgboost::Entry> data_;
dh::caching_device_vector<GradientPair> gpair_;
dh::CubMemory temp_;
size_t num_row_;
};

11 changes: 2 additions & 9 deletions src/metric/elementwise_metric.cu
Original file line number Diff line number Diff line change
@@ -59,13 +59,6 @@ class ElementWiseMetricsReduction {

#if defined(XGBOOST_USE_CUDA)

~ElementWiseMetricsReduction() {
if (device_ >= 0) {
dh::safe_cuda(cudaSetDevice(device_));
allocator_.Free();
}
}

PackedReduceResult DeviceReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
@@ -83,8 +76,9 @@ class ElementWiseMetricsReduction {

auto d_policy = policy_;

dh::XGBCachingDeviceAllocator<char> alloc;
PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(allocator_),
thrust::cuda::par(alloc),
begin, end,
[=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
@@ -130,7 +124,6 @@ class ElementWiseMetricsReduction {
EvalRow policy_;
#if defined(XGBOOST_USE_CUDA)
int device_{-1};
dh::CubMemory allocator_;
#endif // defined(XGBOOST_USE_CUDA)
};

11 changes: 2 additions & 9 deletions src/metric/multiclass_metric.cu
Original file line number Diff line number Diff line change
@@ -73,13 +73,6 @@ class MultiClassMetricsReduction {

#if defined(XGBOOST_USE_CUDA)

~MultiClassMetricsReduction() {
if (device_ >= 0) {
dh::safe_cuda(cudaSetDevice(device_));
allocator_.Free();
}
}

PackedReduceResult DeviceReduceMetrics(
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
@@ -98,8 +91,9 @@ class MultiClassMetricsReduction {
auto s_label_error = label_error_.GetSpan<int32_t>(1);
s_label_error[0] = 0;

dh::XGBCachingDeviceAllocator<char> alloc;
PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(allocator_),
thrust::cuda::par(alloc),
begin, end,
[=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
@@ -152,7 +146,6 @@ class MultiClassMetricsReduction {
#if defined(XGBOOST_USE_CUDA)
dh::PinnedMemory label_error_;
int device_{-1};
dh::CubMemory allocator_;
#endif // defined(XGBOOST_USE_CUDA)
};

1 change: 0 additions & 1 deletion src/tree/gpu_hist/row_partitioner.cuh
Original file line number Diff line number Diff line change
@@ -108,7 +108,6 @@ class RowPartitioner {
template <typename UpdatePositionOpT>
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
bst_node_t right_nidx, UpdatePositionOpT op) {
dh::safe_cuda(cudaSetDevice(device_idx_));
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx_.CurrentSpan();
auto d_position = position_.CurrentSpan();
45 changes: 20 additions & 25 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
@@ -2,9 +2,6 @@
* Copyright 2017-2020 XGBoost contributors
*/
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <xgboost/tree_updater.h>
#include <algorithm>
@@ -20,8 +17,6 @@
#include "xgboost/span.h"
#include "xgboost/json.h"

#include "../common/common.h"
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/timer.h"
@@ -324,9 +319,9 @@ class DeviceHistogram {
}

void Reset() {
dh::safe_cuda(cudaMemsetAsync(
data_.data().get(), 0,
data_.size() * sizeof(typename decltype(data_)::value_type)));
auto d_data = data_.data().get();
dh::LaunchN(device_id_, data_.size(),
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
nidx_map_.clear();
}
bool HistogramExists(int nidx) const {
@@ -348,30 +343,33 @@ class DeviceHistogram {
// Number of items currently used in data
const size_t used_size = nidx_map_.size() * HistogramSize();
const size_t new_used_size = used_size + HistogramSize();
dh::safe_cuda(cudaSetDevice(device_id_));
if (data_.size() >= kStopGrowingSize) {
// Recycle histogram memory
if (new_used_size <= data_.size()) {
// no need to remove old node, just insert the new one.
nidx_map_[nidx] = used_size;
// memset histogram size in bytes
dh::safe_cuda(cudaMemsetAsync(data_.data().get() + used_size, 0,
n_bins_ * sizeof(GradientSumT)));
} else {
std::pair<int, size_t> old_entry = *nidx_map_.begin();
nidx_map_.erase(old_entry.first);
dh::safe_cuda(cudaMemsetAsync(data_.data().get() + old_entry.second, 0,
n_bins_ * sizeof(GradientSumT)));
nidx_map_[nidx] = old_entry.second;
}
// Zero recycled memory
auto d_data = data_.data().get() + nidx_map_[nidx];
dh::LaunchN(device_id_, n_bins_ * 2,
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
} else {
// Append new node histogram
nidx_map_[nidx] = used_size;
size_t new_required_memory = std::max(data_.size() * 2, HistogramSize());
if (data_.size() < new_required_memory) {
// Check there is enough memory for another histogram node
if (data_.size() < new_used_size + HistogramSize()) {
size_t new_required_memory =
std::max(data_.size() * 2, HistogramSize());
data_.resize(new_required_memory);
}
}

CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize());
}

/**
@@ -428,7 +426,6 @@ struct GPUHistMakerDevice {

GradientSumT histogram_rounding;

dh::CubMemory temp_memory;
dh::PinnedMemory pinned_memory;

std::vector<cudaStream_t> streams{};
@@ -531,15 +528,14 @@ struct GPUHistMakerDevice {
std::vector<DeviceSplitCandidate> EvaluateSplits(
std::vector<int> nidxs, const RegTree& tree,
size_t num_columns) {
dh::safe_cuda(cudaSetDevice(device_id));
auto result_all = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());

// Work out cub temporary memory requirement
GPUTrainingParam gpu_param(param);
DeviceSplitCandidateReduceOp op(gpu_param);

dh::caching_device_vector<DeviceSplitCandidate> d_result_all(nidxs.size());
dh::caching_device_vector<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
dh::TemporaryArray<DeviceSplitCandidate> d_result_all(nidxs.size());
dh::TemporaryArray<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);

auto& streams = this->GetStreams(nidxs.size());
for (auto i = 0ull; i < nidxs.size(); i++) {
@@ -582,7 +578,7 @@ struct GPUHistMakerDevice {
cub_bytes, d_split_candidates.data(),
d_result.data(), d_split_candidates.size(), op,
DeviceSplitCandidate(), streams[i]);
dh::caching_device_vector<char> cub_temp(cub_bytes);
dh::TemporaryArray<char> cub_temp(cub_bytes);
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
cub_bytes, d_split_candidates.data(),
d_result.data(), d_split_candidates.size(), op,
@@ -651,9 +647,8 @@ struct GPUHistMakerDevice {
// instances to their final leaf. This information is used later to update the
// prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
const auto d_nodes =
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
cudaMemcpyHostToDevice));

@@ -662,10 +657,10 @@ struct GPUHistMakerDevice {
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
}
if (page->n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, d_nodes);
FinalisePositionInPage(page, dh::ToSpan(d_nodes));
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
FinalisePositionInPage(batch.Impl(), d_nodes);
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes));
}
}
}
3 changes: 1 addition & 2 deletions tests/cpp/common/test_device_helpers.cu
Original file line number Diff line number Diff line change
@@ -10,8 +10,7 @@

TEST(SumReduce, Test) {
thrust::device_vector<float> data(100, 1.0f);
dh::CubMemory temp;
auto sum = dh::SumReduction(&temp, data.data().get(), data.size());
auto sum = dh::SumReduction(data.data().get(), data.size());
ASSERT_NEAR(sum, 100.0f, 1e-5);
}

0 comments on commit d6d1035

Please sign in to comment.