Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group aware GPU sketching. #5551

Merged
merged 4 commits into from
Apr 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,23 @@ class MetaInfo {

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}

/*!
* \brief Validate all metainfo.
*/
void Validate() const;

MetaInfo Slice(common::Span<int32_t const> ridxs) const;
/*!
* \brief Get weight of each instances.
Expand Down
148 changes: 91 additions & 57 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,26 @@ void GetColumnSizesScan(int device,
* \param column_sizes_scan Describes the boundaries of column segments in
* sorted data
*/
void ExtractCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
Span<size_t> column_sizes_scan) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;

Span<Entry> column_entries =
Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);

size_t rank = (column_entries.size() * cut_idx) / num_available_cuts;
auto value = column_entries[rank].fvalue;
cuts[idx] = SketchEntry(rank, rank + 1, 1, value);
size_t rank = (column_entries.size() * cut_idx) /
static_cast<float>(num_available_cuts);
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
});
}

Expand All @@ -170,31 +171,32 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan) {
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;

Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
Span<float> column_weights =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);

float total_column_weight = column_weights.back();
Span<float> column_weights_scan =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights_scan.back();
size_t sample_idx = 0;
if (cut_idx == 0) {
// First cut
sample_idx = 0;
} else if (cut_idx == num_available_cuts - 1) {
} else if (cut_idx == num_available_cuts) {
// Last cut
sample_idx = column_entries.size() - 1;
} else if (num_available_cuts == column_size) {
Expand All @@ -204,15 +206,18 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
} else {
bst_float rank = (total_column_weight * cut_idx) /
static_cast<float>(num_available_cuts);
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(),
column_weights.end(), rank) -
column_weights.begin() - 1;
sample_idx = thrust::upper_bound(thrust::seq,
column_weights_scan.begin(),
column_weights_scan.end(),
rank) -
column_weights_scan.begin();
sample_idx =
max(size_t(0), min(sample_idx, column_entries.size() - 1));
max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1));
}
// repeated values will be filtered out on the CPU
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0;
bst_float rmax = column_weights[sample_idx];
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
column_entries[sample_idx].fvalue);
});
Expand All @@ -224,7 +229,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());

Expand All @@ -235,9 +240,10 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);

dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
ExtractCuts(device, num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
Expand All @@ -246,28 +252,49 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,

void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
SketchContainer* sketch_container, int num_cuts_per_feature,
size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);

// Binary search to assign weights to each element
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid;
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid];
});
if (is_ranking) {
CHECK_GE(d_group_ptr.size(), 2)
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
auto it =
thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx + base_rowid) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
d_temp_weights[idx] = weights[group];
});
} else {
CHECK_EQ(weights.size(), page.offset.Size() - 1);
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}

// Sort
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());
Expand All @@ -287,26 +314,26 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);

// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractWeightedCuts(
device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{temp_weights.data().get(), temp_weights.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts_per_feature);
ExtractWeightedCuts(device, num_cuts_per_feature,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
}

HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) {
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
if (sketch_batch_num_elements == 0) {
int bytes_per_element = has_weights ? 24 : 16;
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry);
size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
// use up to 80% of available space
sketch_batch_num_elements =
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
Expand All @@ -320,15 +347,21 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
for (auto begin = 0ull; begin < batch_nnz;
begin += sketch_batch_num_elements) {
auto const& info = dmat->Info();
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat);
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container, num_cuts, dmat->Info().num_col_);
&sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts,
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature,
dmat->Info().num_col_);
}
}
Expand Down Expand Up @@ -383,9 +416,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,

// Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
ExtractCuts(adapter->DeviceIdx(), num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// Push cuts into sketches stored in host memory
thrust::host_vector<SketchEntry> host_cuts(cuts);
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ class HistogramCuts {
class CutsBuilder {
public:
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);

protected:
HistogramCuts* p_cuts_;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);

public:
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
Expand Down
39 changes: 39 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,45 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
}
}

void MetaInfo::Validate() const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
<< "Size of weights must equal to number of groups when ranking "
"group is used.";
return;
}
if (group_ptr_.size() != 0) {
CHECK_EQ(group_ptr_.back(), num_row_)
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
}
if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
return;
}
if (labels_.Size() != 0) {
CHECK_EQ(labels_.Size(), num_row_)
<< "Size of labels must equal to number of rows.";
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
return;
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
}
}

#if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
common::AssertGPUSupport();
Expand Down
10 changes: 1 addition & 9 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1048,15 +1048,7 @@ class LearnerImpl : public LearnerIO {

void ValidateDMatrix(DMatrix* p_fmat) const {
MetaInfo const& info = p_fmat->Info();
auto const& weights = info.weights_;
if (info.group_ptr_.size() != 0 && weights.Size() != 0) {
CHECK(weights.Size() == info.group_ptr_.size() - 1)
<< "\n"
<< "weights size: " << weights.Size() << ", "
<< "groups size: " << info.group_ptr_.size() -1 << ", "
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
<< "Number of weights should be equal to number of groups in ranking task.";
}
info.Validate();

auto const row_based_split = [this]() {
return tparam_.dsplit == DataSplitMode::kRow ||
Expand Down
Loading