Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Multifrontal #1323

Merged
merged 23 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0938159
overload multifrontal elimination
varunagrawal Nov 9, 2022
98d3186
add copy constructor for HybridBayesTreeClique
varunagrawal Nov 10, 2022
7ae4e57
fix HybridBayesTree::optimize to account for pruned nodes
varunagrawal Nov 10, 2022
d54cf48
fix creation of updatedBayesTree
varunagrawal Nov 10, 2022
318f738
fixup the final tests
varunagrawal Nov 10, 2022
6e6bbff
update docstring for Ordering::+=
varunagrawal Nov 10, 2022
5e2cdfd
make continuousProbPrimes and continuousDeltas as templates
varunagrawal Nov 13, 2022
2394129
address review comments
varunagrawal Nov 15, 2022
05b2d31
better printing
varunagrawal Dec 3, 2022
3eaf4cc
move multifrontal optimize test to testHybridBayesTree and update doc…
varunagrawal Dec 3, 2022
cd3cfa0
moved sequential elimination code to HybridEliminationTree
varunagrawal Dec 3, 2022
15fffeb
add getters to HybridEliminationTree
varunagrawal Dec 4, 2022
addbe2a
override eliminate in HybridJunctionTree
varunagrawal Dec 4, 2022
ae0b3e3
split up the eliminate method to constituent parts
varunagrawal Dec 4, 2022
bed56e0
mark helper methods as protected and add docstrings
varunagrawal Dec 4, 2022
5fc114f
more unit tests
varunagrawal Dec 4, 2022
22e4a73
Add details about the role of the HybridEliminationTree in hybrid eli…
varunagrawal Dec 4, 2022
0596b2f
remove unrequired code
varunagrawal Dec 10, 2022
62bc9f2
update hybrid elimination and corresponding tests
varunagrawal Dec 10, 2022
6beffeb
remove commented out code
varunagrawal Dec 10, 2022
da5d3a2
Merge pull request #1339 from borglab/hybrid/new-elimination
varunagrawal Dec 10, 2022
812bf52
minor cleanup
varunagrawal Dec 21, 2022
46380ca
Merge branch 'hybrid/tests' into hybrid/multifrontal
varunagrawal Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;

/**
* @brief Construct a new Hybrid Assignment Data object.
Expand All @@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
gaussianbayesTree_(gbt),
valid_(valid) {}

bool isValid() const { return valid_; }

/**
* @brief A function used during tree traversal that operates on each node
Expand All @@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();

GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
Expand All @@ -111,15 +117,21 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>();
}

// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
GaussianBayesTree::sharedNode clique;
if (conditional) {
// Create the GaussianClique for the current node
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}

// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
parentData.gaussianbayesTree_, parentData.valid_);
return data;
}
};
Expand All @@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost);
}

if (!rootData.isValid()) {
return VectorValues();
}
VectorValues result = gbt.optimize();

// Return the optimized bayes net result.
Expand All @@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
// this->print();
// decisionTree->print("", DefaultKeyFormatter);

/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ class GTSAM_EXPORT HybridBayesTreeClique
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr;
HybridBayesTreeClique() {}
virtual ~HybridBayesTreeClique() {}
HybridBayesTreeClique(const boost::shared_ptr<HybridConditional>& conditional)
: Base(conditional) {}
///< Copy constructor
HybridBayesTreeClique(const HybridBayesTreeClique& clique) : Base(clique) {}

virtual ~HybridBayesTreeClique() {}
};

/* ************************************************************************* */
Expand Down
179 changes: 166 additions & 13 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,24 @@ HybridGaussianFactorGraph::continuousDelta(
return delta_tree;
}

/* ************************************************************************ */
DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesTreeType> &continuousBayesTree,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = continuousBayesTree->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);

return delta_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
Expand Down Expand Up @@ -584,6 +602,67 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
return probPrimeTree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesTreeType> &continuousBayesTree) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);

// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());

// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesTree, assignments);

// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
VectorValues delta = *delta_tree(assignment);

// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}

// Compute the error given the delta and the assignment.
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}

AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}

/* ************************************************************************ */
std::pair<Ordering, Ordering>
HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
const Ordering &ordering) const {
KeySet all_continuous_keys = this->continuousKeys();
KeySet all_discrete_keys = this->discreteKeys();
Ordering continuous_ordering, discrete_ordering;

for (auto &&key : ordering) {
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
key) != all_continuous_keys.end()) {
continuous_ordering.push_back(key);
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
key) != all_discrete_keys.end()) {
discrete_ordering.push_back(key);
} else {
throw std::runtime_error("Key in ordering not present in factors.");
}
}

return std::make_pair(continuous_ordering, discrete_ordering);
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential(
Expand Down Expand Up @@ -640,25 +719,99 @@ boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
KeySet all_continuous_keys = this->continuousKeys();
KeySet all_discrete_keys = this->discreteKeys();
// Segregate the continuous and the discrete keys
Ordering continuous_ordering, discrete_ordering;
std::tie(continuous_ordering, discrete_ordering) =
this->separateContinuousDiscreteOrdering(ordering);

return this->eliminateHybridSequential(continuous_ordering, discrete_ordering,
function, variableIndex);
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateHybridMultifrontal(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
Ordering continuous_ordering =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

continuous ? *continuous : Ordering(this->continuousKeys());
Ordering discrete_ordering =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

discrete ? *discrete : Ordering(this->discreteKeys());

// Eliminate continuous
HybridBayesTree::shared_ptr bayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesTree, discreteGraph) =
BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering,
function, variableIndex);

