diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index a6abce15aa..a4218593b6 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -26,6 +26,12 @@ void HybridNonlinearFactorGraph::add( FactorGraph::add(boost::make_shared(factor)); } +/* ************************************************************************* */ +void HybridNonlinearFactorGraph::add( + boost::shared_ptr factor) { + FactorGraph::add(boost::make_shared(factor)); +} + /* ************************************************************************* */ void HybridNonlinearFactorGraph::print(const std::string& s, const KeyFormatter& keyFormatter) const { diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 2ddb8bcea2..7a19c77555 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -112,6 +112,9 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { /// Add a nonlinear factor as a shared ptr. void add(boost::shared_ptr factor); + /// Add a discrete factor as a shared ptr. + void add(boost::shared_ptr factor); + /// Print the factor graph. void print( const std::string& s = "HybridNonlinearFactorGraph", diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 207f3ff635..aa63259d9d 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -39,7 +39,17 @@ virtual class HybridConditional { bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; size_t nrFrontals() const; size_t nrParents() const; - Factor* inner(); + gtsam::Factor* inner(); +}; + +#include +virtual class HybridDiscreteFactor { + HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf); + void print(string s = "HybridDiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const; + gtsam::Factor* inner(); }; #include @@ -132,6 +142,7 @@ class HybridGaussianFactorGraph { void add(gtsam::JacobianFactor* factor); bool empty() const; + void remove(size_t i); size_t size() const; gtsam::KeySet keys() const; const gtsam::HybridFactor* at(size_t i) const; @@ -159,4 +170,50 @@ class HybridGaussianFactorGraph { const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; +#include +class HybridNonlinearFactorGraph { + HybridNonlinearFactorGraph(); + HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph); + void push_back(gtsam::HybridFactor* factor); + void push_back(gtsam::NonlinearFactor* factor); + void push_back(gtsam::HybridDiscreteFactor* factor); + void add(gtsam::NonlinearFactor* factor); + void add(gtsam::DiscreteFactor* factor); + gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const; + + bool empty() const; + void remove(size_t i); + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::HybridFactor* at(size_t i) const; + + void print(string s = "HybridNonlinearFactorGraph\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class MixtureFactor : gtsam::HybridFactor { + MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const gtsam::DecisionTree& factors, bool normalized = false); + + template + MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const std::vector& factors, + bool normalized = false); + + double error(const gtsam::Values& continuousVals, + const gtsam::DiscreteValues& discreteVals) const; + + double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor, + const gtsam::Values& values) const; + + GaussianMixtureFactor* linearize( + const gtsam::Values& continuousVals) const; + + void print(string s = "MixtureFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + } // namespace gtsam diff --git a/python/gtsam/preamble/hybrid.h b/python/gtsam/preamble/hybrid.h index 5e5a71e48d..bae636d6a5 100644 --- a/python/gtsam/preamble/hybrid.h +++ b/python/gtsam/preamble/hybrid.h @@ -10,5 +10,12 @@ * Without this they will be automatically converted to a Python object, and all * mutations on Python side will not be reflected on C++. */ +#include + +#ifdef GTSAM_ALLOCATOR_TBB +PYBIND11_MAKE_OPAQUE(std::vector>); +#else +PYBIND11_MAKE_OPAQUE(std::vector); +#endif PYBIND11_MAKE_OPAQUE(std::vector); diff --git a/python/gtsam/tests/test_HybridNonlinearFactorGraph.py b/python/gtsam/tests/test_HybridNonlinearFactorGraph.py new file mode 100644 index 0000000000..3ac0d5c6f9 --- /dev/null +++ b/python/gtsam/tests/test_HybridNonlinearFactorGraph.py @@ -0,0 +1,55 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Nonlinear Factor Graphs. +Author: Fan Jiang +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import C, X +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridGaussianFactorGraph(GtsamTestCase): + """Unit tests for HybridGaussianFactorGraph.""" + + def test_nonlinear_hybrid(self): + nlfg = gtsam.HybridNonlinearFactorGraph() + dk = gtsam.DiscreteKeys() + dk.push_back((10, 2)) + nlfg.add(gtsam.BetweenFactorPoint3(1, 2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([1, 1, 1]))) + nlfg.add( + gtsam.PriorFactorPoint3(2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5]))) + nlfg.push_back( + gtsam.MixtureFactor([1], dk, [ + gtsam.PriorFactorPoint3(1, gtsam.Point3(0, 0, 0), + gtsam.noiseModel.Unit.Create(3)), + gtsam.PriorFactorPoint3(1, gtsam.Point3(1, 2, 1), + gtsam.noiseModel.Unit.Create(3)) + ])) + nlfg.add(gtsam.DecisionTreeFactor((10, 2), "1 3")) + values = gtsam.Values() + values.insert_point3(1, gtsam.Point3(0, 0, 0)) + values.insert_point3(2, gtsam.Point3(2, 3, 1)) + hfg = nlfg.linearize(values) + o = gtsam.Ordering() + o.push_back(1) + o.push_back(2) + o.push_back(10) + hbn = hfg.eliminateSequential(o) + hbv = hbn.optimize() + self.assertEqual(hbv.atDiscrete(10), 0) + + +if __name__ == "__main__": + unittest.main()