Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 15, 2024
1 parent 378dbe4 commit 9341906
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 42 deletions.
19 changes: 8 additions & 11 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from typing import Sequence
from collections.abc import Sequence

import numpy as np
import pytensor.tensor as pt
from pymc.distributions import (
Bernoulli,
Categorical,
DiscreteUniform,
SymbolicRandomVariable
)
from pymc.logprob.basic import conditional_logp, logp

from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import scan, map as scan_map
from pytensor.compile.mode import Mode
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable, TensorType
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorType, TensorVariable

from pymc_experimental.distributions import DiscreteMarkovChain

Expand Down
44 changes: 25 additions & 19 deletions pymc_experimental/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from itertools import zip_longest, chain
from typing import Sequence
from collections.abc import Sequence
from itertools import chain, zip_longest

from pymc import SymbolicRandomVariable
from pytensor.compile import SharedVariable
from pytensor.graph import ancestors, Constant, graph_inputs, Variable
from pytensor.graph import Constant, Variable, ancestors, graph_inputs
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorVariable, TensorType
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
from pytensor.tensor.type_other import NoneTypeT


Expand Down Expand Up @@ -58,7 +58,6 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
]



def collect_shared_vars(outputs, blockers):
return [
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
Expand Down Expand Up @@ -86,18 +85,22 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
return adv_group_axis, adv_group_ndim


def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]:
def _broadcast_dims(
inputs_dims: Sequence[tuple[tuple[int, ...], ...]],
) -> tuple[tuple[int, ...], ...]:
output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
# Add missing dims
inputs_dims = [
((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
]
inputs_dims = [((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims]
# Combine aligned dims
output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims))
output_dims = tuple(
tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)
)
return output_dims


def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]:
def subgraph_dim_connection(
input_var, other_inputs, output_vars
) -> list[tuple[tuple[int, ...], ...]]:
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
Raises
Expand Down Expand Up @@ -135,13 +138,16 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
op_batch_ndim = node.op.batch_ndim(node)

# Collapse all core_dims
core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]])))
batch_dims = _broadcast_dims(
tuple(
input_dims[:op_batch_ndim]
for input_dims in inputs_dims
core_dims = tuple(
sorted(
chain.from_iterable(
[i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]
)
)
)
batch_dims = _broadcast_dims(
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
)
# Add batch dims to each output_dims
batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
for out in node.outputs:
Expand Down Expand Up @@ -221,7 +227,7 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
elif value_dim:
# We are trying to partially slice or index a known dimension
raise NotImplementedError(
f"Partial slicing or advanced integer indexing of known dimensions not supported"
"Partial slicing or advanced integer indexing of known dimensions not supported"
)
elif isinstance(idx, slice):
# Unknown dimensions kept by partial slice.
Expand Down
18 changes: 13 additions & 5 deletions pymc_experimental/model/marginal/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@
__all__ = ["MarginalModel", "marginalize"]

from pymc_experimental.distributions import DiscreteMarkovChain
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV, DiscreteMarginalMarkovChainRV, \
get_domain_of_finite_discrete_rv, _add_reduce_batch_dependent_logps
from pymc_experimental.model.marginal.graph_analysis import find_conditional_input_rvs, is_conditional_dependent, \
find_conditional_dependent_rvs, subgraph_dim_connection, collect_shared_vars
from pymc_experimental.model.marginal.distributions import (
DiscreteMarginalMarkovChainRV,
FiniteDiscreteMarginalRV,
_add_reduce_batch_dependent_logps,
get_domain_of_finite_discrete_rv,
)
from pymc_experimental.model.marginal.graph_analysis import (
collect_shared_vars,
find_conditional_dependent_rvs,
find_conditional_input_rvs,
is_conditional_dependent,
subgraph_dim_connection,
)

ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]

Expand Down Expand Up @@ -613,4 +622,3 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
marginalized_rvs = marginalization_op(*inputs)
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs

2 changes: 1 addition & 1 deletion tests/model/marginal/test_distributions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np
import pymc as pm
import pytest

from pymc.logprob.abstract import _logprob
from pytensor import tensor as pt
from scipy.stats import norm

from pymc_experimental import MarginalModel
from pymc_experimental.distributions import DiscreteMarkovChain

from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV


Expand Down
27 changes: 22 additions & 5 deletions tests/model/marginal/test_graph_analysis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytensor.tensor as pt
import pytest

from pymc.distributions import CustomDist

from pymc_experimental.model.marginal.graph_analysis import subgraph_dim_connection


class TestSubgraphDimConnection:

def test_dimshuffle(self):
inp = pt.zeros(shape=(5, 1, 4, 3))
out1 = pt.matrix_transpose(inp)
Expand All @@ -31,11 +31,16 @@ def test_subtensor(self):
assert dims == ((1,),)

invalid_out = inp[0, :1]
with pytest.raises(NotImplementedError, match="Partial slicing of known dimensions not supported"):
with pytest.raises(
NotImplementedError, match="Partial slicing of known dimensions not supported"
):
subgraph_dim_connection(inp, [], [invalid_out])

# If we are slicing a dummy / unknown dimension that's fine
valid_out = pt.expand_dims(inp[:, 0], 1)[0, :1,]
valid_out = pt.expand_dims(inp[:, 0], 1)[
0,
:1,
]
[dims] = subgraph_dim_connection(inp, [], [valid_out])
assert dims == ((), (2,))

Expand All @@ -53,11 +58,23 @@ def test_elemwise(self):
# By removing the last dimension, we align the first and the last in the addition
out = inp + inp[:, 0]
[dims] = subgraph_dim_connection(inp, [], [out])
assert dims == ((0,), (0, 1,))
assert dims == (
(0,),
(
0,
1,
),
)

out = inp + inp.T
[dims] = subgraph_dim_connection(inp, [], [out])
assert dims == ((0, 1), (0, 1,))
assert dims == (
(0, 1),
(
0,
1,
),
)

def test_blockwise(self):
inp = pt.zeros(shape=(5, 4, 3, 2))
Expand Down
2 changes: 1 addition & 1 deletion tests/model/marginal/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm

from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent
from pymc_experimental.model.marginal.marginal_model import (
MarginalModel,
marginalize,
)
from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent
from tests.utils import equal_computations_up_to_root


Expand Down

0 comments on commit 9341906

Please sign in to comment.