Skip to content

Commit

Permalink
BUG(shells): fix partition() for multi-dimensional arrays (#132)
Browse files Browse the repository at this point in the history
Fix the `partition()` function for multi-dimensional arrays. The axis
corresponding to the shells is now the *first* one.

Fixed: partition() now works correctly with functions having extra axes.
Changed: The output of partition() now has the shells axis as its first.
  • Loading branch information
ntessore authored Oct 13, 2023
1 parent 8cefd56 commit 29938e2
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
67 changes: 42 additions & 25 deletions glass/shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def restrict(z: ArrayLike1D, f: ArrayLike1D, w: RadialWindow


def partition(z: ArrayLike,
f: ArrayLike,
ws: Sequence[RadialWindow],
fz: ArrayLike,
shells: Sequence[RadialWindow],
*,
method: str = "lstsq",
) -> ArrayLike:
Expand All @@ -345,27 +345,27 @@ def partition(z: ArrayLike,
:math:`f(z)`.
The function :math:`f(z)` is given by redshifts *z* of shape *(N,)*
and function values *f* of shape *(..., N)*, with any number of
and function values *fz* of shape *(..., N)*, with any number of
leading axes allowed.
The window functions are given by the sequence *ws* of
The window functions are given by the sequence *shells* of
:class:`RadialWindow` or compatible entries.
Parameters
----------
z, f : array_like
z, fz : array_like
The function to be partitioned. If *f* is multi-dimensional,
its last axis must agree with *z*.
ws : sequence of :class:`RadialWindow`
shells : sequence of :class:`RadialWindow`
Ordered sequence of window functions for the partition.
method : {"lstsq", "restrict"}
Method for the partition. See notes for description.
Returns
-------
x : array_like
Weights of the partition. If *f* is multi-dimensional, the
leading axes of *x* match those of *f*.
Weights of the partition, where the leading axis corresponds to
*shells*.
Notes
-----
Expand Down Expand Up @@ -403,43 +403,60 @@ def partition(z: ArrayLike,
partition_method = globals()[f"partition_{method}"]
except KeyError:
raise ValueError(f"invalid method: {method}") from None
return partition_method(z, f, ws)
return partition_method(z, fz, shells)


def partition_lstsq(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
def partition_lstsq(
z: ArrayLike,
fz: ArrayLike,
shells: Sequence[RadialWindow],
) -> ArrayLike:
"""Least-squares partition."""

# compute the union of all given redshift grids
zp = z
for w in ws:
for w in shells:
zp = np.union1d(zp, w.za)

# get extra leading axes of fz
*dims, _ = np.shape(fz)

# compute grid spacing
dz = np.gradient(zp)

# create the window function matrix
a = [np.interp(zp, za, wa, left=0., right=0.) for za, wa, _ in ws]
a = [np.interp(zp, za, wa, left=0., right=0.) for za, wa, _ in shells]
a = a/np.trapz(a, zp, axis=-1)[..., None]
a = a*dz

# create the target vector of distribution values
b = ndinterp(zp, z, f, left=0., right=0.)
b = ndinterp(zp, z, fz, left=0., right=0.)
b = b*dz

# return least-squares fit
return np.linalg.lstsq(a.T, b.T, rcond=None)[0].T


def partition_restrict(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
# now a is a matrix of shape (len(shells), len(zp))
# and b is a matrix of shape (*dims, len(zp))
# need to find weights x such that b == x @ a over all axes of b
# do the least-squares fit over partially flattened b, then reshape
x = np.linalg.lstsq(a.T, b.reshape(-1, zp.size).T, rcond=None)[0]
x = x.T.reshape(*dims, len(shells))
# roll the last axis of size len(shells) to the front
x = np.moveaxis(x, -1, 0)
# all done
return x


def partition_restrict(
z: ArrayLike,
fz: ArrayLike,
shells: Sequence[RadialWindow],
) -> ArrayLike:
"""Partition by restriction and integration."""

ngal = []
for w in ws:
zr, fr = restrict(z, f, w)
ngal.append(np.trapz(fr, zr, axis=-1))
return np.transpose(ngal)
part = np.empty((len(shells),) + np.shape(fz)[:-1])
for i, w in enumerate(shells):
zr, fr = restrict(z, fz, w)
part[i] = np.trapz(fr, zr, axis=-1)
return part


def redshift_grid(zmin, zmax, *, dz=None, num=None):
Expand Down
26 changes: 26 additions & 0 deletions glass/test/test_shells.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest


def test_tophat_windows():
Expand Down Expand Up @@ -46,3 +47,28 @@ def test_restrict():
i = np.searchsorted(zr, zi)
assert zr[i] == zi
assert fr[i] == fi*np.interp(zi, w.za, w.wa)


@pytest.mark.parametrize("method", ["lstsq", "restrict"])
def test_partition(method):
import numpy as np
from glass.shells import RadialWindow, partition

shells = [
RadialWindow(np.array([0., 1.]), np.array([1., 0.]), 0.0),
RadialWindow(np.array([0., 1., 2.]), np.array([0., 1., 0.]), 0.5),
RadialWindow(np.array([1., 2., 3.]), np.array([0., 1., 0.]), 1.5),
RadialWindow(np.array([2., 3., 4.]), np.array([0., 1., 0.]), 2.5),
RadialWindow(np.array([3., 4., 5.]), np.array([0., 1., 0.]), 3.5),
RadialWindow(np.array([4., 5.]), np.array([0., 1.]), 5.0),
]

z = np.linspace(0., 5., 1000)
k = 1 + np.arange(6).reshape(3, 2, 1)
fz = np.exp(-z / k)

assert fz.shape == (3, 2, 1000)

part = partition(z, fz, shells, method=method)

assert part.shape == (len(shells), 3, 2)

0 comments on commit 29938e2

Please sign in to comment.