-
-
Notifications
You must be signed in to change notification settings - Fork 20
Track value variables when lifting BroadcastTo Ops
#121
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,20 @@ | ||
| import aesara | ||
| import aesara.tensor as at | ||
| import numpy as np | ||
| import scipy.stats as st | ||
| from aesara.graph.opt import in2out | ||
| from aesara.graph.opt_utils import optimize_graph | ||
| from aesara.tensor.elemwise import DimShuffle, Elemwise | ||
| from aesara.tensor.extra_ops import BroadcastTo | ||
| from aesara.tensor.subtensor import Subtensor | ||
|
|
||
| from aeppl import factorized_joint_logprob | ||
| from aeppl.dists import DiracDelta, dirac_delta | ||
| from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift | ||
|
|
||
|
|
||
| def test_naive_bcast_rv_lift(): | ||
| r"""Make sure `test_naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" | ||
| r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" | ||
| X_rv = at.random.normal() | ||
| Z_at = BroadcastTo()(X_rv, ()) | ||
|
|
||
|
|
@@ -22,6 +25,39 @@ def test_naive_bcast_rv_lift(): | |
| assert res is X_rv | ||
|
|
||
|
|
||
| def test_naive_bcast_rv_lift_valued_var(): | ||
| r"""Check that `naive_bcast_rv_lift` handles valued variables correctly.""" | ||
|
|
||
| x_rv = at.random.normal(name="x") | ||
|
|
||
| mu = at.broadcast_to(x_rv, (2,)) | ||
| y_rv = at.random.normal(mu, name="y") | ||
|
|
||
| x_vv = x_rv.clone() | ||
| y_vv = y_rv.clone() | ||
| logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) | ||
|
|
||
| assert x_vv in logp_map | ||
| assert y_vv in logp_map | ||
| y_val = np.array([0, 0]) | ||
| assert np.allclose( | ||
| logp_map[y_vv].eval({x_vv: 0, y_vv: y_val}), st.norm(0).logpdf(y_val) | ||
| ) | ||
|
|
||
| # Lifting should also work when `BroadcastTo`s are directly assigned value | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable) Besides, the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If either only one of the pre- and post- broadcasted variable is valued then #116 already worked. ... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of def broadcastTo_logp(value, original_shape, dist_op, dist_params):
# e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
pre_broadcast_value = undo_broadcast(value, original_shape)
return switch(
at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
logprob(dist_op, pre_broadcast_value, *dist_params),
-inf,
) But that's a separate question from the comment above
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Only if it broadcasted something that wasn't valued, but that's the limitation I've demonstrated.
You're goint to need to clarify this, because, right now, I don't see the relevance of this function. Can you demonstrate the underlying issue by deriving an incorrect log-probability calculation under these changes?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Are the log-probability values incorrect, or do you simply think people shouldn't be allowed to do this? Is there some other inconsistency implied by this that you can demonstrate? Regardless, it's a minimal example, so there may be other cases that involve performing the same rewrite on a valued variable in order to derive a log-probability for another valued variable (e.g. like mixtures). Can you guarantee that this isn't possible, so that we can justify the limitations you're proposing?
What exactly doesn't match the original graph, and how is that relevant? |
||
| # variables | ||
| z_rv = at.broadcast_to(x_rv, (2, 2)) | ||
| z_rv.name = "Z" | ||
| z_vv = z_rv.clone() | ||
| z_vv.name = "z" | ||
|
|
||
| logp_map = factorized_joint_logprob({x_rv: x_vv, z_rv: z_vv}) | ||
| assert x_vv in logp_map | ||
| assert z_vv in logp_map | ||
| z_val = np.array([[0, 0], [0, 0]]) | ||
| assert np.allclose(logp_map[z_vv].eval({z_vv: z_val}), st.norm(0).logpdf(z_val)) | ||
|
|
||
|
|
||
| def test_local_lift_DiracDelta(): | ||
| c_at = at.vector() | ||
| dd_at = dirac_delta(c_at) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this add an extra logp term? If I am parsing this correctly, calling
joint_logpon your first new test case would now sum 3 terms corresponding to(x_vv, new_value_var, y_vv), no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried it? We can always check the sum in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first test case might actually add an extra term; I'll try it out.