Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 10, 2020
1 parent d327f60 commit fbacbc1
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/cpp/common/test_hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,59 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
}

TEST(HistUtil, AdapterSketchBatchMemory) {
int num_columns = 100;
int num_rows = 1000;
int num_bins = 256;
auto x = GenerateRandom(num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, num_rows, num_columns);

dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows);
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
0, &sketch_container);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry);
size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t);
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
sizeof(DenseCuts::WQSketch::Entry);
size_t bytes_constant = 1000;
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
}

TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
int num_columns = 100;
int num_rows = 1000;
int num_bins = 256;
auto x = GenerateRandom(num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
MetaInfo info;
auto& h_weights = info.weights_.HostVector();
h_weights.resize(num_rows);
std::fill(h_weights.begin(), h_weights.end(), 1.0f);

dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows);
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), 0,
&sketch_container);
ConsoleLogger::Configure({{"verbosity", "0"}});

size_t bytes_num_elements =
num_rows * num_columns * (sizeof(Entry) + sizeof(float));
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
sizeof(DenseCuts::WQSketch::Entry);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
size_t((bytes_num_elements + bytes_cuts) * 1.05));
}

TEST(HistUtil, AdapterDeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
Expand Down

0 comments on commit fbacbc1

Please sign in to comment.