diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 36c18f00a9b5..5bec5b2db5c8 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -281,6 +281,59 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); } +TEST(HistUtil, AdapterSketchBatchMemory) { + int num_columns = 100; + int num_rows = 1000; + int num_bins = 256; + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + common::HistogramCuts batched_cuts; + SketchContainer sketch_container(num_bins, num_columns, num_rows); + AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), + 0, &sketch_container); + ConsoleLogger::Configure({{"verbosity", "0"}}); + size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); + size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t); + size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * + sizeof(DenseCuts::WQSketch::Entry); + size_t bytes_constant = 1000; + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), + bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); +} + +TEST(HistUtil, AdapterSketchBatchWeightedMemory) { + int num_columns = 100; + int num_rows = 1000; + int num_bins = 256; + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + MetaInfo info; + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(num_rows); + std::fill(h_weights.begin(), h_weights.end(), 1.0f); + + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + common::HistogramCuts batched_cuts; + SketchContainer sketch_container(num_bins, num_columns, num_rows); + AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), 0, + &sketch_container); + ConsoleLogger::Configure({{"verbosity", "0"}}); + + size_t bytes_num_elements = + num_rows * num_columns * (sizeof(Entry) + sizeof(float)); + size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * + sizeof(DenseCuts::WQSketch::Entry); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), + size_t((bytes_num_elements + bytes_cuts) * 1.05)); +} + TEST(HistUtil, AdapterDeviceSketchCategorical) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256;