-
Notifications
You must be signed in to change notification settings - Fork 765
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1203 from varunagrawal/fan/prototype-hybrid-tr
- Loading branch information
Showing
43 changed files
with
3,374 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ set (gtsam_subdirs | |
inference | ||
symbolic | ||
discrete | ||
hybrid | ||
linear | ||
nonlinear | ||
sam | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Install headers | ||
set(subdir hybrid) | ||
file(GLOB hybrid_headers "*.h") | ||
# FIXME: exclude headers | ||
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid) | ||
|
||
# Add all tests | ||
add_subdirectory(tests) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file GaussianMixture.cpp | ||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme | ||
* @author Fan Jiang | ||
* @author Varun Agrawal | ||
* @author Frank Dellaert | ||
* @date Mar 12, 2022 | ||
*/ | ||
|
||
#include <gtsam/base/utilities.h> | ||
#include <gtsam/discrete/DecisionTree-inl.h> | ||
#include <gtsam/hybrid/GaussianMixture.h> | ||
#include <gtsam/inference/Conditional-inst.h> | ||
#include <gtsam/linear/GaussianFactorGraph.h> | ||
|
||
namespace gtsam { | ||
|
||
GaussianMixture::GaussianMixture( | ||
const KeyVector &continuousFrontals, const KeyVector &continuousParents, | ||
const DiscreteKeys &discreteParents, | ||
const GaussianMixture::Conditionals &conditionals) | ||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), | ||
discreteParents), | ||
BaseConditional(continuousFrontals.size()), | ||
conditionals_(conditionals) {} | ||
|
||
/* *******************************************************************************/ | ||
const GaussianMixture::Conditionals & | ||
GaussianMixture::conditionals() { | ||
return conditionals_; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixture GaussianMixture::FromConditionals( | ||
const KeyVector &continuousFrontals, const KeyVector &continuousParents, | ||
const DiscreteKeys &discreteParents, | ||
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) { | ||
Conditionals dt(discreteParents, conditionalsList); | ||
|
||
return GaussianMixture(continuousFrontals, continuousParents, | ||
discreteParents, dt); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixture::Sum GaussianMixture::add( | ||
const GaussianMixture::Sum &sum) const { | ||
using Y = GaussianFactorGraph; | ||
auto add = [](const Y &graph1, const Y &graph2) { | ||
auto result = graph1; | ||
result.push_back(graph2); | ||
return result; | ||
}; | ||
const Sum tree = asGaussianFactorGraphTree(); | ||
return sum.empty() ? tree : sum.apply(tree, add); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixture::Sum | ||
GaussianMixture::asGaussianFactorGraphTree() const { | ||
auto lambda = [](const GaussianFactor::shared_ptr &factor) { | ||
GaussianFactorGraph result; | ||
result.push_back(factor); | ||
return result; | ||
}; | ||
return {conditionals_, lambda}; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
bool GaussianMixture::equals(const HybridFactor &lf, | ||
double tol) const { | ||
const This *e = dynamic_cast<const This *>(&lf); | ||
return e != nullptr && BaseFactor::equals(*e, tol); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
void GaussianMixture::print(const std::string &s, | ||
const KeyFormatter &formatter) const { | ||
std::cout << s; | ||
if (isContinuous()) std::cout << "Continuous "; | ||
if (isDiscrete()) std::cout << "Discrete "; | ||
if (isHybrid()) std::cout << "Hybrid "; | ||
BaseConditional::print("", formatter); | ||
std::cout << "\nDiscrete Keys = "; | ||
for (auto &dk : discreteKeys()) { | ||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; | ||
} | ||
std::cout << "\n"; | ||
conditionals_.print( | ||
"", [&](Key k) { return formatter(k); }, | ||
[&](const GaussianConditional::shared_ptr &gf) -> std::string { | ||
RedirectCout rd; | ||
if (!gf->empty()) | ||
gf->print("", formatter); | ||
else | ||
return {"nullptr"}; | ||
return rd.str(); | ||
}); | ||
} | ||
} // namespace gtsam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file GaussianMixture.h | ||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme | ||
* @author Fan Jiang | ||
* @author Varun Agrawal | ||
* @date Mar 12, 2022 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <gtsam/discrete/DecisionTree.h> | ||
#include <gtsam/hybrid/HybridFactor.h> | ||
#include <gtsam/inference/Conditional.h> | ||
#include <gtsam/linear/GaussianConditional.h> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* @brief A conditional of gaussian mixtures indexed by discrete variables, as | ||
* part of a Bayes Network. | ||
* | ||
* Represents the conditional density P(X | M, Z) where X is a continuous random | ||
* variable, M is the selection of discrete variables corresponding to a subset | ||
* of the Gaussian variables and Z is parent of this node | ||
* | ||
* The probability P(x|y,z,...) is proportional to | ||
* \f$ \sum_i k_i \exp - \frac{1}{2} |R_i x - (d_i - S_i y - T_i z - ...)|^2 \f$ | ||
* where i indexes the components and k_i is a component-wise normalization | ||
* constant. | ||
* | ||
*/ | ||
class GTSAM_EXPORT GaussianMixture | ||
: public HybridFactor, | ||
public Conditional<HybridFactor, GaussianMixture> { | ||
public: | ||
using This = GaussianMixture; | ||
using shared_ptr = boost::shared_ptr<GaussianMixture>; | ||
using BaseFactor = HybridFactor; | ||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; | ||
|
||
/// Alias for DecisionTree of GaussianFactorGraphs | ||
using Sum = DecisionTree<Key, GaussianFactorGraph>; | ||
|
||
/// typedef for Decision Tree of Gaussian Conditionals | ||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; | ||
|
||
private: | ||
Conditionals conditionals_; | ||
|
||
/** | ||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. | ||
*/ | ||
Sum asGaussianFactorGraphTree() const; | ||
|
||
public: | ||
/// @name Constructors | ||
/// @{ | ||
|
||
/// Defaut constructor, mainly for serialization. | ||
GaussianMixture() = default; | ||
|
||
/** | ||
* @brief Construct a new GaussianMixture object. | ||
* | ||
* @param continuousFrontals the continuous frontals. | ||
* @param continuousParents the continuous parents. | ||
* @param discreteParents the discrete parents. Will be placed last. | ||
* @param conditionals a decision tree of GaussianConditionals. The number of | ||
* conditionals should be C^(number of discrete parents), where C is the | ||
* cardinality of the DiscreteKeys in discreteParents, since the | ||
* discreteParents will be used as the labels in the decision tree. | ||
*/ | ||
GaussianMixture(const KeyVector &continuousFrontals, | ||
const KeyVector &continuousParents, | ||
const DiscreteKeys &discreteParents, | ||
const Conditionals &conditionals); | ||
|
||
/** | ||
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals | ||
* | ||
* @param continuousFrontals The continuous frontal variables | ||
* @param continuousParents The continuous parent variables | ||
* @param discreteParents Discrete parents variables | ||
* @param conditionals List of conditionals | ||
*/ | ||
static This FromConditionals( | ||
const KeyVector &continuousFrontals, const KeyVector &continuousParents, | ||
const DiscreteKeys &discreteParents, | ||
const std::vector<GaussianConditional::shared_ptr> &conditionals); | ||
|
||
/// @} | ||
/// @name Testable | ||
/// @{ | ||
|
||
/// Test equality with base HybridFactor | ||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; | ||
|
||
/* print utility */ | ||
void print( | ||
const std::string &s = "GaussianMixture\n", | ||
const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||
|
||
/// @} | ||
|
||
/// Getter for the underlying Conditionals DecisionTree | ||
const Conditionals &conditionals(); | ||
|
||
/** | ||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while | ||
* maintaining the decision tree structure. | ||
* | ||
* @param sum Decision Tree of Gaussian Factor Graphs | ||
* @return Sum | ||
*/ | ||
Sum add(const Sum &sum) const; | ||
}; | ||
|
||
// traits | ||
template <> | ||
struct traits<GaussianMixture> : public Testable<GaussianMixture> {}; | ||
|
||
} // namespace gtsam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file GaussianMixtureFactor.cpp | ||
* @brief A set of Gaussian factors indexed by a set of discrete keys. | ||
* @author Fan Jiang | ||
* @author Varun Agrawal | ||
* @author Frank Dellaert | ||
* @date Mar 12, 2022 | ||
*/ | ||
|
||
#include <gtsam/base/utilities.h> | ||
#include <gtsam/discrete/DecisionTree-inl.h> | ||
#include <gtsam/discrete/DecisionTree.h> | ||
#include <gtsam/hybrid/GaussianMixtureFactor.h> | ||
#include <gtsam/linear/GaussianFactorGraph.h> | ||
|
||
namespace gtsam { | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, | ||
const DiscreteKeys &discreteKeys, | ||
const Factors &factors) | ||
: Base(continuousKeys, discreteKeys), factors_(factors) {} | ||
|
||
/* *******************************************************************************/ | ||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | ||
const This *e = dynamic_cast<const This *>(&lf); | ||
return e != nullptr && Base::equals(*e, tol); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixtureFactor GaussianMixtureFactor::FromFactors( | ||
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, | ||
const std::vector<GaussianFactor::shared_ptr> &factors) { | ||
Factors dt(discreteKeys, factors); | ||
|
||
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
void GaussianMixtureFactor::print(const std::string &s, | ||
const KeyFormatter &formatter) const { | ||
HybridFactor::print(s, formatter); | ||
factors_.print( | ||
"mixture = ", [&](Key k) { return formatter(k); }, | ||
[&](const GaussianFactor::shared_ptr &gf) -> std::string { | ||
RedirectCout rd; | ||
if (!gf->empty()) | ||
gf->print("", formatter); | ||
else | ||
return {"nullptr"}; | ||
return rd.str(); | ||
}); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { | ||
return factors_; | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( | ||
const GaussianMixtureFactor::Sum &sum) const { | ||
using Y = GaussianFactorGraph; | ||
auto add = [](const Y &graph1, const Y &graph2) { | ||
auto result = graph1; | ||
result.push_back(graph2); | ||
return result; | ||
}; | ||
const Sum tree = asGaussianFactorGraphTree(); | ||
return sum.empty() ? tree : sum.apply(tree, add); | ||
} | ||
|
||
/* *******************************************************************************/ | ||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() | ||
const { | ||
auto wrap = [](const GaussianFactor::shared_ptr &factor) { | ||
GaussianFactorGraph result; | ||
result.push_back(factor); | ||
return result; | ||
}; | ||
return {factors_, wrap}; | ||
} | ||
} // namespace gtsam |
Oops, something went wrong.