Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API(shells): partition() to fit windows to function #122

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 95 additions & 25 deletions glass/shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,45 +331,115 @@ def restrict(z: ArrayLike1D, f: ArrayLike1D, w: RadialWindow
return zr, fr


def partition(z: ArrayLike1D, f: ArrayLike1D, ws: Sequence[RadialWindow]
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
'''Partition a function by a sequence of windows.

Partitions the given function into a sequence of functions
restricted to each window function.

The function :math:`f(z)` is given by redshifts ``z`` of shape
*(N,)* and function values ``f`` of shape *(..., N)*, with any
number of leading axes allowed.

The window functions are given by the sequence ``ws`` of
def partition(z: ArrayLike,
f: ArrayLike,
ws: Sequence[RadialWindow],
*,
method: str = "lstsq",
) -> ArrayLike:
"""Partition a function by a sequence of windows.

Returns a vector of weights :math:`x_1, x_2, \\ldots` such that the
weighted sum of normalised radial window functions :math:`x_1 \\,
w_1(z) + x_2 \\, w_2(z) + \\ldots` approximates the given function
:math:`f(z)`.

The function :math:`f(z)` is given by redshifts *z* of shape *(N,)*
and function values *f* of shape *(..., N)*, with any number of
leading axes allowed.

The window functions are given by the sequence *ws* of
:class:`RadialWindow` or compatible entries.

The partitioned functions have redshifts that are the union of the
redshifts of the original function and each window over the support
of said window. Intermediate function values are found by linear
interpolation

Parameters
----------
z, f : array_like
The function to be partitioned.
The function to be partitioned. If *f* is multi-dimensional,
its last axis must agree with *z*.
ws : sequence of :class:`RadialWindow`
Ordered sequence of window functions for the partition.
method : {"lstsq", "restrict"}
Method for the partition. See notes for description.

Returns
-------
zp, fp : list of array
The partitioned functions, ordered as the given windows.
x : array_like
Weights of the partition. If *f* is multi-dimensional, the
leading axes of *x* match those of *f*.

Notes
-----
Formally, if :math:`w_i` are the normalised window functions,
:math:`f` is the target function, and :math:`z_i` is a redshift grid
with intervals :math:`\\Delta z_i`, the partition problem seeks an
approximate solution of

.. math::
\\begin{pmatrix}
w_1(z_1) \\Delta z_1 & w_2(z_1) \\, \\Delta z_1 & \\cdots \\\\
w_1(z_2) \\Delta z_2 & w_2(z_2) \\, \\Delta z_2 & \\cdots \\\\
\\vdots & \\vdots & \\ddots
\\end{pmatrix} \\, \\begin{pmatrix}
x_1 \\\\ x_2 \\\\ \\vdots
\\end{pmatrix} = \\begin{pmatrix}
f(z_1) \\, \\Delta z_1 \\\\ f(z_2) \\, \\Delta z_2 \\\\ \\vdots
\\end{pmatrix} \\;.

The redshift grid is the union of the given array *z* and the
redshift arrays of all window functions. Intermediate function
values are found by linear interpolation.

If ``method="lstsq"``, obtain a partition from a least-squares
solution. This will more closely match the shape of the input
function, but the normalisation might differ.

If ``method="restrict"``, obtain a partition by integrating the
restriction (using :func:`restrict`) of the function :math:`f` to
each window. This will more closely match the normalisation of the
input function, but the shape might differ.

"""
try:
partition_method = globals()[f"partition_{method}"]
except KeyError:
raise ValueError(f"invalid method: {method}") from None
return partition_method(z, f, ws)


def partition_lstsq(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
"""Least-squares partition."""

# compute the union of all given redshift grids
zp = z
for w in ws:
zp = np.union1d(zp, w.za)

'''
# compute grid spacing
dz = np.gradient(zp)

# create the window function matrix
a = [np.interp(zp, za, wa, left=0., right=0.) for za, wa, _ in ws]
a = a/np.trapz(a, zp, axis=-1)[..., None]
a = a*dz

# create the target vector of distribution values
b = ndinterp(zp, z, f, left=0., right=0.)
b = b*dz

# return least-squares fit
return np.linalg.lstsq(a.T, b.T, rcond=None)[0].T


def partition_restrict(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
"""Partition by restriction and integration."""

zp, fp = [], []
ngal = []
for w in ws:
zr, fr = restrict(z, f, w)
zp.append(zr)
fp.append(fr)
return zp, fp
ngal.append(np.trapz(fr, zr, axis=-1))
return np.transpose(ngal)


def redshift_grid(zmin, zmax, *, dz=None, num=None):
Expand Down
38 changes: 0 additions & 38 deletions glass/test/test_shells.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import numpy.testing as npt


def test_tophat_windows():
Expand Down Expand Up @@ -47,40 +46,3 @@ def test_restrict():
i = np.searchsorted(zr, zi)
assert zr[i] == zi
assert fr[i] == fi*np.interp(zi, w.za, w.wa)


def test_partition():
from glass.shells import partition, RadialWindow

# Gaussian test function
z = np.linspace(0., 5., 1000)
f = np.exp(-((z - 2.)/0.5)**2/2)

# overlapping triangular weight functions
ws = [RadialWindow(za=[0., 1., 2.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[1., 2., 3.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[2., 3., 4.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[3., 4., 5.], wa=[0., 1., 0.], zeff=None)]

zp, fp = partition(z, f, ws)

assert len(zp) == len(fp) == len(ws)

for zr, w in zip(zp, ws):
assert np.all((zr >= w.za[0]) & (zr <= w.za[-1]))

for zr, fr, w in zip(zp, fp, ws):
f_ = np.interp(zr, z, f, left=0., right=0.)
w_ = np.interp(zr, w.za, w.wa, left=0., right=0.)
npt.assert_allclose(fr, f_*w_)

f_ = sum(np.interp(z, zr, fr, left=0., right=0.)
for zr, fr in zip(zp, fp))

# first and last points have zero total weight
assert f_[0] == f_[-1] == 0.

# find first and last index where total weight becomes unity
i, j = np.searchsorted(z, [ws[0].za[1], ws[-1].za[1]])

npt.assert_allclose(f_[i:j], f[i:j], atol=1e-15)
Loading