@@ -21,6 +21,7 @@ class MALAState(NamedTuple):
21
21
"""
22
22
23
23
position : PyTree
24
+ logprob : float
24
25
logprob_grad : PyTree
25
26
26
27
@@ -43,9 +44,9 @@ class MALAInfo(NamedTuple):
43
44
44
45
45
46
def init (position : PyTree , logprob_fn : Callable ) -> MALAState :
46
- grad_fn = jax .grad (logprob_fn )
47
- logprob_grad = grad_fn (position )
48
- return MALAState (position , logprob_grad )
47
+ grad_fn = jax .value_and_grad (logprob_fn )
48
+ logprob , logprob_grad = grad_fn (position )
49
+ return MALAState (position , logprob , logprob_grad )
49
50
50
51
51
52
def kernel ():
@@ -72,16 +73,16 @@ def one_step(
72
73
TODO expand the docstring.
73
74
74
75
"""
75
- grad_fn = jax .grad (logprob_fn )
76
+ grad_fn = jax .value_and_grad (logprob_fn )
76
77
integrator = overdamped_langevin (grad_fn )
77
78
78
79
key_integrator , key_rmh = jax .random .split (rng_key )
79
80
80
81
new_state = integrator (key_integrator , state , step_size )
81
82
82
83
delta = (
83
- logprob_fn ( new_state .position )
84
- - logprob_fn ( state .position )
84
+ new_state .logprob
85
+ - state .logprob
85
86
+ transition_probability (new_state , state , step_size )
86
87
- transition_probability (state , new_state , step_size )
87
88
)
0 commit comments