Skip to content

Commit

Permalink
Setup tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 16, 2020
1 parent 33ec6b5 commit eb34a74
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 32 deletions.
14 changes: 0 additions & 14 deletions tests/cpp/common/test_hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,19 +445,5 @@ TEST(HistUtil, SparseIndexBinData) {
}
}
}

TEST(HistUtil, DeviceSketchFromGroupWeights) {
size_t constexpr kRows = 10000, kCols = 100, kBins = 256;
size_t constexpr kGroups = 10;
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
m->Info().weights_.HostVector() = GenerateRandomWeights(kGroups);
std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups;
}
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
DeviceSketch(0, m.get(), kBins, 0);
}

} // namespace common
} // namespace xgboost
21 changes: 21 additions & 0 deletions tests/cpp/common/test_hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,28 @@ TEST(HistUtil, SketchingEquivalent) {
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
}
}
}

TEST(HistUtil, DeviceSketchFromGroupWeights) {
size_t constexpr kRows = 10000, kCols = 100, kBins = 256;
size_t constexpr kGroups = 10;
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
auto& h_weights = m->Info().weights_.HostVector();
h_weights.resize(kGroups);
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups;
}
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);

h_weights.clear();
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);

ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size());
}
} // namespace common
} // namespace xgboost
61 changes: 43 additions & 18 deletions tests/python-gpu/test_gpu_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import urllib.request
import zipfile


class TestRanking(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -22,7 +23,7 @@ def setUpClass(cls):
target = cls.dpath + '/MQ2008.zip'

if os.path.exists(cls.dpath) and os.path.exists(target):
print ("Skipping dataset download...")
print("Skipping dataset download...")
else:
urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f:
Expand Down Expand Up @@ -50,17 +51,30 @@ def setUpClass(cls):
cls.qid_test = qid_test
cls.qid_valid = qid_valid

def setup_weighted(x, y, groups):
# Setup weighted data
data = xgboost.DMatrix(x, y)
groups_segment = [len(list(items))
for _key, items in itertools.groupby(groups)]
data.set_group(groups_segment)
n_groups = len(groups_segment)
weights = np.ones((n_groups,))
data.set_weight(weights)
return data

cls.dtrain_w = setup_weighted(x_train, y_train, qid_train)
cls.dtest_w = setup_weighted(x_test, y_test, qid_test)
cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid)

# model training parameters
cls.params = {'booster': 'gbtree',
'tree_method': 'gpu_hist',
'gpu_id': 0,
'predictor': 'gpu_predictor'
}
'predictor': 'gpu_predictor'}
cls.cpu_params = {'booster': 'gbtree',
'tree_method': 'hist',
'gpu_id': -1,
'predictor': 'cpu_predictor'
}
'predictor': 'cpu_predictor'}

@classmethod
def tearDownClass(cls):
Expand All @@ -81,30 +95,41 @@ def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolera
# specify validations set to watch performance
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')]

num_trees=2500
check_metric_improvement_rounds=10
num_trees = 2500
check_metric_improvement_rounds = 10

evals_result = {}
cls.params['objective'] = rank_objective
cls.params['eval_metric'] = metric_name
bst = xgboost.train(cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
bst = xgboost.train(
cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
gpu_map_metric = evals_result['train'][metric_name][-1]

evals_result = {}
cls.cpu_params['objective'] = rank_objective
cls.cpu_params['eval_metric'] = metric_name
bstc = xgboost.train(cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
bstc = xgboost.train(
cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
cpu_map_metric = evals_result['train'][metric_name][-1]

print("{0} gpu {1} metric {2}".format(rank_objective, metric_name, gpu_map_metric))
print("{0} cpu {1} metric {2}".format(rank_objective, metric_name, cpu_map_metric))
print("gpu best score {0} cpu best score {1}".format(bst.best_score, bstc.best_score))
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance)
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance)
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance,
tolerance)
assert np.allclose(bst.best_score, bstc.best_score, tolerance,
tolerance)

evals_result_weighted = {}
watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')]
bst_w = xgboost.train(
cls.params, cls.dtrain_w, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result_weighted)
weighted_metric = evals_result_weighted['train'][metric_name][-1]
assert np.allclose(bst_w.best_score, bst.best_score)
assert np.allclose(weighted_metric, gpu_map_metric)

def test_training_rank_pairwise_map_metric(self):
"""
Expand Down

0 comments on commit eb34a74

Please sign in to comment.