Skip to content

Commit 07e4c9b

Browse files
authored
Bump to 0.16.0 (#1917)
* Bump to 0.16.0
1 parent f87f40e commit 07e4c9b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+76
-57
lines changed

examples/annotation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def main(args):
320320

321321

322322
if __name__ == "__main__":
323-
assert numpyro.__version__.startswith("0.15.3")
323+
assert numpyro.__version__.startswith("0.16.0")
324324
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
325325
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
326326
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/ar2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main(args):
138138

139139

140140
if __name__ == "__main__":
141-
assert numpyro.__version__.startswith("0.15.3")
141+
assert numpyro.__version__.startswith("0.16.0")
142142
parser = argparse.ArgumentParser(description="AR2 example")
143143
parser.add_argument("--num-data", nargs="?", default=142, type=int)
144144
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)

examples/baseball.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def main(args):
210210

211211

212212
if __name__ == "__main__":
213-
assert numpyro.__version__.startswith("0.15.3")
213+
assert numpyro.__version__.startswith("0.16.0")
214214
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
215215
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
216216
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)

examples/bnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(args):
160160

161161

162162
if __name__ == "__main__":
163-
assert numpyro.__version__.startswith("0.15.3")
163+
assert numpyro.__version__.startswith("0.16.0")
164164
parser = argparse.ArgumentParser(description="Bayesian neural network example")
165165
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
166166
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/covtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def main(args):
206206

207207

208208
if __name__ == "__main__":
209-
assert numpyro.__version__.startswith("0.15.3")
209+
assert numpyro.__version__.startswith("0.16.0")
210210
parser = argparse.ArgumentParser(description="parse args")
211211
parser.add_argument(
212212
"-n", "--num-samples", default=1000, type=int, help="number of samples"

examples/funnel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main(args):
139139

140140

141141
if __name__ == "__main__":
142-
assert numpyro.__version__.startswith("0.15.3")
142+
assert numpyro.__version__.startswith("0.16.0")
143143
parser = argparse.ArgumentParser(
144144
description="Non-centered reparameterization example"
145145
)

examples/gaussian_shells.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def main(args):
120120

121121

122122
if __name__ == "__main__":
123-
assert numpyro.__version__.startswith("0.15.3")
123+
assert numpyro.__version__.startswith("0.16.0")
124124

125125
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
126126
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)

examples/gp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def main(args):
180180

181181

182182
if __name__ == "__main__":
183-
assert numpyro.__version__.startswith("0.15.3")
183+
assert numpyro.__version__.startswith("0.16.0")
184184
parser = argparse.ArgumentParser(description="Gaussian Process example")
185185
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
186186
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/hmm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def main(args):
263263

264264

265265
if __name__ == "__main__":
266-
assert numpyro.__version__.startswith("0.15.3")
266+
assert numpyro.__version__.startswith("0.16.0")
267267
parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model")
268268
parser.add_argument("--num-categories", default=3, type=int)
269269
parser.add_argument("--num-words", default=10, type=int)

examples/holt_winters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def main(args):
180180

181181

182182
if __name__ == "__main__":
183-
assert numpyro.__version__.startswith("0.15.3")
183+
assert numpyro.__version__.startswith("0.16.0")
184184
parser = argparse.ArgumentParser(description="Holt-Winters")
185185
parser.add_argument("--T", nargs="?", default=6, type=int)
186186
parser.add_argument("--future", nargs="?", default=1, type=int)

examples/horseshoe_regression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def main(args):
162162

163163

164164
if __name__ == "__main__":
165-
assert numpyro.__version__.startswith("0.15.3")
165+
assert numpyro.__version__.startswith("0.16.0")
166166
parser = argparse.ArgumentParser(description="Horseshoe regression example")
167167
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
168168
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/minipyro.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def body_fn(i, val):
5858

5959

6060
if __name__ == "__main__":
61-
assert numpyro.__version__.startswith("0.15.3")
61+
assert numpyro.__version__.startswith("0.16.0")
6262
parser = argparse.ArgumentParser(description="Mini Pyro demo")
6363
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
6464
parser.add_argument("-n", "--num-steps", default=1001, type=int)

examples/mortality.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def main(args):
220220

221221

222222
if __name__ == "__main__":
223-
assert numpyro.__version__.startswith("0.15.3")
223+
assert numpyro.__version__.startswith("0.16.0")
224224

225225
parser = argparse.ArgumentParser(description="Mortality regression model")
226226
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)

examples/neutra.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def main(args):
197197

198198

199199
if __name__ == "__main__":
200-
assert numpyro.__version__.startswith("0.15.3")
200+
assert numpyro.__version__.startswith("0.16.0")
201201
parser = argparse.ArgumentParser(description="NeuTra HMC")
202202
parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int)
203203
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/ode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main(args):
117117

118118

