Skip to content

Commit

Permalink
check graphviz results with all four formatting options
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Nov 26, 2020
1 parent 185b609 commit 15ce7bf
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.Normal("obs", beta * x, obs_sigma, observed=y)
pm.sample(1000, init=None, tune=1000, chains=1)

g = pm.model_to_graphviz(model)

# Data node rendered correctly?
text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]'
assert text in g.source
# Didn't break ordinary variables?
text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]'
assert text in g.source
text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]'
assert text in g.source
for formatting in {"latex", "latex_with_params"}:
with pytest.raises(ValueError, match="Unsupported formatting"):
pm.model_to_graphviz(model, formatting=formatting)

exp_without = [
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
'beta [label="beta\n~\nNormal"]',
'obs [label="obs\n~\nNormal" style=filled]',
]
exp_with = [
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]',
f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]',
]
for formatting, expected_substrings in [
("plain", exp_without),
("plain_with_params", exp_with),
]:
g = pm.model_to_graphviz(model, formatting=formatting)
# check formatting of RV nodes
for expected in expected_substrings:
assert expected in g.source

def test_explicit_coords(self):
N_rows = 5
Expand Down

0 comments on commit 15ce7bf

Please sign in to comment.