Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DecisionTree refactor #1000

Merged
merged 27 commits into from
Jan 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
94f2135
fix decision tree equality and move default constructor to public
varunagrawal Dec 29, 2021
ddaf960
add formatting capabilities to DecisionTree
varunagrawal Dec 29, 2021
8b5d93a
revert incorrect change
varunagrawal Dec 29, 2021
bb6e489
new DecisionTree constructor and methods that takes an op to convert …
varunagrawal Dec 29, 2021
315b10b
minor format
varunagrawal Dec 29, 2021
28071ed
added SFINAE methods for Leaf node equality checks
varunagrawal Dec 30, 2021
f1dedca
replace dot with DOT to prevent collision with vector dot product
varunagrawal Dec 30, 2021
1c76de4
minor fix
varunagrawal Dec 30, 2021
573d0d1
undo change to test
varunagrawal Dec 30, 2021
ed83908
formatter passed as reference and added a default formatter method
varunagrawal Dec 30, 2021
9982057
undo dot changes
varunagrawal Dec 30, 2021
b24da83
add comparator as argument to equals method
varunagrawal Dec 30, 2021
26c48a8
address more review comments
varunagrawal Dec 30, 2021
731cff7
rename comparator to compare and capture tol in the function lambda.
varunagrawal Dec 30, 2021
7f3f332
Removed copy/paste convert
dellaert Jan 2, 2022
78f8cc9
Define empty and check for it in apply variants
dellaert Jan 2, 2022
db3cb4d
Refactor print, equals, convert
dellaert Jan 2, 2022
5c4038c
Fixed dot to have right arguments
dellaert Jan 2, 2022
6c23fd1
Renamed protected method convert -> convertFrom
dellaert Jan 2, 2022
6364254
Fix compile error on windows
dellaert Jan 3, 2022
a9b2c32
Move DefaultFormatter to base class and add defaults. Also replace Su…
varunagrawal Jan 3, 2022
174490e
kill commented out code
varunagrawal Jan 3, 2022
cfb6011
replace typedef with using and improve docstrings
varunagrawal Jan 3, 2022
022b719
Undo DefaultFormatter change
varunagrawal Jan 3, 2022
8a28ac2
Merge pull request #1002 from borglab/feature/decision_tree_2
dellaert Jan 3, 2022
6cd3eeb
Some small doc changes
dellaert Jan 3, 2022
0631193
Added test and fixed constructor
dellaert Jan 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,22 @@ namespace gtsam {
* TODO: consider eliminating this class altogether?
*/
template<typename L>
class AlgebraicDecisionTree: public DecisionTree<L, double> {
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
std::stringstream ss;
ss << x;
return ss.str();
}

public:
public:

typedef DecisionTree<L, double> Super;
using Base = DecisionTree<L, double>;

/** The Real ring with addition and multiplication */
struct Ring {
Expand Down Expand Up @@ -60,57 +71,66 @@ namespace gtsam {
};

AlgebraicDecisionTree() :
Super(1.0) {
Base(1.0) {
}

AlgebraicDecisionTree(const Super& add) :
Super(add) {
AlgebraicDecisionTree(const Base& add) :
Base(add) {
}

/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) {
Base(label, y1, y2) {
}

/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) {
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Base(labelC, y1, y2) {
}

/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}

/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys));

// now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}

/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) {
Base(nullptr) {
this->root_ = compose(begin, end, label);
}

/** Convert */
/**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
this->root_ = this->template convert<M, double>(other.root_, map,
Ring::id);
const std::map<M, L>& map) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
// Functor for label conversion so we can use `convertFrom`.
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
}

/** sum */
Expand All @@ -134,10 +154,28 @@ namespace gtsam {
}

/** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add);
}

/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}

/// Equality method customized to value type `double`.
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
// lambda for comparison of two doubles upto some tolerance.
auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol;
};
return Base::equals(other, compare);
}
};
// AlgebraicDecisionTree

Expand Down
Loading