Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Multifrontal #1323

Merged
merged 23 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0938159
overload multifrontal elimination
varunagrawal Nov 9, 2022
98d3186
add copy constructor for HybridBayesTreeClique
varunagrawal Nov 10, 2022
7ae4e57
fix HybridBayesTree::optimize to account for pruned nodes
varunagrawal Nov 10, 2022
d54cf48
fix creation of updatedBayesTree
varunagrawal Nov 10, 2022
318f738
fixup the final tests
varunagrawal Nov 10, 2022
6e6bbff
update docstring for Ordering::+=
varunagrawal Nov 10, 2022
5e2cdfd
make continuousProbPrimes and continuousDeltas as templates
varunagrawal Nov 13, 2022
2394129
address review comments
varunagrawal Nov 15, 2022
05b2d31
better printing
varunagrawal Dec 3, 2022
3eaf4cc
move multifrontal optimize test to testHybridBayesTree and update doc…
varunagrawal Dec 3, 2022
cd3cfa0
moved sequential elimination code to HybridEliminationTree
varunagrawal Dec 3, 2022
15fffeb
add getters to HybridEliminationTree
varunagrawal Dec 4, 2022
addbe2a
override eliminate in HybridJunctionTree
varunagrawal Dec 4, 2022
ae0b3e3
split up the eliminate method to constituent parts
varunagrawal Dec 4, 2022
bed56e0
mark helper methods as protected and add docstrings
varunagrawal Dec 4, 2022
5fc114f
more unit tests
varunagrawal Dec 4, 2022
22e4a73
Add details about the role of the HybridEliminationTree in hybrid eli…
varunagrawal Dec 4, 2022
0596b2f
remove unrequired code
varunagrawal Dec 10, 2022
62bc9f2
update hybrid elimination and corresponding tests
varunagrawal Dec 10, 2022
6beffeb
remove commented out code
varunagrawal Dec 10, 2022
da5d3a2
Merge pull request #1339 from borglab/hybrid/new-elimination
varunagrawal Dec 10, 2022
812bf52
minor cleanup
varunagrawal Dec 21, 2022
46380ca
Merge branch 'hybrid/tests' into hybrid/multifrontal
varunagrawal Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/
void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
Expand Down
31 changes: 24 additions & 7 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;

/**
* @brief Construct a new Hybrid Assignment Data object.
Expand All @@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
gaussianbayesTree_(gbt),
valid_(valid) {}

bool isValid() const { return valid_; }

/**
* @brief A function used during tree traversal that operates on each node
Expand All @@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();

GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
Expand All @@ -111,15 +117,21 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>();
}

// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
GaussianBayesTree::sharedNode clique;
if (conditional) {
// Create the GaussianClique for the current node
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}

// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
parentData.gaussianbayesTree_, parentData.valid_);
return data;
}
};
Expand All @@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost);
}

if (!rootData.isValid()) {
return VectorValues();
}
VectorValues result = gbt.optimize();

// Return the optimized bayes net result.
Expand All @@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
// this->print();
// decisionTree->print("", DefaultKeyFormatter);

/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ class GTSAM_EXPORT HybridBayesTreeClique
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr;
HybridBayesTreeClique() {}
virtual ~HybridBayesTreeClique() {}
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
: Base(conditional) {}
///< Copy constructor
HybridBayesTreeClique(const HybridBayesTreeClique& clique) : Base(clique) {}

virtual ~HybridBayesTreeClique() {}
};

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridEliminationTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace gtsam {

/**
* Elimination Tree type for Hybrid
* Elimination Tree type for Hybrid Factor Graphs.
*
* @ingroup hybrid
*/
Expand Down
123 changes: 7 additions & 116 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create here a
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
// TODO(Varun) Use the math from the iMHS_Math-1-indexed document
VectorValues empty_values;
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
return 1.0;
double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error);
}
};
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
Expand Down Expand Up @@ -529,122 +530,13 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
}

/* ************************************************************************ */
DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);

return delta_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);

// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());

// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);

// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
VectorValues delta = *delta_tree(assignment);

// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}

// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}

AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
Ordering continuous_ordering =
continuous ? *continuous : Ordering(this->continuousKeys());
Ordering discrete_ordering =
discrete ? *discrete : Ordering(this->discreteKeys());

// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
function, variableIndex);

// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();

// If not discrete variables, return the eliminated bayes net.
if (discrete_keys.size() == 0) {
return bayesNet;
}