119119
if __name__ == "__main__":
120-
assert numpyro.__version__.startswith("0.15.3")
120+
assert numpyro.__version__.startswith("0.16.0")
121121
parser = argparse.ArgumentParser(description="Predator-Prey Model")
122122
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
123123
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)

examples/prodlda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def main(args):
315315

316316

317317
if __name__ == "__main__":
318-
assert numpyro.__version__.startswith("0.15.3")
318+
assert numpyro.__version__.startswith("0.16.0")
319319
parser = argparse.ArgumentParser(
320320
description="Probabilistic topic modelling with Flax and Haiku"
321321
)

examples/proportion_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def main(args):
158158

159159

160160
if __name__ == "__main__":
161-
assert numpyro.__version__.startswith("0.15.3")
161+
assert numpyro.__version__.startswith("0.16.0")
162162
parser = argparse.ArgumentParser(description="Testing whether ")
163163
parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int)
164164
parser.add_argument("--num-warmup", nargs="?", default=1500, type=int)

examples/sparse_regression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def main(args):
384384

385385

386386
if __name__ == "__main__":
387-
assert numpyro.__version__.startswith("0.15.3")
387+
assert numpyro.__version__.startswith("0.16.0")
388388
parser = argparse.ArgumentParser(description="Gaussian Process example")
389389
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
390390
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)

examples/stochastic_volatility.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def main(args):
122122

123123

124124
if __name__ == "__main__":
125-
assert numpyro.__version__.startswith("0.15.3")
125+
assert numpyro.__version__.startswith("0.16.0")
126126
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
127127
parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int)
128128
parser.add_argument("--num-warmup", nargs="?", default=600, type=int)

examples/thompson_sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def main(args):
292292

293293

294294
if __name__ == "__main__":
295-
assert numpyro.__version__.startswith("0.15.3")
295+
assert numpyro.__version__.startswith("0.16.0")
296296
parser = argparse.ArgumentParser(description="Thompson sampling example")
297297
parser.add_argument(
298298
"--num-random", nargs="?", default=2, type=int, help="number of random draws"

examples/toy_mixture_model_discrete_enumeration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_true_pred_CPDs(CPD, posterior_param):
126126

127127

128128
if __name__ == "__main__":
129-
assert numpyro.__version__.startswith("0.15.3")
129+
assert numpyro.__version__.startswith("0.16.0")
130130
parser = argparse.ArgumentParser(description="Toy mixture model")
131131
parser.add_argument("-n", "--num-steps", default=4000, type=int)
132132
parser.add_argument("-o", "--num-obs", default=10000, type=int)

examples/ucbadmit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def main(args):
151151

152152

153153
if __name__ == "__main__":
154-
assert numpyro.__version__.startswith("0.15.3")
154+
assert numpyro.__version__.startswith("0.16.0")
155155
parser = argparse.ArgumentParser(
156156
description="UCBadmit gender discrimination using HMC"
157157
)

examples/vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key):
160160

161161