// Get the last continuous conditional which will have all the discrete
Key last_continuous_key =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

continuous_ordering.at(continuous_ordering.size() - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

back()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the type of conditional? In this case, use const Type& instead of auto. I think it's a Bayes net, no? So, rather than conditional, name it something more revealing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a GaussianMixture since it is a continuous variable conditioned on discrete variables.

DiscreteKeys discrete_keys = last_conditional->discreteKeys();

// If not discrete variables, return the eliminated bayes net.
if (discrete_keys.size() == 0) {
return bayesTree;
}

AlgebraicDecisionTree<Key> probPrimeTree =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const, doc

this->continuousProbPrimes(discrete_keys, bayesTree);

discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are mutating here? Highly suspicious to me. Please explain and add explanation as doc.

Copy link
Collaborator Author

@varunagrawal varunagrawal Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the addition of the model selection term P(M|Z). We've eliminated all the continuous variables so we have the q(\mu)*sqrt(det) terms as the probPrimeTree. I have another PR after this one where I am adding the sqrt(det) term.


auto updatedBayesTree =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc, const auto&

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't be const since we add the discrete clique as the root so that the Bayes tree is valid.

discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering,
function);

auto discrete_clique = (*updatedBayesTree)[discrete_ordering.at(0)];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc, const auto&


std::set<HybridBayesTreeClique::shared_ptr> clique_set;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I particularly want a review on this part. Am I creating the Bayes Tree with the discrete root correctly? Unit tests seem to say yes, but if there is a better way, I am all ears.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without documentation (e.g. with math, I don;t even know what is happeingin). In general true for this entire piece of code. Now that you have the math, migth as well add it?

for (auto node : bayesTree->nodes()) {
clique_set.insert(node.second);
}

// Set the root of the bayes tree as the discrete clique
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more mutating? Very suspicious.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how else to update the BayesTree so set the discrete clique to be the root. That is a requirement for the structure of the HybridBayesTree.

for (auto clique : clique_set) {
if (clique->conditional()->parents() ==
discrete_clique->conditional()->frontals()) {
updatedBayesTree->addClique(clique, discrete_clique);

// Segregate the continuous and the discrete keys
for (auto &&key : ordering) {
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
key) != all_continuous_keys.end()) {
continuous_ordering.push_back(key);
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
key) != all_discrete_keys.end()) {
discrete_ordering.push_back(key);
} else {
throw std::runtime_error("Key in ordering not present in factors.");
// Remove the clique from the children of the parents since it will get
// added again in addClique.
auto clique_it = std::find(clique->parent()->children.begin(),
clique->parent()->children.end(), clique);
clique->parent()->children.erase(clique_it);
updatedBayesTree->addClique(clique, clique->parent());
}
}
return updatedBayesTree;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateMultifrontal(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateMultifrontal(orderingType, function,
variableIndex);
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesTreeType>
HybridGaussianFactorGraph::eliminateMultifrontal(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
// Segregate the continuous and the discrete keys
Ordering continuous_ordering, discrete_ordering;
std::tie(continuous_ordering, discrete_ordering) =
this->separateContinuousDiscreteOrdering(ordering);

return this->eliminateHybridSequential(continuous_ordering,
discrete_ordering);
return this->eliminateHybridMultifrontal(
continuous_ordering, discrete_ordering, function, variableIndex);
}

} // namespace gtsam
26 changes: 26 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const;
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesTreeType>& continuousBayesTree,
const std::vector<DiscreteValues>& assignments) const;

/**
* @brief Compute the unnormalized probabilities of the continuous variables
Expand All @@ -244,6 +248,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> continuousProbPrimes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a great name. It refers to math that is not in the documentation. Suggest unnormalizedProbabilities documenting the fact that it is exp(-E). On a side note: I question the wisdom of exponentiating small numbers before they will somehow be normalized. Typically we avoid doing this until we know the minimum error and then subtract the minimum error so the corresponding maximum prob' == 1.0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are unnormalized probabilities only on the continuous variables and probPrime is the method name for computing the unnormalized probabilities, so it seemed to make sense. If you still think I should rename it to unnormalizedProbabilities, will gladly do so.

Regarding the exponentiation, the minimum error term will always be 0 due to pruning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: this function will get deleted.

const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const;
AlgebraicDecisionTree<Key> continuousProbPrimes(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesTreeType>& continuousBayesTree) const;

std::pair<Ordering, Ordering> separateContinuousDiscreteOrdering(
const Ordering& ordering) const;

/**
* @brief Custom elimination function which computes the correct
Expand All @@ -269,6 +279,22 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

boost::shared_ptr<BayesTreeType> eliminateHybridMultifrontal(
const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

boost::shared_ptr<BayesTreeType> eliminateMultifrontal(
const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
7 changes: 7 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ TEST(HybridBayesTree, Optimize) {
dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
}

// Add the probabilities for each branch
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656,
0.037152205, 0.12248971, 0.07349729, 0.08};
AlgebraicDecisionTree<Key> potentials(discrete_keys, probs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added but not used?

dfg.emplace_shared<DecisionTreeFactor>(discrete_keys, probs);

DiscreteValues expectedMPE = dfg.optimize();
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
Expand Down
Loading