Skip to content

Commit 5957107

Browse files
authored
Merge pull request #215 from danielward27/avoid_array_doctest
Avoid comparing array values doctest
2 parents 262d7a9 + ed95e56 commit 5957107

File tree

5 files changed

+38
-40
lines changed

5 files changed

+38
-40
lines changed

docs/getting_started.rst

+6-4
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ Additionally, we can evaluate the log probabilities of these samples
3636

3737
.. doctest::
3838

39-
>>> normal.log_prob(sample)
40-
Array(-3.4016984, dtype=float32)
41-
>>> normal.log_prob(batch)
42-
Array([-4.8808994, -5.0121717, -3.2557464, -4.131773 ], dtype=float32)
39+
>>> log_prob = normal.log_prob(sample)
40+
>>> log_prob.shape
41+
()
42+
>>> log_probs = normal.log_prob(batch)
43+
>>> log_probs.shape
44+
(4,)
4345

4446
When ``sample.shape == distribution.shape``, a scalar log probability is returned. For
4547
a batch of samples, the shape of the returned log probabilities matches the shape

flowjax/bijections/affine.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ class AdditiveCondition(AbstractBijection):
180180
>>> bijection = AdditiveCondition(
181181
... Linear(2, 3, key=jr.key(0)), shape=(3,), cond_shape=(2,)
182182
... )
183-
>>> bijection.transform(jnp.ones(3), condition=jnp.ones(2))
184-
Array([1.9670618, 0.8156546, 1.7763454], dtype=float32)
183+
>>> y = bijection.transform(jnp.ones(3), condition=jnp.ones(2))
185184
186185
"""
187186

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ license = { file = "LICENSE" }
2323
name = "flowjax"
2424
readme = "README.md"
2525
requires-python = ">=3.10"
26-
version = "17.1.0"
26+
version = "17.1.1"
2727

2828
[project.urls]
2929
repository = "https://github.com/danielward27/flowjax"

tests/test_train/test_data_fit.py

-31
This file was deleted.

tests/test_train/test_loops.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
1+
import equinox as eqx
12
import jax.numpy as jnp
23
import jax.random as jr
34
import pytest
5+
from jax import random
46

5-
from flowjax.distributions import Normal, StandardNormal
6-
from flowjax.train.loops import fit_to_key_based_loss
7+
from flowjax.bijections import Affine
8+
from flowjax.distributions import Normal, StandardNormal, Transformed
9+
from flowjax.train.loops import fit_to_data, fit_to_key_based_loss
710
from flowjax.train.losses import ElboLoss
811

12+
13+
def test_data_fit():
14+
dim = 3
15+
mean, std = jnp.ones(dim), jnp.ones(dim)
16+
base_dist = Normal(mean, std)
17+
flow = Transformed(base_dist, Affine(jnp.ones(dim), jnp.ones(dim)))
18+
19+
# All params should change by default
20+
before = eqx.filter(flow, eqx.is_inexact_array)
21+
x = random.normal(random.key(0), (100, dim))
22+
flow, losses = fit_to_data(
23+
random.key(0),
24+
dist=flow,
25+
x=x,
26+
max_epochs=1,
27+
batch_size=50,
28+
)
29+
after = eqx.filter(flow, eqx.is_inexact_array)
30+
31+
assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc)
32+
assert jnp.all(before.bijection.loc != after.bijection.loc)
33+
assert isinstance(losses["train"][0], float)
34+
assert isinstance(losses["val"][0], float)
35+
36+
937
test_shapes = [(), (2,), (2, 3, 4)]
1038

1139

0 commit comments

Comments
 (0)