Skip to content

Commit

Permalink
Merge pull request #1221 from borglab/hybrid/hybrid-factor-graph
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 24, 2022
2 parents 525e512 + df64982 commit ff20a50
Show file tree
Hide file tree
Showing 38 changed files with 3,647 additions and 278 deletions.
10 changes: 10 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ class DiscreteBayesTree {
};

#include <gtsam/discrete/DiscreteLookupDAG.h>

class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials);
void print(
const std::string& s = "Discrete Lookup Table: ",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
};

class DiscreteLookupDAG {
DiscreteLookupDAG();
void push_back(const gtsam::DiscreteLookupTable* table);
Expand Down
27 changes: 26 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,36 @@ void GaussianMixture::print(const std::string &s,
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
if (gf && !gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&decisionTree](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};

auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;
}

} // namespace gtsam
23 changes: 18 additions & 5 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
Expand All @@ -30,11 +31,14 @@ namespace gtsam {

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network.
* part of a Bayes Network. This is the result of the elimination of a
* continuous variable in a hybrid scheme, such that the remaining variables are
* discrete+continuous.
*
* Represents the conditional density P(X | M, Z) where X is a continuous random
* variable, M is the selection of discrete variables corresponding to a subset
* of the Gaussian variables and Z is parent of this node
* Represents the conditional density P(X | M, Z) where X is the set of
* continuous random variables, M is the selection of discrete variables
* corresponding to a subset of the Gaussian variables and Z is parent of this
* node .
*
* The probability P(x|y,z,...) is proportional to
* \f$ \sum_i k_i \exp - \frac{1}{2} |R_i x - (d_i - S_i y - T_i z - ...)|^2 \f$
Expand Down Expand Up @@ -118,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

/* print utility */
/// Print utility
void print(
const std::string &s = "GaussianMixture\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
Expand All @@ -128,6 +132,15 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
*
* @param decisionTree A pruned decision tree of discrete keys where the
* leaves are probabilities.
*/
void prune(const DecisionTreeFactor &decisionTree);

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << "]{\n";
std::cout << "{\n";
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h>

namespace gtsam {
Expand Down Expand Up @@ -108,7 +108,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

void print(
const std::string &s = "HybridFactor\n",
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}

Expand Down
125 changes: 125 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,132 @@
* @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @date January 2022
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>

namespace gtsam {

/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the GaussianMixture for just nullptrs.
*/

HybridBayesNet prunedBayesNetFragment;

// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};

// Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

GaussianMixture::shared_ptr gaussianMixture =
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());

if (gaussianMixture) {
// We may have mixtures with less discrete keys than discreteFactor so we
// skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}

// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);

DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());

// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());

// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(prunedGaussianMixture));

} else {
// Add the non-GaussianMixture conditional
prunedBayesNetFragment.push_back(conditional);
}
}

return prunedBayesNetFragment;
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>(
factors_.at(i)->inner());
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
}
return gbn;
}

/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
}

} // namespace gtsam
36 changes: 36 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

#pragma once

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/linear/GaussianBayesNet.h>

namespace gtsam {

Expand All @@ -36,6 +39,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

/** Construct empty bayes net */
HybridBayesNet() = default;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;

/// Add HybridConditional to Bayes Net
using Base::add;

/// Add a discrete conditional to the Bayes Net.
void add(const DiscreteKey &key, const std::string &table) {
push_back(
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
}

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;

/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;

/**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesNet
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;
};

} // namespace gtsam
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/**
* @brief Class for Hybrid Bayes tree orphan subtrees.
*
* This does special stuff for the hybrid case
* This object stores parent keys in our base type factor so that
* eliminating those parent keys will pull this subtree into the
* elimination.
* This does special stuff for the hybrid case.
*
* @tparam CLIQUE
*/
Expand Down
17 changes: 12 additions & 5 deletions gtsam/hybrid/HybridFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_;
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
discreteKeys_ == e->discreteKeys_;
}

/* ************************************************************************ */
Expand All @@ -88,17 +89,23 @@ void HybridFactor::print(const std::string &s,
if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybrid ";
for (size_t c=0; c<continuousKeys_.size(); c++) {
std::cout << "[";
for (size_t c = 0; c < continuousKeys_.size(); c++) {
std::cout << formatter(continuousKeys_.at(c));
if (c < continuousKeys_.size() - 1) {
std::cout << " ";
} else {
std::cout << "; ";
std::cout << (discreteKeys_.size() > 0 ? "; " : "");
}
}
for(auto && discreteKey: discreteKeys_) {
std::cout << formatter(discreteKey.first) << " ";
for (size_t d = 0; d < discreteKeys_.size(); d++) {
std::cout << formatter(discreteKeys_.at(d).first);
if (d < discreteKeys_.size() - 1) {
std::cout << " ";
}

}
std::cout << "]";
}

} // namespace gtsam
11 changes: 7 additions & 4 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class GTSAM_EXPORT HybridFactor : public Factor {
size_t nrContinuous_ = 0;

protected:
// Set of DiscreteKeys for this factor.
DiscreteKeys discreteKeys_;

/// Record continuous keys for book-keeping
KeyVector continuousKeys_;

Expand Down Expand Up @@ -120,10 +120,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
bool isHybrid() const { return isHybrid_; }

/// Return the number of continuous variables in this factor.
size_t nrContinuous() const { return nrContinuous_; }
size_t nrContinuous() const { return continuousKeys_.size(); }

/// Return the discrete keys for this factor.
const DiscreteKeys &discreteKeys() const { return discreteKeys_; }

/// Return vector of discrete keys.
DiscreteKeys discreteKeys() const { return discreteKeys_; }
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// @}
};
Expand Down
Loading

0 comments on commit ff20a50

Please sign in to comment.