Skip to content

Commit

Permalink
Merge pull request #1070 from borglab/feauture/dotwriter
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 28, 2022
2 parents 7530623 + 2123225 commit 84aed90
Show file tree
Hide file tree
Showing 23 changed files with 452 additions and 293 deletions.
31 changes: 18 additions & 13 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

namespace gtsam {

/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

/**
* A Bayes net made from discrete conditional distributions.
* @addtogroup discrete
*/
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
Expand All @@ -49,16 +50,20 @@ namespace gtsam {
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}

/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}

/// Destructor
virtual ~DiscreteBayesNet() {}
Expand Down
29 changes: 17 additions & 12 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
Expand Down Expand Up @@ -156,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand Down Expand Up @@ -252,14 +257,6 @@ class DiscreteFactorGraph {
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;

gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
Expand All @@ -281,6 +278,14 @@ class DiscreteFactorGraph {
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);

string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand Down
19 changes: 14 additions & 5 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
}

Expand Down
44 changes: 28 additions & 16 deletions gtsam/inference/BayesNet-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,51 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>

#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
#include <string>

namespace gtsam {

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
void BayesNet<CONDITIONAL>::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.digraphPreamble(&os);

// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";

for (auto conditional : *this) {
// Reverse order as typically Bayes nets stored in reverse topological sort.
for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
}

os << "}";
Expand All @@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,

/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss;
dot(ss, keyFormatter);
dot(ss, keyFormatter, writer);
return ss.str();
}

/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
dot(of, keyFormatter, writer);
of.close();
}

Expand Down
104 changes: 53 additions & 51 deletions gtsam/inference/BayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */

/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/

#pragma once

#include <boost/shared_ptr.hpp>

#include <gtsam/inference/FactorGraph.h>

namespace gtsam {
#include <boost/shared_ptr.hpp>
#include <string>

/**
* A BayesNet is a tree of conditionals, stored in elimination order.
*
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
* \nosubgrouping
*/
template<class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
namespace gtsam {

private:
/**
* A BayesNet is a tree of conditionals, stored in elimination order.
* @addtogroup inference
*/
template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;

typedef FactorGraph<CONDITIONAL> Base;
public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional

public:
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
protected:
/// @name Standard Constructors
/// @{

protected:
/// @name Standard Constructors
/// @{
/** Default constructor as an empty BayesNet */
BayesNet() {}

/** Default constructor as an empty BayesNet */
BayesNet() {};
/** Construct from iterator over conditionals */
template <typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}

/** Construct from iterator over conditionals */
template<typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
/// @}

/// @}
public:
/// @name Testable
/// @{

public:
/// @name Testable
/// @{
/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;

/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @}

/// @}
/// @name Graph Display
/// @{

/// @name Graph Display
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;

/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
};
/// @}
};

}
} // namespace gtsam

#include <gtsam/inference/BayesNet-inst.h>
Loading

0 comments on commit 84aed90

Please sign in to comment.