Skip to content

Commit

Permalink
Merge pull request #1323 from borglab/hybrid/multifrontal
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 23, 2022
2 parents 6b6731a + 46380ca commit 583d121
Show file tree
Hide file tree
Showing 16 changed files with 311 additions and 244 deletions.
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/
void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
Expand Down
31 changes: 23 additions & 8 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* @brief Hybrid Bayes Tree, the result of eliminating a
* HybridJunctionTree
* @date Mar 11, 2022
* @author Fan Jiang
* @author Fan Jiang, Varun Agrawal
*/

#include <gtsam/base/treeTraversal-inst.h>
Expand Down Expand Up @@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;

/**
* @brief Construct a new Hybrid Assignment Data object.
Expand All @@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
gaussianbayesTree_(gbt),
valid_(valid) {}

bool isValid() const { return valid_; }

/**
* @brief A function used during tree traversal that operates on each node
Expand All @@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();

GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
Expand All @@ -111,15 +117,21 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>();
}

// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
GaussianBayesTree::sharedNode clique;
if (conditional) {
// Create the GaussianClique for the current node
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}

// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
parentData.gaussianbayesTree_, parentData.valid_);
return data;
}
};
Expand All @@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost);
}

if (!rootData.isValid()) {
return VectorValues();
}
VectorValues result = gbt.optimize();

// Return the optimized bayes net result.
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ class GTSAM_EXPORT HybridBayesTreeClique
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr;
HybridBayesTreeClique() {}
virtual ~HybridBayesTreeClique() {}
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
: Base(conditional) {}
///< Copy constructor
HybridBayesTreeClique(const HybridBayesTreeClique& clique) : Base(clique) {}

virtual ~HybridBayesTreeClique() {}
};

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridEliminationTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace gtsam {

/**
* Elimination Tree type for Hybrid
* Elimination Tree type for Hybrid Factor Graphs.
*
* @ingroup hybrid
*/
Expand Down
126 changes: 8 additions & 118 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ GaussianMixtureFactor::Sum sumFrontals(
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = cgmf->add(sum);
}

if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum);
}
Expand Down Expand Up @@ -189,7 +188,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());

// sum out frontals, this is the factor on the separator
// sum out frontals, this is the factor 𝜏 on the separator
GaussianMixtureFactor::Sum sum = sumFrontals(factors);

// If a tree leaf contains nullptr,
Expand Down Expand Up @@ -257,13 +256,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create here a
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
// TODO(Varun) Use the math from the iMHS_Math-1-indexed document
VectorValues empty_values;
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
return 1.0;
double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error);
}
};
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
Expand Down Expand Up @@ -529,122 +529,13 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
}

/* ************************************************************************ */
DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);

return delta_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);

// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());

// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);

// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
VectorValues delta = *delta_tree(assignment);

// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}

// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}

AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
Ordering continuous_ordering =
continuous ? *continuous : Ordering(this->continuousKeys());
Ordering discrete_ordering =
discrete ? *discrete : Ordering(this->discreteKeys());

// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
function, variableIndex);

// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();

// If not discrete variables, return the eliminated bayes net.
if (discrete_keys.size() == 0) {
return bayesNet;
}

AlgebraicDecisionTree<Key> probPrimeTree =
this->continuousProbPrimes(discrete_keys, bayesNet);

discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));

// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->BaseEliminateable::eliminateSequential(
discrete_ordering, function, variableIndex);

bayesNet->add(*discreteBayesNet);

return bayesNet;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateSequential(orderingType, function,
variableIndex);
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
std::pair<Ordering, Ordering>
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
const Ordering &ordering) const {
KeySet all_continuous_keys = this->continuousKeys();
KeySet all_discrete_keys = this->discreteKeys();
Ordering continuous_ordering, discrete_ordering;

// Segregate the continuous and the discrete keys
for (auto &&key : ordering) {
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
key) != all_continuous_keys.end()) {
Expand All @@ -657,8 +548,7 @@ HybridGaussianFactorGraph::eliminateSequential(
}
}

return this->eliminateHybridSequential(continuous_ordering,
discrete_ordering);
return std::make_pair(continuous_ordering, discrete_ordering);
}

} // namespace gtsam
Loading

0 comments on commit 583d121

Please sign in to comment.