-
Notifications
You must be signed in to change notification settings - Fork 765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid Multifrontal #1323
Hybrid Multifrontal #1323
Changes from 6 commits
0938159
98d3186
7ae4e57
d54cf48
318f738
6e6bbff
5e2cdfd
2394129
05b2d31
3eaf4cc
cd3cfa0
15fffeb
addbe2a
ae0b3e3
bed56e0
5fc114f
22e4a73
0596b2f
62bc9f2
6beffeb
da5d3a2
812bf52
46380ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -546,6 +546,24 @@ HybridGaussianFactorGraph::continuousDelta( | |
return delta_tree; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
DecisionTree<Key, VectorValues::shared_ptr> | ||
HybridGaussianFactorGraph::continuousDelta( | ||
const DiscreteKeys &discrete_keys, | ||
const boost::shared_ptr<BayesTreeType> &continuousBayesTree, | ||
const std::vector<DiscreteValues> &assignments) const { | ||
// Create a decision tree of all the different VectorValues | ||
std::vector<VectorValues::shared_ptr> vector_values; | ||
for (const DiscreteValues &assignment : assignments) { | ||
VectorValues values = continuousBayesTree->optimize(assignment); | ||
vector_values.push_back(boost::make_shared<VectorValues>(values)); | ||
} | ||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys, | ||
vector_values); | ||
|
||
return delta_tree; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes( | ||
const DiscreteKeys &orig_discrete_keys, | ||
|
@@ -584,6 +602,67 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes( | |
return probPrimeTree; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes( | ||
const DiscreteKeys &orig_discrete_keys, | ||
const boost::shared_ptr<BayesTreeType> &continuousBayesTree) const { | ||
// Generate all possible assignments. | ||
const std::vector<DiscreteValues> assignments = | ||
DiscreteValues::CartesianProduct(orig_discrete_keys); | ||
|
||
// Save a copy of the original discrete key ordering | ||
DiscreteKeys discrete_keys(orig_discrete_keys); | ||
// Reverse discrete keys order for correct tree construction | ||
std::reverse(discrete_keys.begin(), discrete_keys.end()); | ||
|
||
// Create a decision tree of all the different VectorValues | ||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree = | ||
this->continuousDelta(discrete_keys, continuousBayesTree, assignments); | ||
|
||
// Get the probPrime tree with the correct leaf probabilities | ||
std::vector<double> probPrimes; | ||
for (const DiscreteValues &assignment : assignments) { | ||
VectorValues delta = *delta_tree(assignment); | ||
|
||
// If VectorValues is empty, it means this is a pruned branch. | ||
// Set thr probPrime to 0.0. | ||
if (delta.size() == 0) { | ||
probPrimes.push_back(0.0); | ||
continue; | ||
} | ||
|
||
// Compute the error given the delta and the assignment. | ||
double error = this->error(delta, assignment); | ||
probPrimes.push_back(exp(-error)); | ||
} | ||
|
||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); | ||
return probPrimeTree; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
std::pair<Ordering, Ordering> | ||
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( | ||
const Ordering &ordering) const { | ||
KeySet all_continuous_keys = this->continuousKeys(); | ||
KeySet all_discrete_keys = this->discreteKeys(); | ||
Ordering continuous_ordering, discrete_ordering; | ||
|
||
for (auto &&key : ordering) { | ||
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(), | ||
key) != all_continuous_keys.end()) { | ||
continuous_ordering.push_back(key); | ||
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(), | ||
key) != all_discrete_keys.end()) { | ||
discrete_ordering.push_back(key); | ||
} else { | ||
throw std::runtime_error("Key in ordering not present in factors."); | ||
} | ||
} | ||
|
||
return std::make_pair(continuous_ordering, discrete_ordering); | ||
} | ||
|
||
/* ************************************************************************ */ | ||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType> | ||
HybridGaussianFactorGraph::eliminateHybridSequential( | ||
|
@@ -640,25 +719,99 @@ boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType> | |
HybridGaussianFactorGraph::eliminateSequential( | ||
const Ordering &ordering, const Eliminate &function, | ||
OptionalVariableIndex variableIndex) const { | ||
KeySet all_continuous_keys = this->continuousKeys(); | ||
KeySet all_discrete_keys = this->discreteKeys(); | ||
// Segregate the continuous and the discrete keys | ||
Ordering continuous_ordering, discrete_ordering; | ||
std::tie(continuous_ordering, discrete_ordering) = | ||
this->separateContinuousDiscreteOrdering(ordering); | ||
|
||
return this->eliminateHybridSequential(continuous_ordering, discrete_ordering, | ||
function, variableIndex); | ||
} | ||
|
||
/* ************************************************************************ */ | ||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType> | ||
HybridGaussianFactorGraph::eliminateHybridMultifrontal( | ||
const boost::optional<Ordering> continuous, | ||
const boost::optional<Ordering> discrete, const Eliminate &function, | ||
OptionalVariableIndex variableIndex) const { | ||
Ordering continuous_ordering = | ||
continuous ? *continuous : Ordering(this->continuousKeys()); | ||
Ordering discrete_ordering = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const |
||
discrete ? *discrete : Ordering(this->discreteKeys()); | ||
|
||
// Eliminate continuous | ||
HybridBayesTree::shared_ptr bayesTree; | ||
HybridGaussianFactorGraph::shared_ptr discreteGraph; | ||
std::tie(bayesTree, discreteGraph) = | ||
BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering, | ||
function, variableIndex); | ||
|
||
// Get the last continuous conditional which will have all the discrete | ||
Key last_continuous_key = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const |
||
continuous_ordering.at(continuous_ordering.size() - 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. back() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! |
||
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the type of conditional? In this case, use const Type& instead of auto. I think it's a Bayes net, no? So, rather than conditional, name it something more revealing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be a |
||
DiscreteKeys discrete_keys = last_conditional->discreteKeys(); | ||
|
||
// If not discrete variables, return the eliminated bayes net. | ||
if (discrete_keys.size() == 0) { | ||
return bayesTree; | ||
} | ||
|
||
AlgebraicDecisionTree<Key> probPrimeTree = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const, doc |
||
this->continuousProbPrimes(discrete_keys, bayesTree); | ||
|
||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are mutating here? Highly suspicious to me. Please explain and add explanation as doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the addition of the model selection term |
||
|
||
auto updatedBayesTree = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc, const auto& There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't be const since we add the discrete clique as the root so that the Bayes tree is valid. |
||
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering, | ||
function); | ||
|
||
auto discrete_clique = (*updatedBayesTree)[discrete_ordering.at(0)]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc, const auto& |
||
|
||
std::set<HybridBayesTreeClique::shared_ptr> clique_set; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I particularly want a review on this part. Am I creating the Bayes Tree with the discrete root correctly? Unit tests seem to say yes, but if there is a better way, I am all ears. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. without documentation (e.g. with math, I don;t even know what is happeingin). In general true for this entire piece of code. Now that you have the math, migth as well add it? |
||
for (auto node : bayesTree->nodes()) { | ||
clique_set.insert(node.second); | ||
} | ||
|
||
// Set the root of the bayes tree as the discrete clique | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more mutating? Very suspicious. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure how else to update the BayesTree so set the discrete clique to be the root. That is a requirement for the structure of the HybridBayesTree. |
||
for (auto clique : clique_set) { | ||
if (clique->conditional()->parents() == | ||
discrete_clique->conditional()->frontals()) { | ||
updatedBayesTree->addClique(clique, discrete_clique); | ||
|
||
// Segregate the continuous and the discrete keys | ||
for (auto &&key : ordering) { | ||
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(), | ||
key) != all_continuous_keys.end()) { | ||
continuous_ordering.push_back(key); | ||
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(), | ||
key) != all_discrete_keys.end()) { | ||
discrete_ordering.push_back(key); | ||
} else { | ||
throw std::runtime_error("Key in ordering not present in factors."); | ||
// Remove the clique from the children of the parents since it will get | ||
// added again in addClique. | ||
auto clique_it = std::find(clique->parent()->children.begin(), | ||
clique->parent()->children.end(), clique); | ||
clique->parent()->children.erase(clique_it); | ||
updatedBayesTree->addClique(clique, clique->parent()); | ||
} | ||
} | ||
return updatedBayesTree; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType> | ||
HybridGaussianFactorGraph::eliminateMultifrontal( | ||
OptionalOrderingType orderingType, const Eliminate &function, | ||
OptionalVariableIndex variableIndex) const { | ||
return BaseEliminateable::eliminateMultifrontal(orderingType, function, | ||
variableIndex); | ||
} | ||
|
||
/* ************************************************************************ */ | ||
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType> | ||
HybridGaussianFactorGraph::eliminateMultifrontal( | ||
const Ordering &ordering, const Eliminate &function, | ||
OptionalVariableIndex variableIndex) const { | ||
// Segregate the continuous and the discrete keys | ||
Ordering continuous_ordering, discrete_ordering; | ||
std::tie(continuous_ordering, discrete_ordering) = | ||
this->separateContinuousDiscreteOrdering(ordering); | ||
|
||
return this->eliminateHybridSequential(continuous_ordering, | ||
discrete_ordering); | ||
return this->eliminateHybridMultifrontal( | ||
continuous_ordering, discrete_ordering, function, variableIndex); | ||
} | ||
|
||
} // namespace gtsam |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,6 +231,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |
const DiscreteKeys& discrete_keys, | ||
const boost::shared_ptr<BayesNetType>& continuousBayesNet, | ||
const std::vector<DiscreteValues>& assignments) const; | ||
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta( | ||
const DiscreteKeys& discrete_keys, | ||
const boost::shared_ptr<BayesTreeType>& continuousBayesTree, | ||
const std::vector<DiscreteValues>& assignments) const; | ||
|
||
/** | ||
* @brief Compute the unnormalized probabilities of the continuous variables | ||
|
@@ -244,6 +248,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |
AlgebraicDecisionTree<Key> continuousProbPrimes( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a great name. It refers to math that is not in the documentation. Suggest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are unnormalized probabilities only on the continuous variables and Regarding the exponentiation, the minimum error term will always be 0 due to pruning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: this function will get deleted. |
||
const DiscreteKeys& discrete_keys, | ||
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const; | ||
AlgebraicDecisionTree<Key> continuousProbPrimes( | ||
const DiscreteKeys& discrete_keys, | ||
const boost::shared_ptr<BayesTreeType>& continuousBayesTree) const; | ||
|
||
std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering( | ||
const Ordering& ordering) const; | ||
|
||
/** | ||
* @brief Custom elimination function which computes the correct | ||
|
@@ -269,6 +279,22 @@ class GTSAM_EXPORT HybridGaussianFactorGraph | |
const Eliminate& function = EliminationTraitsType::DefaultEliminate, | ||
OptionalVariableIndex variableIndex = boost::none) const; | ||
|
||
boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal( | ||
const boost::optional<Ordering> continuous = boost::none, | ||
const boost::optional<Ordering> discrete = boost::none, | ||
const Eliminate& function = EliminationTraitsType::DefaultEliminate, | ||
OptionalVariableIndex variableIndex = boost::none) const; | ||
|
||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal( | ||
OptionalOrderingType orderingType = boost::none, | ||
const Eliminate& function = EliminationTraitsType::DefaultEliminate, | ||
OptionalVariableIndex variableIndex = boost::none) const; | ||
|
||
boost::shared_ptr<BayesTreeType> eliminateMultifrontal( | ||
const Ordering& ordering, | ||
const Eliminate& function = EliminationTraitsType::DefaultEliminate, | ||
OptionalVariableIndex variableIndex = boost::none) const; | ||
|
||
/** | ||
* @brief Return a Colamd constrained ordering where the discrete keys are | ||
* eliminated after the continuous keys. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,13 @@ TEST(HybridBayesTree, Optimize) { | |
dfg.push_back( | ||
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner())); | ||
} | ||
|
||
// Add the probabilities for each branch | ||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; | ||
vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656, | ||
0.037152205, 0.12248971, 0.07349729, 0.08}; | ||
AlgebraicDecisionTree<Key> potentials(discrete_keys, probs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is added but not used? |
||
dfg.emplace_shared<DecisionTreeFactor>(discrete_keys, probs); | ||
|
||
DiscreteValues expectedMPE = dfg.optimize(); | ||
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const