Skip to content

Commit

Permalink
add HybridGaussianFactorGraph::probPrime method
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Nov 8, 2022
1 parent cb55af3 commit eb94ad9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,14 @@ double HybridGaussianFactorGraph::error(
return error;
}

/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
return std::exp(-error);
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
Expand Down
12 changes: 12 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;

/**
* @brief Compute the VectorValues solution for the continuous variables for
* each mode.
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
EXPECT(assert_equal(expected_error, error, 1e-9));

double probs = exp(-error);
double expected_probs = exp(-expected_error);
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());

// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));
Expand Down

0 comments on commit eb94ad9

Please sign in to comment.