Skip to content

Commit 9fcf1a0

Browse files
committed
Compute logprob & grad at the same time
1 parent bd7d149 commit 9fcf1a0

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

blackjax/mcmc/diffusion.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
class DiffusionState(NamedTuple):
1313
position: PyTree
14+
logprob: float
1415
logprob_grad: PyTree
1516

1617

@@ -24,7 +25,7 @@ def overdamped_langevin(logprob_grad_fn):
2425
"""Euler solver for overdamped Langevin diffusion."""
2526

2627
def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()):
27-
position, logprob_grad = state
28+
position, _, logprob_grad = state
2829
noise = generate_gaussian_noise(rng_key, position)
2930
position = jax.tree_util.tree_multimap(
3031
lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
@@ -33,7 +34,7 @@ def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()
3334
noise,
3435
)
3536

36-
logprob_grad = logprob_grad_fn(position, *batch)
37-
return DiffusionState(position, logprob_grad)
37+
logprob, logprob_grad = logprob_grad_fn(position, *batch)
38+
return DiffusionState(position, logprob, logprob_grad)
3839

3940
return one_step

blackjax/mcmc/mala.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class MALAState(NamedTuple):
2121
"""
2222

2323
position: PyTree
24+
logprob: float
2425
logprob_grad: PyTree
2526

2627

@@ -43,9 +44,9 @@ class MALAInfo(NamedTuple):
4344

4445

4546
def init(position: PyTree, logprob_fn: Callable) -> MALAState:
46-
grad_fn = jax.grad(logprob_fn)
47-
logprob_grad = grad_fn(position)
48-
return MALAState(position, logprob_grad)
47+
grad_fn = jax.value_and_grad(logprob_fn)
48+
logprob, logprob_grad = grad_fn(position)
49+
return MALAState(position, logprob, logprob_grad)
4950

5051

5152
def kernel():
@@ -72,16 +73,16 @@ def one_step(
7273
TODO expand the docstring.
7374
7475
"""
75-
grad_fn = jax.grad(logprob_fn)
76+
grad_fn = jax.value_and_grad(logprob_fn)
7677
integrator = overdamped_langevin(grad_fn)
7778

7879
key_integrator, key_rmh = jax.random.split(rng_key)
7980

8081
new_state = integrator(key_integrator, state, step_size)
8182

8283
delta = (
83-
logprob_fn(new_state.position)
84-
- logprob_fn(state.position)
84+
new_state.logprob
85+
- state.logprob
8586
+ transition_probability(new_state, state, step_size)
8687
- transition_probability(state, new_state, step_size)
8788
)

0 commit comments

Comments
 (0)