-
Notifications
You must be signed in to change notification settings - Fork 765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear HybridBayesNet optimization #1270
Changes from all commits
5806850
36d6097
9564e32
7d36a9e
379a65f
746ca78
c4184e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* ---------------------------------------------------------------------------- | ||
|
||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
|
||
* See LICENSE for the license information | ||
|
||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file DiscreteLookupDAG.cpp | ||
* @date Aug, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#include <gtsam/discrete/DiscreteBayesNet.h> | ||
#include <gtsam/discrete/DiscreteLookupDAG.h> | ||
#include <gtsam/discrete/DiscreteValues.h> | ||
#include <gtsam/hybrid/HybridBayesNet.h> | ||
#include <gtsam/hybrid/HybridConditional.h> | ||
#include <gtsam/hybrid/HybridLookupDAG.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
#include <gtsam/linear/VectorValues.h> | ||
|
||
#include <string> | ||
#include <utility> | ||
|
||
using std::pair; | ||
using std::vector; | ||
|
||
namespace gtsam { | ||
|
||
/* ************************************************************************** */ | ||
void HybridLookupTable::argmaxInPlace(HybridValues* values) const { | ||
// For discrete conditional, uses argmaxInPlace() method in | ||
// DiscreteLookupTable. | ||
if (isDiscrete()) { | ||
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace( | ||
&(values->discrete)); | ||
} else if (isContinuous()) { | ||
// For Gaussian conditional, uses solve() method in GaussianConditional. | ||
values->continuous.insert( | ||
boost::static_pointer_cast<GaussianConditional>(inner_)->solve( | ||
values->continuous)); | ||
} else if (isHybrid()) { | ||
// For hybrid conditional, since children should not contain discrete | ||
// variable, we can condition on the discrete variable in the parents and | ||
// solve the resulting GaussianConditional. | ||
auto conditional = | ||
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()( | ||
values->discrete); | ||
values->continuous.insert(conditional->solve(values->continuous)); | ||
} | ||
} | ||
|
||
/* ************************************************************************** */ | ||
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { | ||
HybridLookupDAG dag; | ||
for (auto&& conditional : bayesNet) { | ||
HybridLookupTable hlt(*conditional); | ||
dag.push_back(hlt); | ||
} | ||
return dag; | ||
} | ||
|
||
/* ************************************************************************** */ | ||
HybridValues HybridLookupDAG::argmax(HybridValues result) const { | ||
// Argmax each node in turn in topological sort order (parents first). | ||
for (auto lookupTable : boost::adaptors::reverse(*this)) | ||
lookupTable->argmaxInPlace(&result); | ||
return result; | ||
} | ||
|
||
} // namespace gtsam |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/* ---------------------------------------------------------------------------- | ||
|
||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
|
||
* See LICENSE for the license information | ||
|
||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file HybridLookupDAG.h | ||
* @date Aug, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <gtsam/discrete/DiscreteDistribution.h> | ||
#include <gtsam/discrete/DiscreteLookupDAG.h> | ||
#include <gtsam/hybrid/HybridConditional.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
#include <gtsam/inference/BayesNet.h> | ||
#include <gtsam/inference/FactorGraph.h> | ||
|
||
#include <boost/shared_ptr.hpp> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* @brief HybridLookupTable table for max-product | ||
* | ||
* Similar to DiscreteLookupTable, inherits from hybrid conditional for | ||
* convenience. Is used in the max-product algorithm. | ||
*/ | ||
class GTSAM_EXPORT HybridLookupTable : public HybridConditional { | ||
public: | ||
using Base = HybridConditional; | ||
using This = HybridLookupTable; | ||
using shared_ptr = boost::shared_ptr<This>; | ||
using BaseConditional = Conditional<DecisionTreeFactor, This>; | ||
|
||
/** | ||
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional. | ||
* | ||
* @param conditional input hybrid conditional | ||
*/ | ||
HybridLookupTable(HybridConditional& conditional) : Base(conditional){}; | ||
|
||
/** | ||
* @brief Calculate assignment for frontal variables that maximizes value. | ||
* @param (in/out) parentsValues Known assignments for the parents. | ||
*/ | ||
void argmaxInPlace(HybridValues* parentsValues) const; | ||
}; | ||
|
||
/** A DAG made from hybrid lookup tables, as defined above. Similar to | ||
* DiscreteLookupDAG */ | ||
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docs for this class are severely lacking. @xsj01 you have to be better than just copy-pasting the same docstrings, since now the motivation for this class is lost. |
||
public: | ||
using Base = BayesNet<HybridLookupTable>; | ||
using This = HybridLookupDAG; | ||
using shared_ptr = boost::shared_ptr<This>; | ||
|
||
/// @name Standard Constructors | ||
/// @{ | ||
|
||
/// Construct empty DAG. | ||
HybridLookupDAG() {} | ||
|
||
/// Create from BayesNet with LookupTables | ||
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet); | ||
|
||
/// Destructor | ||
virtual ~HybridLookupDAG() {} | ||
|
||
/// @} | ||
|
||
/// @name Standard Interface | ||
/// @{ | ||
|
||
/** Add a DiscreteLookupTable */ | ||
template <typename... Args> | ||
void add(Args&&... args) { | ||
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...); | ||
} | ||
|
||
/** | ||
* @brief argmax by back-substitution, optionally given certain variables. | ||
* | ||
* Assumes the DAG is reverse topologically sorted, i.e. last | ||
* conditional will be optimized first *and* that the | ||
* DAG does not contain any conditionals for the given variables. If the DAG | ||
* resulted from eliminating a factor graph, this is true for the elimination | ||
* ordering. | ||
* | ||
* @return given assignment extended w. optimal assignment for all variables. | ||
*/ | ||
HybridValues argmax(HybridValues given = HybridValues()) const; | ||
/// @} | ||
|
||
private: | ||
/** Serialization function */ | ||
friend class boost::serialization::access; | ||
template <class ARCHIVE> | ||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { | ||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); | ||
} | ||
}; | ||
|
||
// traits | ||
template <> | ||
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {}; | ||
|
||
} // namespace gtsam |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* ---------------------------------------------------------------------------- | ||
|
||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
|
||
* See LICENSE for the license information | ||
|
||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file HybridValues.h | ||
* @date Jul 28, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <gtsam/discrete/Assignment.h> | ||
#include <gtsam/discrete/DiscreteKey.h> | ||
#include <gtsam/discrete/DiscreteValues.h> | ||
#include <gtsam/inference/Key.h> | ||
#include <gtsam/linear/VectorValues.h> | ||
#include <gtsam/nonlinear/Values.h> | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* HybridValues represents a collection of DiscreteValues and VectorValues. It | ||
* is typically used to store the variables of a HybridGaussianFactorGraph. | ||
* Optimizing a HybridGaussianBayesNet returns this class. | ||
*/ | ||
class GTSAM_EXPORT HybridValues { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for thoughts: I always wondered can we just get rid of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the use of this class. Can't we simply have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think that's also an option. But I feel like the api is not very user friendly while accessing/assigning values. We can talk about it in our meeting. |
||
public: | ||
// DiscreteValue stored the discrete components of the HybridValues. | ||
DiscreteValues discrete; | ||
|
||
// VectorValue stored the continuous components of the HybridValues. | ||
VectorValues continuous; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we name the class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I can change VectorValues to Values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is more like a design question: Is this limited to linear hybrid systems? If so we should name it Gaussian. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will extend that for nonlinear hybrid later, but currently it's only for linear hybrid system. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the spirit for |
||
|
||
// Default constructor creates an empty HybridValues. | ||
HybridValues() : discrete(), continuous(){}; | ||
|
||
// Construct from DiscreteValues and VectorValues. | ||
HybridValues(const DiscreteValues& dv, const VectorValues& cv) | ||
: discrete(dv), continuous(cv){}; | ||
|
||
// print required by Testable for unit testing | ||
void print(const std::string& s = "HybridValues", | ||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { | ||
std::cout << s << ": \n"; | ||
discrete.print(" Discrete", keyFormatter); // print discrete components | ||
continuous.print(" Continuous", | ||
keyFormatter); // print continuous components | ||
}; | ||
|
||
// equals required by Testable for unit testing | ||
bool equals(const HybridValues& other, double tol = 1e-9) const { | ||
return discrete.equals(other.discrete, tol) && | ||
continuous.equals(other.continuous, tol); | ||
} | ||
|
||
// Check whether a variable with key \c j exists in DiscreteValue. | ||
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; | ||
|
||
// Check whether a variable with key \c j exists in VectorValue. | ||
bool existsVector(Key j) { return continuous.exists(j); }; | ||
|
||
// Check whether a variable with key \c j exists. | ||
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; | ||
|
||
/** Insert a discrete \c value with key \c j. Replaces the existing value if | ||
* the key \c j is already used. | ||
* @param value The vector to be inserted. | ||
* @param j The index with which the value will be associated. */ | ||
void insert(Key j, int value) { discrete[j] = value; }; | ||
|
||
/** Insert a vector \c value with key \c j. Throws an invalid_argument | ||
* exception if the key \c j is already used. | ||
* @param value The vector to be inserted. | ||
* @param j The index with which the value will be associated. */ | ||
void insert(Key j, const Vector& value) { continuous.insert(j, value); } | ||
|
||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h | ||
|
||
/** | ||
* Read/write access to the discrete value with key \c j, throws | ||
* std::out_of_range if \c j does not exist. | ||
*/ | ||
size_t& atDiscrete(Key j) { return discrete.at(j); }; | ||
|
||
/** | ||
* Read/write access to the vector value with key \c j, throws | ||
* std::out_of_range if \c j does not exist. | ||
*/ | ||
Vector& at(Key j) { return continuous.at(j); }; | ||
|
||
/// @name Wrapper support | ||
/// @{ | ||
|
||
/** | ||
* @brief Output as a html table. | ||
* | ||
* @param keyFormatter function that formats keys. | ||
* @return string html output. | ||
*/ | ||
std::string html( | ||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { | ||
std::stringstream ss; | ||
ss << this->discrete.html(keyFormatter); | ||
ss << this->continuous.html(keyFormatter); | ||
return ss.str(); | ||
}; | ||
|
||
/// @} | ||
}; | ||
|
||
// traits | ||
template <> | ||
struct traits<HybridValues> : public Testable<HybridValues> {}; | ||
|
||
} // namespace gtsam |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xsj01 What would be the frontal variables here? They should all be continuous variables since we eliminate those first, so then this doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it seems the HybridLookupTable is not necessary. Just had some discussions with Fan, I will submit a new PR to fix this.