Skip to content

Commit

Permalink
Merge pull request #1203 from varunagrawal/fan/prototype-hybrid-tr
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jun 3, 2022
2 parents 12aed1f + eeecb27 commit 9a1eb02
Show file tree
Hide file tree
Showing 43 changed files with 3,374 additions and 24 deletions.
1 change: 1 addition & 0 deletions gtsam/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set (gtsam_subdirs
inference
symbolic
discrete
hybrid
linear
nonlinear
sam
Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Install headers
set(subdir hybrid)
file(GLOB hybrid_headers "*.h")
# FIXME: exclude headers
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid)

# Add all tests
add_subdirectory(tests)
110 changes: 110 additions & 0 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file GaussianMixture.cpp
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang
* @author Varun Agrawal
* @author Frank Dellaert
* @date Mar 12, 2022
*/

#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const GaussianMixture::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}

/* *******************************************************************************/
const GaussianMixture::Conditionals &
GaussianMixture::conditionals() {
return conditionals_;
}

/* *******************************************************************************/
GaussianMixture GaussianMixture::FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);

return GaussianMixture(continuousFrontals, continuousParents,
discreteParents, dt);
}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianMixture::Sum
GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);
return result;
};
return {conditionals_, lambda};
}

/* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf,
double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol);
}

/* *******************************************************************************/
void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << "\nDiscrete Keys = ";
for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << "\n";
conditionals_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}
} // namespace gtsam
133 changes: 133 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file GaussianMixture.h
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang
* @author Varun Agrawal
* @date Mar 12, 2022
*/

#pragma once

#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>

namespace gtsam {

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network.
*
* Represents the conditional density P(X | M, Z) where X is a continuous random
* variable, M is the selection of discrete variables corresponding to a subset
* of the Gaussian variables and Z is parent of this node
*
* The probability P(x|y,z,...) is proportional to
* \f$ \sum_i k_i \exp - \frac{1}{2} |R_i x - (d_i - S_i y - T_i z - ...)|^2 \f$
* where i indexes the components and k_i is a component-wise normalization
* constant.
*
*/
class GTSAM_EXPORT GaussianMixture
: public HybridFactor,
public Conditional<HybridFactor, GaussianMixture> {
public:
using This = GaussianMixture;
using shared_ptr = boost::shared_ptr<GaussianMixture>;
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;

/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;

/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_;

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;

public:
/// @name Constructors
/// @{

/// Defaut constructor, mainly for serialization.
GaussianMixture() = default;

/**
* @brief Construct a new GaussianMixture object.
*
* @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents.
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
GaussianMixture(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/// @}
/// @name Testable
/// @{

/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

/* print utility */
void print(
const std::string &s = "GaussianMixture\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;

/// @}

/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum
*/
Sum add(const Sum &sum) const;
};

// traits
template <>
struct traits<GaussianMixture> : public Testable<GaussianMixture> {};

} // namespace gtsam
94 changes: 94 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file GaussianMixtureFactor.cpp
* @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Frank Dellaert
* @date Mar 12, 2022
*/

#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}

/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
}

/* *******************************************************************************/
GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) {
Factors dt(discreteKeys, factors);

return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
}

/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
factors_.print(
"mixture = ", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}

/* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_;
}

/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);
return result;
};
return {factors_, wrap};
}
} // namespace gtsam
Loading

0 comments on commit 9a1eb02

Please sign in to comment.