Skip to content

Commit

Permalink
Merge pull request #3460 from rpgoldman/grapher-dfs
Browse files Browse the repository at this point in the history
Fix for timeout in graph_model
  • Loading branch information
lucianopaz authored May 5, 2019
2 parents d113e41 + ca13c44 commit 05e3c39
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Used `numpy.vectorize` in `distributions.distribution._compile_theano_function`. This enables `sample_prior_predictive` and `sample_posterior_predictive` to ask for tuples of samples instead of just integers. This fixes issue #3422.

### Maintenance
- Fixed an issue in `model_graph` that caused construction of the graph of the model for rendering to hang: replaced a search over the powerset of the nodes with a breadth-first search over the nodes. Fix for #3458.
- All occurances of `sd` as a parameter name have been renamed to `sigma`. `sd` will continue to function for backwards compatibility.
- Made `BrokenPipeError` for parallel sampling more verbose on Windows.
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
Expand Down
57 changes: 34 additions & 23 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import itertools
from collections import deque
from typing import Iterator, Optional, MutableSet

from theano.gof.graph import ancestors
from theano.gof.graph import stack_search
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


def powerset(iterable):
"""All *nonempty* subsets of an iterable.
Expand All @@ -27,37 +34,41 @@ def __init__(self, model):
self._deterministics = None

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph"""
"""Compute the deterministic nodes of the graph, **not** including var itself."""
deterministics = []
attrs = ('transformed', 'logpt')
for v in self.var_list:
if v != var and all(not hasattr(v, attr) for attr in attrs):
deterministics.append(v)
return deterministics

def _ancestors(self, var, func, blockers=None):
"""Get ancestors of a function that are also named PyMC3 variables"""
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
def _get_ancestors(self, var, func) -> MutableSet[RV]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

def _get_ancestors(self, var, func):
"""Get all ancestors of a function, doing some accounting for deterministics
# this contains all of the variables in the model EXCEPT var...
vars: MutableSet[RV] = set(self.var_list)
vars.remove(var)

blockers: MutableSet[RV] = set()
retval = set()
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
return None
elif node in vars:
blockers.add(node)
retval.add(node)
return None
elif node.owner:
blockers.add(node)
return reversed(node.owner.inputs)
else:
return None

Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
return only the inputs *to the deterministic*. However, if we pass in the
deterministic as a blocker, it will skip those nodes.
"""
deterministics = self.get_deterministics(var)
upstream = self._ancestors(var, func)

# Usual case
if upstream == self._ancestors(var, func, blockers=upstream):
return upstream
else: # deterministic accounting
for d in powerset(upstream):
blocked = self._ancestors(var, func, blockers=d)
if set(d) == blocked:
return d
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')
stack_search(start = deque([func]),
expand=_expand,
mode='bfs')
return retval

def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
Expand Down

0 comments on commit 05e3c39

Please sign in to comment.