From d5d5ecc3b3fdd954253578e8fdadfb5b5930c36d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 06:10:40 -0400 Subject: [PATCH 1/5] refactor DecisionTree to make a distinction between leaves and assignments --- gtsam/discrete/DecisionTree-inl.h | 106 +++++++++++++++++------------- gtsam/discrete/DecisionTree.h | 37 ++++++----- 2 files changed, 80 insertions(+), 63 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 826e54b955..48da55ce1a 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -59,7 +59,7 @@ namespace gtsam { /** constant stored in this leaf */ Y constant_; - /** The number of assignments contained within this leaf + /** The number of assignments contained within this leaf. * Particularly useful when leaves have been pruned. */ size_t nrAssignments_; @@ -68,7 +68,7 @@ namespace gtsam { Leaf(const Y& constant, size_t nrAssignments = 1) : constant_(constant), nrAssignments_(nrAssignments) {} - /** return the constant */ + /// Return the constant const Y& constant() const { return constant_; } @@ -81,19 +81,19 @@ namespace gtsam { return constant_ == q.constant_; } - /// polymorphic equality: is q is a leaf, could be + /// polymorphic equality: is q a leaf and is it the same as this leaf? bool sameLeaf(const Node& q) const override { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ + /// equality up to tolerance bool equals(const Node& q, const CompareFunc& compare) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; return compare(this->constant_, other->constant_); } - /** print */ + /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; @@ -122,8 +122,8 @@ namespace gtsam { /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, - const Assignment& choices) const override { - NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); + const Assignment& assignment) const override { + NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); return f; } @@ -168,7 +168,10 @@ namespace gtsam { std::vector branches_; private: - /** incremental allSame */ + /** + * Incremental allSame. + * Records if all the branches are the same leaf. + */ size_t allSame_; using ChoicePtr = boost::shared_ptr; @@ -181,7 +184,7 @@ namespace gtsam { #endif } - /** If all branches of a choice node f are the same, just return a branch */ + /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { #ifndef DT_NO_PRUNING if (f->allSame_) { @@ -205,15 +208,13 @@ namespace gtsam { bool isLeaf() const override { return false; } - /** Constructor, given choice label and mandatory expected branch count */ + /// Constructor, given choice label and mandatory expected branch count. Choice(const L& label, size_t count) : label_(label), allSame_(true) { branches_.reserve(count); } - /** - * Construct from applying binary op to two Choice nodes - */ + /// Construct from applying binary op to two Choice nodes. Choice(const Choice& f, const Choice& g, const Binary& op) : allSame_(true) { // Choose what to do based on label @@ -241,6 +242,7 @@ namespace gtsam { } } + /// Return the label of this choice node. const L& label() const { return label_; } @@ -262,7 +264,7 @@ namespace gtsam { branches_.push_back(node); } - /** print (as a tree) */ + /// print (as a tree). void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; @@ -308,7 +310,7 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality */ + /// equality bool equals(const Node& q, const CompareFunc& compare) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; @@ -321,7 +323,7 @@ namespace gtsam { return true; } - /** evaluate */ + /// evaluate const Y& operator()(const Assignment& x) const override { #ifndef NDEBUG typename Assignment::const_iterator it = x.find(label_); @@ -336,13 +338,13 @@ namespace gtsam { return (*child)(x); } - /** - * Construct from applying unary op to a Choice node - */ + /// Construct from applying unary op to a Choice node. Choice(const L& label, const Choice& f, const Unary& op) : label_(label), allSame_(true) { branches_.reserve(f.branches_.size()); // reserve space - for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); + for (const NodePtr& branch : f.branches_) { + push_back(branch->apply(op)); + } } /** @@ -353,28 +355,28 @@ namespace gtsam { * @param f The original choice node to apply the op on. * @param op Function to apply on the choice node. Takes Assignment and * value as arguments. - * @param choices The Assignment that will go to op. + * @param assignment The Assignment that will go to op. */ Choice(const L& label, const Choice& f, const UnaryAssignment& op, - const Assignment& choices) + const Assignment& assignment) : label_(label), allSame_(true) { branches_.reserve(f.branches_.size()); // reserve space - Assignment choices_ = choices; + Assignment assignment_ = assignment; for (size_t i = 0; i < f.branches_.size(); i++) { - choices_[label_] = i; // Set assignment for label to i + assignment_[label_] = i; // Set assignment for label to i const NodePtr branch = f.branches_[i]; - push_back(branch->apply(op, choices_)); + push_back(branch->apply(op, assignment_)); - // Remove the choice so we are backtracking - auto choice_it = choices_.find(label_); - choices_.erase(choice_it); + // Remove the assignment so we are backtracking + auto assignment_it = assignment_.find(label_); + assignment_.erase(assignment_it); } } - /** apply unary operator */ + /// apply unary operator. NodePtr apply(const Unary& op) const override { auto r = boost::make_shared(label_, *this, op); return Unique(r); @@ -382,8 +384,8 @@ namespace gtsam { /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, - const Assignment& choices) const override { - auto r = boost::make_shared(label_, *this, op, choices); + const Assignment& assignment) const override { + auto r = boost::make_shared(label_, *this, op, assignment); return Unique(r); } @@ -678,7 +680,14 @@ namespace gtsam { } /****************************************************************************/ - // Functor performing depth-first visit without Assignment argument. + /** + * Functor performing depth-first visit without Assignment argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + */ template struct Visit { using F = std::function; @@ -707,33 +716,36 @@ namespace gtsam { } /****************************************************************************/ - // Functor performing depth-first visit with Assignment argument. + /** + * Functor performing depth-first visit with Assignment argument. + * + * NOTE: Follows the same pruning semantics as `visit`. + */ template struct VisitWith { - using Choices = Assignment; - using F = std::function; + using F = std::function&, const Y&)>; explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. - Choices choices; ///< Assignment, mutating through recursion. - F f; ///< folding function object. + Assignment assignment; ///< Assignment, mutating through recursion. + F f; ///< folding function object. /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) { using Leaf = typename DecisionTree::Leaf; if (auto leaf = boost::dynamic_pointer_cast(node)) - return f(choices, leaf->constant()); + return f(assignment, leaf->constant()); using Choice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(node); if (!choice) throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); for (size_t i = 0; i < choice->nrChoices(); i++) { - choices[choice->label()] = i; // Set assignment for label to i + assignment[choice->label()] = i; // Set assignment for label to i (*this)(choice->branches()[i]); // recurse! // Remove the choice so we are backtracking - auto choice_it = choices.find(choice->label()); - choices.erase(choice_it); + auto choice_it = assignment.find(choice->label()); + assignment.erase(choice_it); } } }; @@ -763,12 +775,14 @@ namespace gtsam { } /****************************************************************************/ - // labels is just done with a visit + // Get (partial) labels by performing a visit. template std::set DecisionTree::labels() const { std::set unique; - auto f = [&](const Assignment& choices, const Y&) { - for (auto&& kv : choices) unique.insert(kv.first); + auto f = [&](const Assignment& assignment, const Y&) { + for (auto&& kv : assignment) { + unique.insert(kv.first); + } }; visitWith(f); return unique; @@ -817,8 +831,8 @@ namespace gtsam { throw std::runtime_error( "DecisionTree::apply(unary op) undefined for empty tree."); } - Assignment choices; - return DecisionTree(root_->apply(op, choices)); + Assignment assignment; + return DecisionTree(root_->apply(op, assignment)); } /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index c0a2a7a1c6..9520d43bc4 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -105,7 +105,7 @@ namespace gtsam { virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const UnaryAssignment& op, - const Assignment& choices) const = 0; + const Assignment& assignment) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; @@ -153,7 +153,7 @@ namespace gtsam { /** Create a constant */ explicit DecisionTree(const Y& y); - /** Create a new leaf function splitting on a variable */ + /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ @@ -219,9 +219,8 @@ namespace gtsam { /// @name Standard Interface /// @{ - /** Make virtual */ - virtual ~DecisionTree() { - } + /// Make virtual + virtual ~DecisionTree() {} /// Check if tree is empty. bool empty() const { return !root_; } @@ -234,11 +233,13 @@ namespace gtsam { /** * @brief Visit all leaves in depth-first fashion. - * - * @param f side-effect taking a value. - * - * @note Due to pruning, leaves might not exhaust choices. - * + * + * @param f (side-effect) Function taking a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * * Example: * int sum = 0; * auto visitor = [&](int y) { sum += y; }; @@ -249,14 +250,16 @@ namespace gtsam { /** * @brief Visit all leaves in depth-first fashion. - * - * @param f side-effect taking an assignment and a value. - * - * @note Due to pruning, leaves might not exhaust choices. - * + * + * @param f (side-effect) Function taking an assignment and a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * * Example: * int sum = 0; - * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * auto visitor = [&](const Assignment& assignment, int y) { sum += y; }; * tree.visitWith(visitor); */ template @@ -275,7 +278,7 @@ namespace gtsam { * * @note X is always passed by value. * @note Due to pruning, leaves might not exhaust choices. - * + * * Example: * auto add = [](const double& y, double x) { return y + x; }; * double sum = tree.fold(add, 0.0); From e81e04acf57747a9442d297292049189716d9a2c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 06:35:59 -0400 Subject: [PATCH 2/5] convert DT_NO_PRUNING flag to GTSAM_DT_NO_PRUNING in case we wish to expose it via cmake --- gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/tests/testAlgebraicDecisionTree.cpp | 2 +- gtsam/discrete/tests/testDecisionTree.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 48da55ce1a..ed345461c1 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -186,7 +186,7 @@ namespace gtsam { /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { -#ifndef DT_NO_PRUNING +#ifndef GTSAM_DT_NO_PRUNING if (f->allSame_) { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index c800321d63..6a3fb23884 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -20,7 +20,7 @@ #include // make sure we have traits #include // headers first to make sure no missing headers -//#define DT_NO_PRUNING +//#define GTSAM_DT_NO_PRUNING #include #include // for convert only #define DISABLE_TIMING diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 14cf307a58..5ccbcf9162 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -18,7 +18,7 @@ */ // #define DT_DEBUG_MEMORY -// #define DT_NO_PRUNING +// #define GTSAM_DT_NO_PRUNING #define DISABLE_DOT #include From 039ecfc3c3bf67232ffe4da524e02ba65018fe4c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 10:03:23 -0400 Subject: [PATCH 3/5] add new visitLeaf method that provides the leaf as the function argument --- gtsam/discrete/DecisionTree-inl.h | 36 +++++++++++++++++++++++++++++++ gtsam/discrete/DecisionTree.h | 17 +++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index ed345461c1..b3ad667212 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -715,6 +715,42 @@ namespace gtsam { visit(root_); } + /****************************************************************************/ + /** + * Functor performing depth-first visit with Leaf argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + */ + template + struct VisitLeaf { + using F = std::function::Leaf&)>; + explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(*leaf); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr"); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visitLeaf(Func f) const { + VisitLeaf visit(f); + visit(root_); + } + /****************************************************************************/ /** * Functor performing depth-first visit with Assignment argument. diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 9520d43bc4..1f45d320b9 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -248,6 +248,23 @@ namespace gtsam { template void visit(Func f) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f (side-effect) Function taking the leaf node pointer. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitLeaf(Func f) const; + /** * @brief Visit all leaves in depth-first fashion. * From dac84e99321f992620a1004e9f873b4c14c85024 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 10:04:00 -0400 Subject: [PATCH 4/5] update prune to new max number of assignments scheme --- gtsam/discrete/DecisionTreeFactor.cpp | 10 +++++++--- gtsam/discrete/DecisionTreeFactor.h | 7 ++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index acd4d4af2b..4e16fc689e 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -287,12 +287,16 @@ namespace gtsam { cardinalities_(keys.cardinalities()) {} /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { - const size_t N = maxNrLeaves; + DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; - this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); + this->visitLeaf([&](const Leaf& leaf) { + size_t nrAssignments = leaf.nrAssignments(); + double prob = leaf.constant(); + probabilities.insert(probabilities.end(), nrAssignments, prob); + }); // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 1f3d692921..286571ffc4 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -175,12 +175,13 @@ namespace gtsam { * * Pruning will set the leaves to be "pruned" to 0 indicating a 0 * probability. - * A leaf is pruned if it is not in the top `maxNrLeaves` values. + * An assignment is pruned if it is not in the top `maxNrAssignments` + * values. * - * @param maxNrLeaves The maximum number of leaves to keep. + * @param maxNrAssignments The maximum number of assignments to keep. * @return DecisionTreeFactor */ - DecisionTreeFactor prune(size_t maxNrLeaves) const; + DecisionTreeFactor prune(size_t maxNrAssignments) const; /// @} /// @name Wrapper support From 8e6a583e7ea6811690902da6bb523b7306c8dc77 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 12 Apr 2022 14:04:42 -0400 Subject: [PATCH 5/5] update for better docstrings --- gtsam/discrete/DecisionTree-inl.h | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index b3ad667212..99f29b8e5f 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -681,12 +681,14 @@ namespace gtsam { /****************************************************************************/ /** - * Functor performing depth-first visit without Assignment argument. + * Functor performing depth-first visit to each leaf with the leaf value as + * the argument. * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it - * can have <8 leaves. For example, if a tree has all assignment values as 1, - * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + * can have less than 8 leaves. For example, if a tree has all assignment + * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 + * assignments. */ template struct Visit { @@ -717,7 +719,8 @@ namespace gtsam { /****************************************************************************/ /** - * Functor performing depth-first visit with Leaf argument. + * Functor performing depth-first visit to each leaf with the Leaf object + * passed as an argument. * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it @@ -753,7 +756,8 @@ namespace gtsam { /****************************************************************************/ /** - * Functor performing depth-first visit with Assignment argument. + * Functor performing depth-first visit to each leaf with the leaf's + * `Assignment` and value passed as arguments. * * NOTE: Follows the same pruning semantics as `visit`. */ @@ -811,7 +815,19 @@ namespace gtsam { } /****************************************************************************/ - // Get (partial) labels by performing a visit. + /** + * Get (partial) labels by performing a visit. + * + * This method performs a depth-first search to go to every leaf and records + * the keys assignment which leads to that leaf. Since the tree can be pruned, + * there might be a leaf at a lower depth which results in a partial + * assignment (i.e. not all keys are specified). + * + * E.g. given a tree with 3 keys, there may be a branch where the 3rd key has + * the same values for all the leaves. This leads to the branch being pruned + * so we get a leaf which is arrived at by just the first 2 keys and their + * assignments. + */ template std::set DecisionTree::labels() const { std::set unique;