diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index cdbf0a2e96..90e2dbdd83 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -33,6 +33,8 @@ namespace gtsam { template class Assignment : public std::map { public: + using std::map::operator=; + void print(const std::string& s = "Assignment: ") const { std::cout << s << ": "; for (const typename Assignment::value_type& keyValue : *this) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 01c7b689c1..627c1a5aac 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -666,8 +666,13 @@ namespace gtsam { if (!choice) throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); for (size_t i = 0; i < choice->nrChoices(); i++) { - choices[choice->label()] = i; // Set assignment for label to i + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + + // Remove the choice so we are backtracking + auto choice_it = choices.find(choice->label()); + choices.erase(choice_it); } } }; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 967023eebd..91deed6253 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -346,6 +346,44 @@ TEST(DecisionTree, visitWith) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } +/* ************************************************************************** */ +// Test visit, with Choices argument. +TEST(DecisionTree, VisitWithPruned) { + // Create pruned tree + std::pair A("A", 2), B("B", 2), C("C", 2); + std::vector> labels = {C, B, A}; + std::vector nodes = {0, 0, 2, 3, 4, 4, 6, 7}; + DT tree(labels, nodes); + + std::vector> choices; + auto func = [&](const Assignment& choice, const int& d) { + choices.push_back(choice); + }; + tree.visitWith(func); + + EXPECT_LONGS_EQUAL(6, choices.size()); + + Assignment expectedAssignment; + + expectedAssignment = {{"B", 0}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(0)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(1)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(2)); + + expectedAssignment = {{"B", 0}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(3)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(4)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(5)); +} + /* ************************************************************************** */ // Test fold. TEST(DecisionTree, fold) {