Skip to content

Commit

Permalink
Merge pull request #1280 from borglab/hybrid/optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 29, 2022
2 parents 41b5354 + 7272268 commit a6b9554
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 17 deletions.
23 changes: 17 additions & 6 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>

namespace gtsam {

Expand Down Expand Up @@ -112,22 +112,27 @@ HybridBayesNet HybridBayesNet::prune(

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
return factors_.at(i)->asMixture();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>(
factors_.at(i)->inner());
return factors_.at(i)->asDiscreteConditional();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
try {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));

} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
continue;
}
}
return gbn;
}
Expand All @@ -138,4 +143,10 @@ HybridValues HybridBayesNet::optimize() const {
return dag.argmax();
}

/* *******************************************************************************/
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;

/**
* @brief Given the discrete assignment, return the optimized estimate for the
* selected Gaussian BayesNet.
*
* @param assignment An assignment of discrete values.
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;
};

} // namespace gtsam
44 changes: 44 additions & 0 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,48 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}

/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());

// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());

// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
// Bayes Net.
if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
GaussianConditional::shared_ptr gaussian_conditional =
(*gm)(assignment);

gbn.push_back(gaussian_conditional);
}
}
}
// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif

// Return the optimized bayes net.
return gbn.optimize();
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;

/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
* @param assignment The discrete values assignment to select the Gaussian
* mixtures.
* @return VectorValues
*/
VectorValues optimize(const DiscreteValues& assignment) const;

/// @}
};

Expand Down
9 changes: 4 additions & 5 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
BaseConditional; ///< Typedef to our conditional base class

protected:
// Type-erased pointer to the inner type
/// Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner_;

public:
Expand Down Expand Up @@ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture);
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);

/**
* @brief Return HybridConditional as a GaussianMixture
Expand Down Expand Up @@ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }

}; // DiscreteConditional
}; // HybridConditional

// traits
template <>
struct traits<HybridConditional> : public Testable<DiscreteConditional> {};
struct traits<HybridConditional> : public Testable<HybridConditional> {};

} // namespace gtsam
22 changes: 19 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
gfg.push_back(ptr->inner());
} else if (auto p =
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
gfg.push_back(
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
} else {
// It is an orphan wrapped conditional
}
Expand Down Expand Up @@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
OptionalOrderingType orderingType) const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
Base::push_back(sharedFactor);
}
}

/**
* @brief
*
* @param orderingType
* @return const Ordering
*/
const Ordering getHybridOrdering(
OptionalOrderingType orderingType = boost::none) const;
};

} // namespace gtsam
4 changes: 2 additions & 2 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct Switching {
// Add "motion models".
for (size_t k = 1; k < K; k++) {
KeyVector keys = {X(k), X(k + 1)};
auto motion_models = motionModels(k);
auto motion_models = motionModels(k, between_sigma);
std::vector<NonlinearFactor::shared_ptr> components;
for (auto &&f : motion_models) {
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
Expand All @@ -155,7 +155,7 @@ struct Switching {
}

// Add measurement factors
auto measurement_noise = noiseModel::Isotropic::Sigma(1, 0.1);
auto measurement_noise = noiseModel::Isotropic::Sigma(1, prior_sigma);
for (size_t k = 2; k <= K; k++) {
nonlinearFactorGraph.emplace_nonlinear<PriorFactor<double>>(
X(k), 1.0 * (k - 1), measurement_noise);
Expand Down
71 changes: 71 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>

#include "Switching.h"

Expand Down Expand Up @@ -85,6 +86,76 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3)));
}

/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Switching s(4);

Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}

HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);

DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;

VectorValues delta = hybridBayesNet->optimize(assignment);

// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));

EXPECT(assert_equal(expected_delta, delta));
}

/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);

Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);

HybridValues delta = hybridBayesNet->optimize();

delta.print();
VectorValues correct;
correct.insert(X(1), 0 * Vector1::Ones());
correct.insert(X(2), 1 * Vector1::Ones());
correct.insert(X(3), 2 * Vector1::Ones());
correct.insert(X(4), 3 * Vector1::Ones());

DiscreteValues assignment111;
assignment111[M(1)] = 1;
assignment111[M(2)] = 1;
assignment111[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;

DiscreteValues assignment101;
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading

0 comments on commit a6b9554

Please sign in to comment.