Skip to content

Commit

Permalink
Merge pull request #1394 from borglab/hybrid/pruning_test
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 19, 2023
2 parents 3460147 + b5d5745 commit d920d94
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,33 +220,30 @@ TEST(HybridBayesNet, Optimize) {

/* ****************************************************************************/
// Test Bayes net error
TEST(HybridBayesNet, logProbability) {
TEST(HybridBayesNet, Pruning) {
Switching s(3);

HybridBayesNet::shared_ptr posterior =
s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, posterior->size());

HybridValues delta = posterior->optimize();
auto actualTree = posterior->logProbability(delta.continuous());
auto actualTree = posterior->evaluate(delta.continuous());

// Regression test on density tree.
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {1.8101301, 3.0128899, 2.8784032, 2.9825507};
std::vector<double> leaves = {6.1112424, 20.346113, 17.785849, 19.738098};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);

// regression
EXPECT(assert_equal(expected, actualTree, 1e-6));

// logProbability on pruned Bayes net
// Prune and get probabilities
auto prunedBayesNet = posterior->prune(2);
auto prunedTree = prunedBayesNet.logProbability(delta.continuous());
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());

std::vector<double> pruned_leaves = {2e50, 3.0128899, 2e50, 2.9825507};
// Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);

// regression
// TODO(dellaert): fix pruning, I have no insight in this code.
// EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));

// Verify logProbability computation and check specific logProbability value
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
Expand All @@ -261,8 +258,9 @@ TEST(HybridBayesNet, logProbability) {
logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues);

EXPECT_DOUBLES_EQUAL(logProbability, actualTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, prunedTree(discrete_values), 1e-9);
double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9);
}
Expand Down

0 comments on commit d920d94

Please sign in to comment.