Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for positioning in BayesNet plotting #1070

Merged
merged 10 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

namespace gtsam {

/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

/**
* A Bayes net made from discrete conditional distributions.
* @addtogroup discrete
*/
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
Expand All @@ -49,16 +50,20 @@ namespace gtsam {
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}

/// Destructor
virtual ~DiscreteBayesNet() {}
Expand Down
12 changes: 8 additions & 4 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
Expand Down Expand Up @@ -156,10 +157,13 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
Expand Down
19 changes: 14 additions & 5 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
}

Expand Down
43 changes: 27 additions & 16 deletions gtsam/inference/BayesNet-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,50 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>

#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
#include <string>

namespace gtsam {

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
void BayesNet<CONDITIONAL>::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.digraphPreamble(&os);

// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";

for (auto conditional : *this) {
for (auto conditional : boost::adaptors::reverse(*this)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment here why reversing the order? Like following a top-down order in printing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because Bayes nets are typically specified in reverse topological sort order, that's how they come out of elimination. Will add comments in the next PR to not restart CI for this one thing

auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
}

os << "}";
Expand All @@ -53,18 +62,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,

/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss;
dot(ss, keyFormatter);
dot(ss, keyFormatter, writer);
return ss.str();
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
dot(of, keyFormatter, writer);
of.close();
}

Expand Down
104 changes: 53 additions & 51 deletions gtsam/inference/BayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <boost/shared_ptr.hpp>

#include <gtsam/inference/FactorGraph.h>

namespace gtsam {
#include <boost/shared_ptr.hpp>
#include <string>

/**
* A BayesNet is a tree of conditionals, stored in elimination order.
*
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
* \nosubgrouping
*/
template<class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
namespace gtsam {

private:
/**
* A BayesNet is a tree of conditionals, stored in elimination order.
* @addtogroup inference
*/
template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;

typedef FactorGraph<CONDITIONAL> Base;
public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional

public:
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
protected:
/// @name Standard Constructors
/// @{

protected:
/// @name Standard Constructors
/// @{
/** Default constructor as an empty BayesNet */
BayesNet() {}

/** Default constructor as an empty BayesNet */
BayesNet() {};
/** Construct from iterator over conditionals */
template <typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
/// @}

/// @}
public:
/// @name Testable
/// @{

public:
/// @name Testable
/// @{
/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;

/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}

/// @}
/// @name Graph Display
/// @{

/// @name Graph Display
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
};
/// @}
};

}
} // namespace gtsam

#include <gtsam/inference/BayesNet-inst.h>
Loading