Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HybridBayesTree Optimize #1287

Merged
merged 2 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gtsam/base/treeTraversal-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
PrintForestVisitorPre visitor(keyFormatter);
DepthFirstForest(forest, str, visitor);
}
}
} // namespace treeTraversal

}
} // namespace gtsam
174 changes: 94 additions & 80 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move inst to bottom

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but the formatter keeps moving it back 😕

#include <gtsam/linear/GaussianJunctionTree.h>

namespace gtsam {

Expand All @@ -39,95 +40,108 @@ bool HybridBayesTree::equals(const This& other, double tol) const {

/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
HybridBayesNet hbn;
DiscreteBayesNet dbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

if (conditional->isDiscrete()) {
// If discrete, we use it to compute the MPE
dbn.push_back(conditional->asDiscreteConditional());

} else {
// Else conditional is hybrid or continuous-only,
// so we directly add it to the Hybrid Bayes net.
hbn.push_back(conditional);
}
}
DiscreteValues mpe;

auto root = roots_.at(0);
// Access the clique and get the underlying hybrid conditional
HybridConditional::shared_ptr root_conditional = root->conditional();

// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination "
"ordering or use continuous factor graph.");
}
// Get the MPE
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hbn.choose(mpe);

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif

return HybridValues(mpe, gbn.optimize());

VectorValues values = optimize(mpe);
return HybridValues(mpe, values);
}

/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
// Bayes Net.
if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
GaussianConditional::shared_ptr gaussian_conditional =
(*gm)(assignment);

gbn.push_back(gaussian_conditional);

} else if (conditional->isContinuous()) {
// If conditional is Gaussian, we simply add it to the Bayes net.
gbn.push_back(conditional->asGaussian());
}
/**
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
*
* When traversing the tree, the pre-order visitor will receive an instance of
* this class with the parent clique data.
*/
struct HybridAssignmentData {
const DiscreteValues assignment_;
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;

/**
* @brief Construct a new Hybrid Assignment Data object.
*
* @param assignment The MPE assignment for the optimal Gaussian cliques.
* @param parentClique The clique from the parent node of the current node.
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}

/**
* @brief A function used during tree traversal that operators on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
} else if (hybrid_conditional->isContinuous()) {
conditional = hybrid_conditional->asGaussian();
} else {
// Discrete only conditional, so we set to empty gaussian conditional
conditional = boost::make_shared<GaussianConditional>();
}

// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);

// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
return data;
}
};

/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
visitorPost);
}

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif
VectorValues result = gbt.optimize();

// Return the optimized bayes net.
return gbn.optimize();
// Return the optimized bayes net result.
return result;
}

} // namespace gtsam
20 changes: 20 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>

#include "Switching.h"
Expand Down Expand Up @@ -149,6 +150,25 @@ TEST(HybridBayesNet, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test bayes net multifrontal optimize
TEST(HybridBayesNet, OptimizeMultifrontal) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree::shared_ptr hybridBayesTree =
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
HybridValues delta = hybridBayesTree->optimize();

VectorValues expectedValues;
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());

EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Expand Down