Skip to content

Commit 78d7e90

Browse files
authored
Merge pull request #1155 from borglab/decisiontree-refactor
2 parents 27c7bfe + 8e6a583 commit 78d7e90

6 files changed

+163
-72
lines changed

gtsam/discrete/DecisionTree-inl.h

+113-47
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace gtsam {
5959
/** constant stored in this leaf */
6060
Y constant_;
6161

62-
/** The number of assignments contained within this leaf
62+
/** The number of assignments contained within this leaf.
6363
* Particularly useful when leaves have been pruned.
6464
*/
6565
size_t nrAssignments_;
@@ -68,7 +68,7 @@ namespace gtsam {
6868
Leaf(const Y& constant, size_t nrAssignments = 1)
6969
: constant_(constant), nrAssignments_(nrAssignments) {}
7070

71-
/** return the constant */
71+
/// Return the constant
7272
const Y& constant() const {
7373
return constant_;
7474
}
@@ -81,19 +81,19 @@ namespace gtsam {
8181
return constant_ == q.constant_;
8282
}
8383

84-
/// polymorphic equality: is q is a leaf, could be
84+
/// polymorphic equality: is q a leaf and is it the same as this leaf?
8585
bool sameLeaf(const Node& q) const override {
8686
return (q.isLeaf() && q.sameLeaf(*this));
8787
}
8888

89-
/** equality up to tolerance */
89+
/// equality up to tolerance
9090
bool equals(const Node& q, const CompareFunc& compare) const override {
9191
const Leaf* other = dynamic_cast<const Leaf*>(&q);
9292
if (!other) return false;
9393
return compare(this->constant_, other->constant_);
9494
}
9595

96-
/** print */
96+
/// print
9797
void print(const std::string& s, const LabelFormatter& labelFormatter,
9898
const ValueFormatter& valueFormatter) const override {
9999
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@@ -122,8 +122,8 @@ namespace gtsam {
122122

123123
/// Apply unary operator with assignment
124124
NodePtr apply(const UnaryAssignment& op,
125-
const Assignment<L>& choices) const override {
126-
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
125+
const Assignment<L>& assignment) const override {
126+
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
127127
return f;
128128
}
129129

@@ -168,7 +168,10 @@ namespace gtsam {
168168
std::vector<NodePtr> branches_;
169169

170170
private:
171-
/** incremental allSame */
171+
/**
172+
* Incremental allSame.
173+
* Records if all the branches are the same leaf.
174+
*/
172175
size_t allSame_;
173176

174177
using ChoicePtr = boost::shared_ptr<const Choice>;
@@ -181,9 +184,9 @@ namespace gtsam {
181184
#endif
182185
}
183186

184-
/** If all branches of a choice node f are the same, just return a branch */
187+
/// If all branches of a choice node f are the same, just return a branch.
185188
static NodePtr Unique(const ChoicePtr& f) {
186-
#ifndef DT_NO_PRUNING
189+
#ifndef GTSAM_DT_NO_PRUNING
187190
if (f->allSame_) {
188191
assert(f->branches().size() > 0);
189192
NodePtr f0 = f->branches_[0];
@@ -205,15 +208,13 @@ namespace gtsam {
205208

206209
bool isLeaf() const override { return false; }
207210

208-
/** Constructor, given choice label and mandatory expected branch count */
211+
/// Constructor, given choice label and mandatory expected branch count.
209212
Choice(const L& label, size_t count) :
210213
label_(label), allSame_(true) {
211214
branches_.reserve(count);
212215
}
213216

214-
/**
215-
* Construct from applying binary op to two Choice nodes
216-
*/
217+
/// Construct from applying binary op to two Choice nodes.
217218
Choice(const Choice& f, const Choice& g, const Binary& op) :
218219
allSame_(true) {
219220
// Choose what to do based on label
@@ -241,6 +242,7 @@ namespace gtsam {
241242
}
242243
}
243244

245+
/// Return the label of this choice node.
244246
const L& label() const {
245247
return label_;
246248
}
@@ -262,7 +264,7 @@ namespace gtsam {
262264
branches_.push_back(node);
263265
}
264266

265-
/** print (as a tree) */
267+
/// print (as a tree).
266268
void print(const std::string& s, const LabelFormatter& labelFormatter,
267269
const ValueFormatter& valueFormatter) const override {
268270
std::cout << s << " Choice(";
@@ -308,7 +310,7 @@ namespace gtsam {
308310
return (q.isLeaf() && q.sameLeaf(*this));
309311
}
310312

311-
/** equality */
313+
/// equality
312314
bool equals(const Node& q, const CompareFunc& compare) const override {
313315
const Choice* other = dynamic_cast<const Choice*>(&q);
314316
if (!other) return false;
@@ -321,7 +323,7 @@ namespace gtsam {
321323
return true;
322324
}
323325

324-
/** evaluate */
326+
/// evaluate
325327
const Y& operator()(const Assignment<L>& x) const override {
326328
#ifndef NDEBUG
327329
typename Assignment<L>::const_iterator it = x.find(label_);
@@ -336,13 +338,13 @@ namespace gtsam {
336338
return (*child)(x);
337339
}
338340

339-
/**
340-
* Construct from applying unary op to a Choice node
341-
*/
341+
/// Construct from applying unary op to a Choice node.
342342
Choice(const L& label, const Choice& f, const Unary& op) :
343343
label_(label), allSame_(true) {
344344
branches_.reserve(f.branches_.size()); // reserve space
345-
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
345+
for (const NodePtr& branch : f.branches_) {
346+
push_back(branch->apply(op));
347+
}
346348
}
347349

348350
/**
@@ -353,37 +355,37 @@ namespace gtsam {
353355
* @param f The original choice node to apply the op on.
354356
* @param op Function to apply on the choice node. Takes Assignment and
355357
* value as arguments.
356-
* @param choices The Assignment that will go to op.
358+
* @param assignment The Assignment that will go to op.
357359
*/
358360
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
359-
const Assignment<L>& choices)
361+
const Assignment<L>& assignment)
360362
: label_(label), allSame_(true) {
361363
branches_.reserve(f.branches_.size()); // reserve space
362364

363-
Assignment<L> choices_ = choices;
365+
Assignment<L> assignment_ = assignment;
364366

365367
for (size_t i = 0; i < f.branches_.size(); i++) {
366-
choices_[label_] = i; // Set assignment for label to i
368+
assignment_[label_] = i; // Set assignment for label to i
367369

368370
const NodePtr branch = f.branches_[i];
369-
push_back(branch->apply(op, choices_));
371+
push_back(branch->apply(op, assignment_));
370372

371-
// Remove the choice so we are backtracking
372-
auto choice_it = choices_.find(label_);
373-
choices_.erase(choice_it);
373+
// Remove the assignment so we are backtracking
374+
auto assignment_it = assignment_.find(label_);
375+
assignment_.erase(assignment_it);
374376
}
375377
}
376378

377-
/** apply unary operator */
379+
/// apply unary operator.
378380
NodePtr apply(const Unary& op) const override {
379381
auto r = boost::make_shared<Choice>(label_, *this, op);
380382
return Unique(r);
381383
}
382384

383385
/// Apply unary operator with assignment
384386
NodePtr apply(const UnaryAssignment& op,
385-
const Assignment<L>& choices) const override {
386-
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
387+
const Assignment<L>& assignment) const override {
388+
auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
387389
return Unique(r);
388390
}
389391

@@ -678,7 +680,16 @@ namespace gtsam {
678680
}
679681

680682
/****************************************************************************/
681-
// Functor performing depth-first visit without Assignment<L> argument.
683+
/**
684+
* Functor performing depth-first visit to each leaf with the leaf value as
685+
* the argument.
686+
*
687+
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
688+
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
689+
* can have less than 8 leaves. For example, if a tree has all assignment
690+
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
691+
* assignments.
692+
*/
682693
template <typename L, typename Y>
683694
struct Visit {
684695
using F = std::function<void(const Y&)>;
@@ -707,33 +718,74 @@ namespace gtsam {
707718
}
708719

709720
/****************************************************************************/
710-
// Functor performing depth-first visit with Assignment<L> argument.
721+
/**
722+
* Functor performing depth-first visit to each leaf with the Leaf object
723+
* passed as an argument.
724+
*
725+
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
726+
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
727+
* can have <8 leaves. For example, if a tree has all assignment values as 1,
728+
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
729+
*/
730+
template <typename L, typename Y>
731+
struct VisitLeaf {
732+
using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
733+
explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function.
734+
F f; ///< folding function object.
735+
736+
/// Do a depth-first visit on the tree rooted at node.
737+
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
738+
using Leaf = typename DecisionTree<L, Y>::Leaf;
739+
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
740+
return f(*leaf);
741+
742+
using Choice = typename DecisionTree<L, Y>::Choice;
743+
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
744+
if (!choice)
745+
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
746+
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
747+
}
748+
};
749+
750+
template <typename L, typename Y>
751+
template <typename Func>
752+
void DecisionTree<L, Y>::visitLeaf(Func f) const {
753+
VisitLeaf<L, Y> visit(f);
754+
visit(root_);
755+
}
756+
757+
/****************************************************************************/
758+
/**
759+
* Functor performing depth-first visit to each leaf with the leaf's
760+
* `Assignment<L>` and value passed as arguments.
761+
*
762+
* NOTE: Follows the same pruning semantics as `visit`.
763+
*/
711764
template <typename L, typename Y>
712765
struct VisitWith {
713-
using Choices = Assignment<L>;
714-
using F = std::function<void(const Choices&, const Y&)>;
766+
using F = std::function<void(const Assignment<L>&, const Y&)>;
715767
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
716-
Choices choices; ///< Assignment, mutating through recursion.
717-
F f; ///< folding function object.
768+
Assignment<L> assignment; ///< Assignment, mutating through recursion.
769+
F f; ///< folding function object.
718770

719771
/// Do a depth-first visit on the tree rooted at node.
720772
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
721773
using Leaf = typename DecisionTree<L, Y>::Leaf;
722774
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
723-
return f(choices, leaf->constant());
775+
return f(assignment, leaf->constant());
724776

725777
using Choice = typename DecisionTree<L, Y>::Choice;
726778
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
727779
if (!choice)
728780
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
729781
for (size_t i = 0; i < choice->nrChoices(); i++) {
730-
choices[choice->label()] = i; // Set assignment for label to i
782+
assignment[choice->label()] = i; // Set assignment for label to i
731783

732784
(*this)(choice->branches()[i]); // recurse!
733785

734786
// Remove the choice so we are backtracking
735-
auto choice_it = choices.find(choice->label());
736-
choices.erase(choice_it);
787+
auto choice_it = assignment.find(choice->label());
788+
assignment.erase(choice_it);
737789
}
738790
}
739791
};
@@ -763,12 +815,26 @@ namespace gtsam {
763815
}
764816

765817
/****************************************************************************/
766-
// labels is just done with a visit
818+
/**
819+
* Get (partial) labels by performing a visit.
820+
*
821+
* This method performs a depth-first search to go to every leaf and records
822+
* the keys assignment which leads to that leaf. Since the tree can be pruned,
823+
* there might be a leaf at a lower depth which results in a partial
824+
* assignment (i.e. not all keys are specified).
825+
*
826+
* E.g. given a tree with 3 keys, there may be a branch where the 3rd key has
827+
* the same values for all the leaves. This leads to the branch being pruned
828+
* so we get a leaf which is arrived at by just the first 2 keys and their
829+
* assignments.
830+
*/
767831
template <typename L, typename Y>
768832
std::set<L> DecisionTree<L, Y>::labels() const {
769833
std::set<L> unique;
770-
auto f = [&](const Assignment<L>& choices, const Y&) {
771-
for (auto&& kv : choices) unique.insert(kv.first);
834+
auto f = [&](const Assignment<L>& assignment, const Y&) {
835+
for (auto&& kv : assignment) {
836+
unique.insert(kv.first);
837+
}
772838
};
773839
visitWith(f);
774840
return unique;
@@ -817,8 +883,8 @@ namespace gtsam {
817883
throw std::runtime_error(
818884
"DecisionTree::apply(unary op) undefined for empty tree.");
819885
}
820-
Assignment<L> choices;
821-
return DecisionTree(root_->apply(op, choices));
886+
Assignment<L> assignment;
887+
return DecisionTree(root_->apply(op, assignment));
822888
}
823889

824890
/****************************************************************************/

0 commit comments

Comments
 (0)