Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-268: add more tests for glass.core #263

Merged
merged 6 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/core/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ def test_nnls(rng):

from glass.core.algorithm import nnls as nnls_glass

# cross-check output with scipy's nnls

a = rng.standard_normal((100, 20))
b = rng.standard_normal((100,))

x_glass = nnls_glass(a, b)
x_scipy, _ = nnls_scipy(a, b)

np.testing.assert_allclose(x_glass, x_scipy)

# check matrix and vector's shape

with pytest.raises(ValueError, match="input `a` is not a matrix"):
nnls_glass(b, a)
with pytest.raises(ValueError, match="input `b` is not a vector"):
nnls_glass(a, a)
with pytest.raises(ValueError, match="the shapes of `a` and `b` do not match"):
nnls_glass(a.T, b)
92 changes: 92 additions & 0 deletions tests/core/test_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
import numpy as np
import numpy.testing as npt
import pytest

# check if scipy is available for testing
try:
import scipy
except ImportError:
HAVE_SCIPY = False
else:
del scipy
HAVE_SCIPY = True


def broadcast_first():
from glass.core.array import broadcast_first
Saransh-cpp marked this conversation as resolved.
Show resolved Hide resolved

a = np.ones((2, 3, 4))
b = np.ones((2, 1))

# arrays with shape ((3, 4, 2)) and ((1, 2)) are passed
# to np.broadcast_arrays; hence it works
a_a, b_a = broadcast_first(a, b)
assert a_a.shape == (2, 3, 4)
assert b_a.shape == (2, 3, 4)

# plain np.broadcast_arrays will not work
with pytest.raises(ValueError, match="shape mismatch"):
np.broadcast_arrays(a, b)

# arrays with shape ((5, 6, 4)) and ((6, 5)) are passed
# to np.broadcast_arrays; hence it will not work
a = np.ones((4, 5, 6))
b = np.ones((5, 6))

with pytest.raises(ValueError, match="shape mismatch"):
broadcast_first(a, b)

# plain np.broadcast_arrays will work
a_a, b_a = broadcast_first(a, b)

assert a_a.shape == (4, 5, 6)
assert b_a.shape == (4, 5, 6)


def test_broadcast_leading_axes():
Expand Down Expand Up @@ -114,3 +155,54 @@ def test_trapz_product():
s = trapz_product((x1, f1), (x2, f2))

assert np.allclose(s, 1.0)


@pytest.mark.skipif(not HAVE_SCIPY, reason="test requires SciPy")
def test_cumtrapz():
Saransh-cpp marked this conversation as resolved.
Show resolved Hide resolved
from scipy.integrate import cumulative_trapezoid

from glass.core.array import cumtrapz

# 1D f and x

f = np.array([1, 2, 3, 4])
x = np.array([0, 1, 2, 3])

# default dtype (int - not supported by scipy)

glass_ct = cumtrapz(f, x)
npt.assert_allclose(glass_ct, np.array([0, 1, 4, 7]))

# explicit dtype (float)

glass_ct = cumtrapz(f, x, dtype=float)
scipy_ct = cumulative_trapezoid(f, x, initial=0)
npt.assert_allclose(glass_ct, scipy_ct)

# explicit return array

result = cumtrapz(f, x, dtype=float, out=np.zeros((4,)))
scipy_ct = cumulative_trapezoid(f, x, initial=0)
npt.assert_allclose(result, scipy_ct)

# 2D f and 1D x

f = np.array([[1, 4, 9, 16], [2, 3, 5, 7]])
x = np.array([0, 1, 2.5, 4])

# default dtype (int - not supported by scipy)

glass_ct = cumtrapz(f, x)
npt.assert_allclose(glass_ct, np.array([[0, 2, 12, 31], [0, 2, 8, 17]]))

# explicit dtype (float)

glass_ct = cumtrapz(f, x, dtype=float)
scipy_ct = cumulative_trapezoid(f, x, initial=0)
npt.assert_allclose(glass_ct, scipy_ct)

# explicit return array

glass_ct = cumtrapz(f, x, dtype=float, out=np.zeros((2, 4)))
scipy_ct = cumulative_trapezoid(f, x, initial=0)
npt.assert_allclose(glass_ct, scipy_ct)