Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,20 @@ def naive_bcast_rv_lift(fgraph, node):
]
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

new_bcast_out = bcasted_node.outputs[1]

if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
val_var = rv_map_feature.rv_values[rv_var]
new_val_var = at.broadcast_to(val_var, tuple(bcast_shape))
rv_map_feature.rv_values[new_bcast_out] = new_val_var
Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this add an extra logp term? If I am parsing this correctly, calling joint_logp on your first new test case would now sum 3 terms corresponding to (x_vv, new_value_var, y_vv), no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this add an extra logp term? If I am parsing this correctly, calling joint_logp on your first new test case would now sum 3 terms corresponding to (x_vv, new_value_var, y_vv), no?

Have you tried it? We can always check the sum in the tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first test case might actually add an extra term; I'll try it out.

rv_map_feature.original_values[new_val_var] = new_val_var

if aesara.config.compute_test_value != "off":
compute_test_value(bcasted_node)

return [bcasted_node.outputs[1]]
return [new_bcast_out]


logprob_rewrites_db = SequenceDB()
Expand Down
5 changes: 0 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import sys
import os
import pathlib


# import local version of library instead of installed one
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve().parent.parent / "src"))
import aeppl

# -- Project information
Expand Down
38 changes: 37 additions & 1 deletion tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import aesara
import aesara.tensor as at
import numpy as np
import scipy.stats as st
from aesara.graph.opt import in2out
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.subtensor import Subtensor

from aeppl import factorized_joint_logprob
from aeppl.dists import DiracDelta, dirac_delta
from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift


def test_naive_bcast_rv_lift():
r"""Make sure `test_naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
X_rv = at.random.normal()
Z_at = BroadcastTo()(X_rv, ())

Expand All @@ -22,6 +25,39 @@ def test_naive_bcast_rv_lift():
assert res is X_rv


def test_naive_bcast_rv_lift_valued_var():
r"""Check that `naive_bcast_rv_lift` handles valued variables correctly."""

x_rv = at.random.normal(name="x")

mu = at.broadcast_to(x_rv, (2,))
y_rv = at.random.normal(mu, name="y")

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})

assert x_vv in logp_map
assert y_vv in logp_map
y_val = np.array([0, 0])
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: y_val}), st.norm(0).logpdf(y_val)
)

# Lifting should also work when `BroadcastTo`s are directly assigned value
Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)

Besides, the joint_logprob would have been fine with saying {x_vv: 0, z_vv: [[1, 1], [1, 1]]} which does not match the original graph

Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.

... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of z_rv would correspond to a logp of something like:

def broadcastTo_logp(value, original_shape, dist_op, dist_params):

  # e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
  pre_broadcast_value = undo_broadcast(value, original_shape)
  
  return switch(
    at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
    logprob(dist_op, pre_broadcast_value, *dist_params),
    -inf,
  )  

But that's a separate question from the comment above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.

Only if it broadcasted something that wasn't valued, but that's the limitation I've demonstrated.

... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of z_rv would correspond to a logp of something like:

def broadcastTo_logp(value, original_shape, dist_op, dist_params):

  # e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
  pre_broadcast_value = undo_broadcast(value, original_shape)
  
  return switch(
    at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
    logprob(dist_op, pre_broadcast_value, *dist_params),
    -inf,
  )  

But that's a separate question from the comment above

You're goint to need to clarify this, because, right now, I don't see the relevance of this function. Can you demonstrate the underlying issue by deriving an incorrect log-probability calculation under these changes?

Copy link
Member Author

@brandonwillard brandonwillard Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)

Are the log-probability values incorrect, or do you simply think people shouldn't be allowed to do this? Is there some other inconsistency implied by this that you can demonstrate?

Regardless, it's a minimal example, so there may be other cases that involve performing the same rewrite on a valued variable in order to derive a log-probability for another valued variable (e.g. like mixtures). Can you guarantee that this isn't possible, so that we can justify the limitations you're proposing?

Besides, the joint_logprob would have been fine with saying {x_vv: 0, z_vv: [[1, 1], [1, 1]]} which does not match the original graph

What exactly doesn't match the original graph, and how is that relevant?

# variables
z_rv = at.broadcast_to(x_rv, (2, 2))
z_rv.name = "Z"
z_vv = z_rv.clone()
z_vv.name = "z"

logp_map = factorized_joint_logprob({x_rv: x_vv, z_rv: z_vv})
assert x_vv in logp_map
assert z_vv in logp_map
z_val = np.array([[0, 0], [0, 0]])
assert np.allclose(logp_map[z_vv].eval({z_vv: z_val}), st.norm(0).logpdf(z_val))


def test_local_lift_DiracDelta():
c_at = at.vector()
dd_at = dirac_delta(c_at)
Expand Down