Skip to content

Commit

Permalink
Update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Oct 27, 2024
1 parent f9b6258 commit ad3abd9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pymc_experimental/inference/jax_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def fit_laplace(
)

f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph(
logp,
cast(TensorVariable, logp),
use_grad=True,
use_hess=True,
use_hessp=False,
Expand Down Expand Up @@ -376,8 +376,8 @@ def find_MAP(
Seed for the random number generator or a numpy Generator for reproducibility
return_raw: bool | False, optinal
Whether to also return the full output of `scipy.optimize.minimize`
jitter : bool, optional
Whether to add jitter to the initial values. Defaults to False.
jitter_rvs : list of TensorVariables, optional
Variables whose initial values should be jittered. If None, all variables are jittered.
progressbar : bool, optional
Whether to display a progress bar during optimization. Defaults to True.
include_transformed: bool, optional
Expand Down

0 comments on commit ad3abd9

Please sign in to comment.