Skip to content

Commit

Permalink
Add wrapping for hybrid nonlinear
Browse files Browse the repository at this point in the history
  • Loading branch information
ProfFan committed Aug 26, 2022
1 parent ff20a50 commit 1bb7a00
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 1 deletion.
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ void HybridNonlinearFactorGraph::add(
FactorGraph::add(boost::make_shared<HybridNonlinearFactor>(factor));
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::add(
boost::shared_ptr<DiscreteFactor> factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
Expand Down
3 changes: 3 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
/// Add a nonlinear factor as a shared ptr.
void add(boost::shared_ptr<NonlinearFactor> factor);

/// Add a discrete factor as a shared ptr.
void add(boost::shared_ptr<DiscreteFactor> factor);

/// Print the factor graph.
void print(
const std::string& s = "HybridNonlinearFactorGraph",
Expand Down
59 changes: 58 additions & 1 deletion gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ virtual class HybridConditional {
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const;
size_t nrParents() const;
Factor* inner();
gtsam::Factor* inner();
};

#include <gtsam/hybrid/HybridDiscreteFactor.h>
virtual class HybridDiscreteFactor {
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
void print(string s = "HybridDiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
gtsam::Factor* inner();
};

#include <gtsam/hybrid/GaussianMixtureFactor.h>
Expand Down Expand Up @@ -132,6 +142,7 @@ class HybridGaussianFactorGraph {
void add(gtsam::JacobianFactor* factor);

bool empty() const;
void remove(size_t i);
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridFactor* at(size_t i) const;
Expand Down Expand Up @@ -159,4 +170,50 @@ class HybridGaussianFactorGraph {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};

#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
class HybridNonlinearFactorGraph {
HybridNonlinearFactorGraph();
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
void push_back(gtsam::HybridFactor* factor);
void push_back(gtsam::NonlinearFactor* factor);
void push_back(gtsam::HybridDiscreteFactor* factor);
void add(gtsam::NonlinearFactor* factor);
void add(gtsam::DiscreteFactor* factor);
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;

bool empty() const;
void remove(size_t i);
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridFactor* at(size_t i) const;

void print(string s = "HybridNonlinearFactorGraph\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/MixtureFactor.h>
class MixtureFactor : gtsam::HybridFactor {
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);

template <FACTOR = {gtsam::NonlinearFactor}>
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const std::vector<FACTOR*>& factors,
bool normalized = false);

double error(const gtsam::Values& continuousVals,
const gtsam::DiscreteValues& discreteVals) const;

double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor,
const gtsam::Values& values) const;

GaussianMixtureFactor* linearize(
const gtsam::Values& continuousVals) const;

void print(string s = "MixtureFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

} // namespace gtsam
7 changes: 7 additions & 0 deletions python/gtsam/preamble/hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,12 @@
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>

#ifdef GTSAM_ALLOCATOR_TBB
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key, tbb::tbb_allocator<gtsam::Key>>);
#else
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key>);
#endif

PYBIND11_MAKE_OPAQUE(std::vector<gtsam::GaussianFactor::shared_ptr>);
55 changes: 55 additions & 0 deletions python/gtsam/tests/test_HybridNonlinearFactorGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Nonlinear Factor Graphs.
Author: Fan Jiang
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

from __future__ import print_function

import unittest

import gtsam
import numpy as np
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase


class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridGaussianFactorGraph."""

def test_nonlinear_hybrid(self):
nlfg = gtsam.HybridNonlinearFactorGraph()
dk = gtsam.DiscreteKeys()
dk.push_back((10, 2))
nlfg.add(gtsam.BetweenFactorPoint3(1, 2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([1, 1, 1])))
nlfg.add(
gtsam.PriorFactorPoint3(2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5])))
nlfg.push_back(
gtsam.MixtureFactor([1], dk, [
gtsam.PriorFactorPoint3(1, gtsam.Point3(0, 0, 0),
gtsam.noiseModel.Unit.Create(3)),
gtsam.PriorFactorPoint3(1, gtsam.Point3(1, 2, 1),
gtsam.noiseModel.Unit.Create(3))
]))
nlfg.add(gtsam.DecisionTreeFactor((10, 2), "1 3"))
values = gtsam.Values()
values.insert_point3(1, gtsam.Point3(0, 0, 0))
values.insert_point3(2, gtsam.Point3(2, 3, 1))
hfg = nlfg.linearize(values)
o = gtsam.Ordering()
o.push_back(1)
o.push_back(2)
o.push_back(10)
hbn = hfg.eliminateSequential(o)
hbv = hbn.optimize()
self.assertEqual(hbv.atDiscrete(10), 0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1bb7a00

Please sign in to comment.