AlgebraicDecisionTree<Key> probPrimeTree =
this->continuousProbPrimes(discrete_keys, bayesNet);

discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));

// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->BaseEliminateable::eliminateSequential(
discrete_ordering, function, variableIndex);

bayesNet->add(*discreteBayesNet);

return bayesNet;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateSequential(orderingType, function,
variableIndex);
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
std::pair<Ordering, Ordering>
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
const Ordering &ordering) const {
KeySet all_continuous_keys = this->continuousKeys();
KeySet all_discrete_keys = this->discreteKeys();
Ordering continuous_ordering, discrete_ordering;

// Segregate the continuous and the discrete keys
for (auto &&key : ordering) {
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
key) != all_continuous_keys.end()) {
Expand All @@ -657,8 +549,7 @@ HybridGaussianFactorGraph::eliminateSequential(
}
}

return this->eliminateHybridSequential(continuous_ordering,
discrete_ordering);
return std::make_pair(continuous_ordering, discrete_ordering);
}

} // namespace gtsam
93 changes: 64 additions & 29 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,57 +217,92 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
const DiscreteValues& discreteValues) const;

/**
* @brief Compute the VectorValues solution for the continuous variables for
* each mode.
* @brief Helper method to compute the VectorValues solution for the
* continuous variables for each discrete mode.
* Used as a helper to compute q(\mu | M, Z) which is used by
* both P(X | M, Z) and P(M | Z).
*
* @tparam BAYES Template on the type of Bayes graph, either a bayes net or a
* bayes tree.
* @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous
* @param continuousBayesNet The Bayes Net/Tree representing the continuous
* eliminated variables.
* @param assignments List of all discrete assignments to create the final
* decision tree.
* @return DecisionTree<Key, VectorValues::shared_ptr>
*/
template <typename BAYES>
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const;
const boost::shared_ptr<BAYES>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues& assignment : assignments) {
VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);

return delta_tree;
}

/**
* @brief Compute the unnormalized probabilities of the continuous variables
* for each of the modes.
*
* @tparam BAYES Template on the type of Bayes graph, either a bayes net or a
* bayes tree.
* @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous
* eliminated variables.
* @return AlgebraicDecisionTree<Key>
*/
template <typename BAYES>
AlgebraicDecisionTree<Key> continuousProbPrimes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a great name. It refers to math that is not in the documentation. Suggest unnormalizedProbabilities documenting the fact that it is exp(-E). On a side note: I question the wisdom of exponentiating small numbers before they will somehow be normalized. Typically we avoid doing this until we know the minimum error and then subtract the minimum error so the corresponding maximum prob' == 1.0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are unnormalized probabilities only on the continuous variables and probPrime is the method name for computing the unnormalized probabilities, so it seemed to make sense. If you still think I should rename it to unnormalizedProbabilities, will gladly do so.

Regarding the exponentiation, the minimum error term will always be 0 due to pruning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: this function will get deleted.

const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const;
const boost::shared_ptr<BAYES>& continuousBayesNet) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);

// Save a copy of the original discrete key ordering
DiscreteKeys reversed_discrete_keys(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(reversed_discrete_keys.begin(), reversed_discrete_keys.end());

// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(reversed_discrete_keys, continuousBayesNet,
assignments);

// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues& assignment : assignments) {
VectorValues delta = *delta_tree(assignment);

// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}

// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}

/**
* @brief Custom elimination function which computes the correct
* continuous probabilities.
*
* @param continuous Optional ordering for all continuous variables.
* @param discrete Optional ordering for all discrete variables.
* @return boost::shared_ptr<BayesNetType>
*/
boost::shared_ptr<BayesNetType> eliminateHybridSequential(
const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

boost::shared_ptr<BayesNetType> eliminateSequential(
OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

boost::shared_ptr<BayesNetType> eliminateSequential(
const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
AlgebraicDecisionTree<Key> probPrimeTree(reversed_discrete_keys,
probPrimes);
return probPrimeTree;
}

std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
const Ordering& ordering) const;



/**
* @brief Return a Colamd constrained ordering where the discrete keys are
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/HybridJunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ class HybridEliminationTree;
*/
class GTSAM_EXPORT HybridJunctionTree
: public JunctionTree<HybridBayesTree, HybridGaussianFactorGraph> {

public:
typedef JunctionTree<HybridBayesTree, HybridGaussianFactorGraph>
Base; ///< Base class
typedef HybridJunctionTree This; ///< This class
typedef HybridJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class

/**
Expand Down
Loading