@@ -59,7 +59,7 @@ namespace gtsam {
59
59
/* * constant stored in this leaf */
60
60
Y constant_;
61
61
62
- /* * The number of assignments contained within this leaf
62
+ /* * The number of assignments contained within this leaf.
63
63
* Particularly useful when leaves have been pruned.
64
64
*/
65
65
size_t nrAssignments_;
@@ -68,7 +68,7 @@ namespace gtsam {
68
68
Leaf (const Y& constant, size_t nrAssignments = 1 )
69
69
: constant_(constant), nrAssignments_(nrAssignments) {}
70
70
71
- /* * return the constant */
71
+ // / Return the constant
72
72
const Y& constant () const {
73
73
return constant_;
74
74
}
@@ -81,19 +81,19 @@ namespace gtsam {
81
81
return constant_ == q.constant_ ;
82
82
}
83
83
84
- // / polymorphic equality: is q is a leaf, could be
84
+ // / polymorphic equality: is q a leaf and is it the same as this leaf?
85
85
bool sameLeaf (const Node& q) const override {
86
86
return (q.isLeaf () && q.sameLeaf (*this ));
87
87
}
88
88
89
- /* * equality up to tolerance */
89
+ // / equality up to tolerance
90
90
bool equals (const Node& q, const CompareFunc& compare) const override {
91
91
const Leaf* other = dynamic_cast <const Leaf*>(&q);
92
92
if (!other) return false ;
93
93
return compare (this ->constant_ , other->constant_ );
94
94
}
95
95
96
- /* * print */
96
+ // / print
97
97
void print (const std::string& s, const LabelFormatter& labelFormatter,
98
98
const ValueFormatter& valueFormatter) const override {
99
99
std::cout << s << " Leaf " << valueFormatter (constant_) << std::endl;
@@ -122,8 +122,8 @@ namespace gtsam {
122
122
123
123
// / Apply unary operator with assignment
124
124
NodePtr apply (const UnaryAssignment& op,
125
- const Assignment<L>& choices ) const override {
126
- NodePtr f (new Leaf (op (choices , constant_), nrAssignments_));
125
+ const Assignment<L>& assignment ) const override {
126
+ NodePtr f (new Leaf (op (assignment , constant_), nrAssignments_));
127
127
return f;
128
128
}
129
129
@@ -168,7 +168,10 @@ namespace gtsam {
168
168
std::vector<NodePtr> branches_;
169
169
170
170
private:
171
- /* * incremental allSame */
171
+ /* *
172
+ * Incremental allSame.
173
+ * Records if all the branches are the same leaf.
174
+ */
172
175
size_t allSame_;
173
176
174
177
using ChoicePtr = boost::shared_ptr<const Choice>;
@@ -181,9 +184,9 @@ namespace gtsam {
181
184
#endif
182
185
}
183
186
184
- /* * If all branches of a choice node f are the same, just return a branch */
187
+ // / If all branches of a choice node f are the same, just return a branch.
185
188
static NodePtr Unique (const ChoicePtr& f) {
186
- #ifndef DT_NO_PRUNING
189
+ #ifndef GTSAM_DT_NO_PRUNING
187
190
if (f->allSame_ ) {
188
191
assert (f->branches ().size () > 0 );
189
192
NodePtr f0 = f->branches_ [0 ];
@@ -205,15 +208,13 @@ namespace gtsam {
205
208
206
209
bool isLeaf () const override { return false ; }
207
210
208
- /* * Constructor, given choice label and mandatory expected branch count */
211
+ // / Constructor, given choice label and mandatory expected branch count.
209
212
Choice (const L& label, size_t count) :
210
213
label_ (label), allSame_(true ) {
211
214
branches_.reserve (count);
212
215
}
213
216
214
- /* *
215
- * Construct from applying binary op to two Choice nodes
216
- */
217
+ // / Construct from applying binary op to two Choice nodes.
217
218
Choice (const Choice& f, const Choice& g, const Binary& op) :
218
219
allSame_ (true ) {
219
220
// Choose what to do based on label
@@ -241,6 +242,7 @@ namespace gtsam {
241
242
}
242
243
}
243
244
245
+ // / Return the label of this choice node.
244
246
const L& label () const {
245
247
return label_;
246
248
}
@@ -262,7 +264,7 @@ namespace gtsam {
262
264
branches_.push_back (node);
263
265
}
264
266
265
- /* * print (as a tree) */
267
+ // / print (as a tree).
266
268
void print (const std::string& s, const LabelFormatter& labelFormatter,
267
269
const ValueFormatter& valueFormatter) const override {
268
270
std::cout << s << " Choice(" ;
@@ -308,7 +310,7 @@ namespace gtsam {
308
310
return (q.isLeaf () && q.sameLeaf (*this ));
309
311
}
310
312
311
- /* * equality */
313
+ // / equality
312
314
bool equals (const Node& q, const CompareFunc& compare) const override {
313
315
const Choice* other = dynamic_cast <const Choice*>(&q);
314
316
if (!other) return false ;
@@ -321,7 +323,7 @@ namespace gtsam {
321
323
return true ;
322
324
}
323
325
324
- /* * evaluate */
326
+ // / evaluate
325
327
const Y& operator ()(const Assignment<L>& x) const override {
326
328
#ifndef NDEBUG
327
329
typename Assignment<L>::const_iterator it = x.find (label_);
@@ -336,13 +338,13 @@ namespace gtsam {
336
338
return (*child)(x);
337
339
}
338
340
339
- /* *
340
- * Construct from applying unary op to a Choice node
341
- */
341
+ // / Construct from applying unary op to a Choice node.
342
342
Choice (const L& label, const Choice& f, const Unary& op) :
343
343
label_ (label), allSame_(true ) {
344
344
branches_.reserve (f.branches_ .size ()); // reserve space
345
- for (const NodePtr& branch : f.branches_ ) push_back (branch->apply (op));
345
+ for (const NodePtr& branch : f.branches_ ) {
346
+ push_back (branch->apply (op));
347
+ }
346
348
}
347
349
348
350
/* *
@@ -353,37 +355,37 @@ namespace gtsam {
353
355
* @param f The original choice node to apply the op on.
354
356
* @param op Function to apply on the choice node. Takes Assignment and
355
357
* value as arguments.
356
- * @param choices The Assignment that will go to op.
358
+ * @param assignment The Assignment that will go to op.
357
359
*/
358
360
Choice (const L& label, const Choice& f, const UnaryAssignment& op,
359
- const Assignment<L>& choices )
361
+ const Assignment<L>& assignment )
360
362
: label_(label), allSame_(true ) {
361
363
branches_.reserve (f.branches_ .size ()); // reserve space
362
364
363
- Assignment<L> choices_ = choices ;
365
+ Assignment<L> assignment_ = assignment ;
364
366
365
367
for (size_t i = 0 ; i < f.branches_ .size (); i++) {
366
- choices_ [label_] = i; // Set assignment for label to i
368
+ assignment_ [label_] = i; // Set assignment for label to i
367
369
368
370
const NodePtr branch = f.branches_ [i];
369
- push_back (branch->apply (op, choices_ ));
371
+ push_back (branch->apply (op, assignment_ ));
370
372
371
- // Remove the choice so we are backtracking
372
- auto choice_it = choices_ .find (label_);
373
- choices_ .erase (choice_it );
373
+ // Remove the assignment so we are backtracking
374
+ auto assignment_it = assignment_ .find (label_);
375
+ assignment_ .erase (assignment_it );
374
376
}
375
377
}
376
378
377
- /* * apply unary operator */
379
+ // / apply unary operator.
378
380
NodePtr apply (const Unary& op) const override {
379
381
auto r = boost::make_shared<Choice>(label_, *this , op);
380
382
return Unique (r);
381
383
}
382
384
383
385
// / Apply unary operator with assignment
384
386
NodePtr apply (const UnaryAssignment& op,
385
- const Assignment<L>& choices ) const override {
386
- auto r = boost::make_shared<Choice>(label_, *this , op, choices );
387
+ const Assignment<L>& assignment ) const override {
388
+ auto r = boost::make_shared<Choice>(label_, *this , op, assignment );
387
389
return Unique (r);
388
390
}
389
391
@@ -678,7 +680,16 @@ namespace gtsam {
678
680
}
679
681
680
682
/* ***************************************************************************/
681
- // Functor performing depth-first visit without Assignment<L> argument.
683
+ /* *
684
+ * Functor performing depth-first visit to each leaf with the leaf value as
685
+ * the argument.
686
+ *
687
+ * NOTE: We differentiate between leaves and assignments. Concretely, a 3
688
+ * binary variable tree will have 2^3=8 assignments, but based on pruning, it
689
+ * can have less than 8 leaves. For example, if a tree has all assignment
690
+ * values as 1, then pruning will cause the tree to have only 1 leaf yet 8
691
+ * assignments.
692
+ */
682
693
template <typename L, typename Y>
683
694
struct Visit {
684
695
using F = std::function<void (const Y&)>;
@@ -707,33 +718,74 @@ namespace gtsam {
707
718
}
708
719
709
720
/* ***************************************************************************/
710
- // Functor performing depth-first visit with Assignment<L> argument.
721
+ /* *
722
+ * Functor performing depth-first visit to each leaf with the Leaf object
723
+ * passed as an argument.
724
+ *
725
+ * NOTE: We differentiate between leaves and assignments. Concretely, a 3
726
+ * binary variable tree will have 2^3=8 assignments, but based on pruning, it
727
+ * can have <8 leaves. For example, if a tree has all assignment values as 1,
728
+ * then pruning will cause the tree to have only 1 leaf yet 8 assignments.
729
+ */
730
+ template <typename L, typename Y>
731
+ struct VisitLeaf {
732
+ using F = std::function<void (const typename DecisionTree<L, Y>::Leaf&)>;
733
+ explicit VisitLeaf (F f) : f(f) {} // /< Construct from folding function.
734
+ F f; // /< folding function object.
735
+
736
+ // / Do a depth-first visit on the tree rooted at node.
737
+ void operator ()(const typename DecisionTree<L, Y>::NodePtr& node) const {
738
+ using Leaf = typename DecisionTree<L, Y>::Leaf;
739
+ if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
740
+ return f (*leaf);
741
+
742
+ using Choice = typename DecisionTree<L, Y>::Choice;
743
+ auto choice = boost::dynamic_pointer_cast<const Choice>(node);
744
+ if (!choice)
745
+ throw std::invalid_argument (" DecisionTree::VisitLeaf: Invalid NodePtr" );
746
+ for (auto && branch : choice->branches ()) (*this )(branch); // recurse!
747
+ }
748
+ };
749
+
750
+ template <typename L, typename Y>
751
+ template <typename Func>
752
+ void DecisionTree<L, Y>::visitLeaf(Func f) const {
753
+ VisitLeaf<L, Y> visit (f);
754
+ visit (root_);
755
+ }
756
+
757
+ /* ***************************************************************************/
758
+ /* *
759
+ * Functor performing depth-first visit to each leaf with the leaf's
760
+ * `Assignment<L>` and value passed as arguments.
761
+ *
762
+ * NOTE: Follows the same pruning semantics as `visit`.
763
+ */
711
764
template <typename L, typename Y>
712
765
struct VisitWith {
713
- using Choices = Assignment<L>;
714
- using F = std::function<void (const Choices&, const Y&)>;
766
+ using F = std::function<void (const Assignment<L>&, const Y&)>;
715
767
explicit VisitWith (F f) : f(f) {} // /< Construct from folding function.
716
- Choices choices ; // /< Assignment, mutating through recursion.
717
- F f; // /< folding function object.
768
+ Assignment<L> assignment ; // /< Assignment, mutating through recursion.
769
+ F f; // /< folding function object.
718
770
719
771
// / Do a depth-first visit on the tree rooted at node.
720
772
void operator ()(const typename DecisionTree<L, Y>::NodePtr& node) {
721
773
using Leaf = typename DecisionTree<L, Y>::Leaf;
722
774
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
723
- return f (choices , leaf->constant ());
775
+ return f (assignment , leaf->constant ());
724
776
725
777
using Choice = typename DecisionTree<L, Y>::Choice;
726
778
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
727
779
if (!choice)
728
780
throw std::invalid_argument (" DecisionTree::VisitWith: Invalid NodePtr" );
729
781
for (size_t i = 0 ; i < choice->nrChoices (); i++) {
730
- choices [choice->label ()] = i; // Set assignment for label to i
782
+ assignment [choice->label ()] = i; // Set assignment for label to i
731
783
732
784
(*this )(choice->branches ()[i]); // recurse!
733
785
734
786
// Remove the choice so we are backtracking
735
- auto choice_it = choices .find (choice->label ());
736
- choices .erase (choice_it);
787
+ auto choice_it = assignment .find (choice->label ());
788
+ assignment .erase (choice_it);
737
789
}
738
790
}
739
791
};
@@ -763,12 +815,26 @@ namespace gtsam {
763
815
}
764
816
765
817
/* ***************************************************************************/
766
- // labels is just done with a visit
818
+ /* *
819
+ * Get (partial) labels by performing a visit.
820
+ *
821
+ * This method performs a depth-first search to go to every leaf and records
822
+ * the keys assignment which leads to that leaf. Since the tree can be pruned,
823
+ * there might be a leaf at a lower depth which results in a partial
824
+ * assignment (i.e. not all keys are specified).
825
+ *
826
+ * E.g. given a tree with 3 keys, there may be a branch where the 3rd key has
827
+ * the same values for all the leaves. This leads to the branch being pruned
828
+ * so we get a leaf which is arrived at by just the first 2 keys and their
829
+ * assignments.
830
+ */
767
831
template <typename L, typename Y>
768
832
std::set<L> DecisionTree<L, Y>::labels() const {
769
833
std::set<L> unique;
770
- auto f = [&](const Assignment<L>& choices, const Y&) {
771
- for (auto && kv : choices) unique.insert (kv.first );
834
+ auto f = [&](const Assignment<L>& assignment, const Y&) {
835
+ for (auto && kv : assignment) {
836
+ unique.insert (kv.first );
837
+ }
772
838
};
773
839
visitWith (f);
774
840
return unique;
@@ -817,8 +883,8 @@ namespace gtsam {
817
883
throw std::runtime_error (
818
884
" DecisionTree::apply(unary op) undefined for empty tree." );
819
885
}
820
- Assignment<L> choices ;
821
- return DecisionTree (root_->apply (op, choices ));
886
+ Assignment<L> assignment ;
887
+ return DecisionTree (root_->apply (op, assignment ));
822
888
}
823
889
824
890
/* ***************************************************************************/
0 commit comments