162162
if __name__ == "__main__":
163-
assert numpyro.__version__.startswith("0.15.3")
163+
assert numpyro.__version__.startswith("0.16.0")
164164
parser = argparse.ArgumentParser(description="parse args")
165165
parser.add_argument(
166166
"-n", "--num-epochs", default=15, type=int, help="number of training epochs"

notebooks/source/bad_posterior_geometry.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"import numpyro.distributions as dist\n",
5050
"from numpyro.infer import MCMC, NUTS\n",
5151
"\n",
52-
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
52+
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
5353
"\n",
5454
"# NB: replace cpu by gpu to run this notebook on gpu\n",
5555
"numpyro.set_platform(\"cpu\")"

notebooks/source/bayesian_hierarchical_linear_regression.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@
246246
"import numpyro.distributions as dist\n",
247247
"from numpyro.infer import MCMC, NUTS, Predictive\n",
248248
"\n",
249-
"assert numpyro.__version__.startswith(\"0.15.3\")"
249+
"assert numpyro.__version__.startswith(\"0.16.0\")"
250250
]
251251
},
252252
{

notebooks/source/bayesian_hierarchical_stacking.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
" set_matplotlib_formats(\"svg\")\n",
9898
"\n",
9999
"numpyro.set_host_device_count(4)\n",
100-
"assert numpyro.__version__.startswith(\"0.15.3\")"
100+
"assert numpyro.__version__.startswith(\"0.16.0\")"
101101
]
102102
},
103103
{

notebooks/source/bayesian_imputation.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
5353
" set_matplotlib_formats(\"svg\")\n",
5454
"\n",
55-
"assert numpyro.__version__.startswith(\"0.15.3\")"
55+
"assert numpyro.__version__.startswith(\"0.16.0\")"
5656
]
5757
},
5858
{

notebooks/source/bayesian_regression.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
9292
" set_matplotlib_formats(\"svg\")\n",
9393
"\n",
94-
"assert numpyro.__version__.startswith(\"0.15.3\")"
94+
"assert numpyro.__version__.startswith(\"0.16.0\")"
9595
]
9696
},
9797
{

notebooks/source/censoring.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"\n",
6161
"rng_key = random.PRNGKey(seed=0)\n",
6262
"\n",
63-
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
63+
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
6464
"\n",
6565
"%load_ext autoreload\n",
6666
"%autoreload 2\n",

notebooks/source/gmm.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"%matplotlib inline\n",
5555
"\n",
5656
"smoke_test = \"CI\" in os.environ\n",
57-
"assert numpyro.__version__.startswith(\"0.15.3\")"
57+
"assert numpyro.__version__.startswith(\"0.16.0\")"
5858
]
5959
},
6060
{

notebooks/source/hsgp_example.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"\n",
6363
"rng_key = random.PRNGKey(seed=42)\n",
6464
"\n",
65-
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
65+
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
6666
"\n",
6767
"%load_ext autoreload\n",
6868
"%autoreload 2\n",

notebooks/source/logistic_regression.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
4242
"from numpyro.infer import HMC, MCMC, NUTS\n",
4343
"\n",
44-
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
44+
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
4545
"\n",
4646
"# NB: replace gpu by cpu to run this notebook in cpu\n",
4747
"numpyro.set_platform(\"gpu\")"

notebooks/source/model_rendering.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"import numpyro.distributions as dist\n",
3939
"import numpyro.distributions.constraints as constraints\n",
4040
"\n",
41-
"assert numpyro.__version__.startswith(\"0.15.3\")"
41+
"assert numpyro.__version__.startswith(\"0.16.0\")"
4242
]
4343
},
4444
{

notebooks/source/ordinal_regression.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"from numpyro.infer import MCMC, NUTS\n",
5555
"from numpyro.infer.reparam import TransformReparam\n",
5656
"\n",
57-
"assert numpyro.__version__.startswith(\"0.15.3\")"
57+
"assert numpyro.__version__.startswith(\"0.16.0\")"
5858
]
5959
},
6060
{

notebooks/source/other_samplers.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"\n",
6868
"rng_key = random.PRNGKey(seed=42)\n",
6969
"\n",
70-
"assert numpyro.__version__.startswith(\"0.15.3\")\n",
70+
"assert numpyro.__version__.startswith(\"0.16.0\")\n",
7171
"\n",
7272
"%load_ext autoreload\n",
7373
"%autoreload 2\n",

notebooks/source/time_series_forecasting.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
" set_matplotlib_formats(\"svg\")\n",
4949
"\n",
5050
"numpyro.set_host_device_count(4)\n",
51-
"assert numpyro.__version__.startswith(\"0.15.3\")"
51+
"assert numpyro.__version__.startswith(\"0.16.0\")"
5252
]
5353
},
5454
{

numpyro/contrib/control_flow/scan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,10 @@ def body_fn(wrapped_carry, x, prefix=None):
198198
)
199199
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
200200

201-
with handlers.block(
202-
hide_fn=lambda site: not site["name"].startswith("_PREV_")
203-
), enum(first_available_dim=first_available_dim):
201+
with (
202+
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
203+
enum(first_available_dim=first_available_dim),
204+
):
204205
wrapped_carry = (0, rng_key, init)
205206
y0s = []
206207
# We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`

numpyro/infer/autoguide.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
14631463
if self.global_guide is not None:
14641464
global_latents = self.global_guide(*args, **kwargs)
14651465
rng_key = numpyro.prng_key()
1466-
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
1467-
data=global_latents
1466+
with (
1467+
handlers.block(),
1468+
handlers.seed(rng_seed=rng_key),
1469+
handlers.substitute(data=global_latents),
14681470
):
14691471
global_outputs = self.global_guide.model(*args, **kwargs)
14701472
local_args = (global_outputs,)
@@ -1575,9 +1577,12 @@ def fn(x):
15751577
if self.local_guide is not None:
15761578
key = numpyro.prng_key()
15771579
subsample_guide = partial(_subsample_model, self.local_guide)
1578-
with handlers.block(), handlers.trace() as tr, handlers.seed(
1579-
rng_seed=key
1580-
), handlers.substitute(data=local_guide_params):
1580+
with (
1581+
handlers.block(),
1582+
handlers.trace() as tr,
1583+
handlers.seed(rng_seed=key),
1584+
handlers.substitute(data=local_guide_params),
1585+
):
15811586
with warnings.catch_warnings():
15821587
warnings.simplefilter("ignore")
15831588
subsample_guide(*local_args, **local_kwargs)

numpyro/infer/elbo.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,11 @@ def get_importance_trace_enum(
893893
trace as _trace,
894894
)
895895

896-
with plate_to_enum_plate(), enum(
897-
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
896+
with (
897+
plate_to_enum_plate(),
898+
enum(
899+
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
900+
),
898901
):
899902
guide = substitute(guide, data=params)
900903
with _without_rsample_stop_gradient():

0 commit comments

Comments
 (0)