Skip to content

Commit

Permalink
Merge pull request #1352 from borglab/feature/HBN-evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 29, 2022
2 parents af6a4f2 + d537867 commit a849eab
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 100 deletions.
77 changes: 45 additions & 32 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f;
}
}
Expand Down Expand Up @@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());

Expand Down Expand Up @@ -150,16 +150,11 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {

// Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();

for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(*decisionTree); // imperative :-(

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
Expand All @@ -186,24 +181,21 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscreteConditional();
return at(i)->asDiscrete();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment));

} else if (conditional->isContinuous()) {
gbn.push_back((*gm)(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian()));

} else if (conditional->isDiscrete()) {
gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
Expand All @@ -218,31 +210,55 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
discrete_bn.push_back(conditional->asDiscrete());
}
}

DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();

// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
GaussianBayesNet gbn = choose(mpe);
return HybridValues(mpe, gbn.optimize());
}

/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
GaussianBayesNet gbn = choose(assignment);
return gbn.optimize();
}

/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous();

double logDensity = 0.0, probability = 1.0;

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
}
}

return probability * exp(logDensity);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional());
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
Expand Down Expand Up @@ -273,7 +289,7 @@ HybridValues HybridBayesNet::sample() const {
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
}

Expand All @@ -284,23 +300,20 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);

error_tree = error_tree + conditional_error;

} else if (conditional->isContinuous()) {
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = conditional->asGaussian()->error(continuousValues);
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (conditional->isDiscrete()) {
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip.
continue;
}
Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Evaluate hybrid probability density for given HybridValues.
double evaluate(const HybridValues &values) const;

/// Evaluate hybrid probability density for given HybridValues, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}

/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {

// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
Expand Down Expand Up @@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional();
this->roots_.at(0)->conditional()->asDiscrete();

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
Expand Down
23 changes: 9 additions & 14 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,34 +131,29 @@ class GTSAM_EXPORT HybridConditional

/**
* @brief Return HybridConditional as a GaussianMixture
*
* @return GaussianMixture::shared_ptr
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() {
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner_);
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}

/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}

/**
* @brief Return conditional as a DiscreteConditional
*
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscreteConditional() {
if (!isDiscrete())
throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner_);
DiscreteConditional::shared_ptr asDiscrete() {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}

/// @}
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class GTSAM_EXPORT HybridValues {
/// @{

/// Return the discrete MPE assignment
DiscreteValues discrete() const { return discrete_; }
const DiscreteValues& discrete() const { return discrete_; }

/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }
const VectorValues& continuous() const { return continuous_; }

/// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
Expand Down
Loading

0 comments on commit a849eab

Please sign in to comment.