Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 126 additions & 45 deletions s2fft/precompute_transforms/fourier_wigner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import jit

from s2fft import recursions
from s2fft.utils import quadrature, quadrature_jax


def inverse_transform(
flmn: np.ndarray,
DW: np.ndarray,
L: int,
N: int,
precomps: tuple[np.ndarray, np.ndarray] | None = None,
reality: bool = False,
sampling: str = "mw",
) -> np.ndarray:
Expand All @@ -18,10 +23,11 @@ def inverse_transform(

Args:
flmn (np.ndarray): Wigner coefficients.
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
reduced Wigner d-functions and the corresponding upsampled quadrature
weights. Defaults to None.
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs.
Defaults to False.
Expand All @@ -37,9 +43,6 @@ def inverse_transform(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
n_dim = N if reality else 2 * N - 1
Expand All @@ -52,15 +55,29 @@ def inverse_transform(
m = np.arange(-L + 1 - m_offset, L)
n = np.arange(n_start_ind - N + 1, N)

# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
x[m_offset:, m_offset:] = np.einsum(
"nlm,lam,lan,l->amn",
flmn[n_start_ind:],
Delta,
Delta[:, :, L - 1 + n],
(2 * np.arange(L) + 1) / (8 * np.pi**2),
)
flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2))

# PRECOMPUTE TRANSFORM
if precomps is not None:
delta, _ = precomps
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
x[m_offset:, m_offset:] = np.einsum(
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
)

# OTF TRANSFORM
else:
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
for el in range(L):
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
x[m_offset:, m_offset:] += np.einsum(
"nm,am,an->amn",
flmn[n_start_ind:, el],
delta_el,
delta_el[:, L - 1 + n],
)

# APPLY SIGN FUNCTION AND PHASE SHIFT
x = np.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0))
Expand All @@ -77,12 +94,12 @@ def inverse_transform(
return np.fft.ifft2(x, axes=(0, 2), norm="forward")


@partial(jit, static_argnums=(2, 3, 4, 5))
@partial(jit, static_argnums=(1, 2, 4, 5))
def inverse_transform_jax(
flmn: jnp.ndarray,
DW: jnp.ndarray,
L: int,
N: int,
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
reality: bool = False,
sampling: str = "mw",
) -> jnp.ndarray:
Expand All @@ -91,10 +108,11 @@ def inverse_transform_jax(

Args:
flmn (jnp.ndarray): Wigner coefficients.
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
reduced Wigner d-functions and the corresponding upsampled quadrature
weights. Defaults to None.
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs.
Defaults to False.
Expand All @@ -110,9 +128,6 @@ def inverse_transform_jax(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
n_dim = N if reality else 2 * N - 1
Expand All @@ -125,14 +140,32 @@ def inverse_transform_jax(
m = jnp.arange(-L + 1 - m_offset, L)
n = jnp.arange(n_start_ind - N + 1, N)

# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
x = x.at[m_offset:, m_offset:].set(
jnp.einsum(
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]

# PRECOMPUTE TRANSFORM
if precomps is not None:
delta, _ = precomps
x = x.at[m_offset:, m_offset:].set(
jnp.einsum(
"nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n]
)
)
)

# OTF TRANSFORM
else:
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
for el in range(L):
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
x = x.at[m_offset:, m_offset:].add(
jnp.einsum(
"nm,am,an->amn",
flmn[n_start_ind:, el],
delta_el,
delta_el[:, L - 1 + n],
)
)

# APPLY SIGN FUNCTION AND PHASE SHIFT
x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))
Expand All @@ -151,9 +184,9 @@ def inverse_transform_jax(

def forward_transform(
f: np.ndarray,
DW: np.ndarray,
L: int,
N: int,
precomps: tuple[np.ndarray, np.ndarray] | None = None,
reality: bool = False,
sampling: str = "mw",
) -> np.ndarray:
Expand All @@ -162,10 +195,11 @@ def forward_transform(

Args:
f (np.ndarray): Function sampled on the rotation group.
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
reduced Wigner d-functions and the corresponding upsampled quadrature
weights. Defaults to None.
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs.
Defaults to False.
Expand All @@ -181,9 +215,6 @@ def forward_transform(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
m_offset = 1 if sampling.lower() == "mwss" else 0
Expand Down Expand Up @@ -223,14 +254,39 @@ def forward_transform(
# NB: Our convention here is conjugate to that of SSHT, in which
# the weights are conjugate but applied flipped and therefore are
# equivalent. To avoid flipping here we simply conjugate the weights.
x = np.einsum("nbm,b->nbm", x, Quads)

if precomps is not None:
# PRECOMPUTE TRANSFORM
delta, quads = precomps
else:
# OTF TRANSFORM
delta = None
# COMPUTE QUADRATURE WEIGHTS
quads = np.zeros(4 * L - 3, dtype=np.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
quads = np.fft.ifft(np.fft.ifftshift(quads), norm="forward")

# APPLY QUADRATURE
x = np.einsum("nbm,b->nbm", x, quads)

# COMPUTE GMM BY FFT
x = np.fft.fft(x, axis=1, norm="forward")
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]

# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
if delta is not None:
# PRECOMPUTE TRANSFORM
x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
else:
# OTF TRANSFORM
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
for el in range(L):
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
xx[:, el] = np.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
x = xx

x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))

# SYMMETRY REFLECT FOR N < 0
Expand All @@ -246,12 +302,12 @@ def forward_transform(
return x * (2.0 * np.pi) ** 2


@partial(jit, static_argnums=(2, 3, 4, 5))
@partial(jit, static_argnums=(1, 2, 4, 5))
def forward_transform_jax(
f: jnp.ndarray,
DW: jnp.ndarray,
L: int,
N: int,
precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None,
reality: bool = False,
sampling: str = "mw",
) -> jnp.ndarray:
Expand All @@ -260,10 +316,11 @@ def forward_transform_jax(

Args:
f (jnp.ndarray): Function sampled on the rotation group.
DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
reduced Wigner d-functions and the corresponding upsampled quadrature
weights. Defaults to None.
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs.
Defaults to False.
Expand All @@ -279,9 +336,6 @@ def forward_transform_jax(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
m_offset = 1 if sampling.lower() == "mwss" else 0
Expand Down Expand Up @@ -321,14 +375,41 @@ def forward_transform_jax(
# NB: Our convention here is conjugate to that of SSHT, in which
# the weights are conjugate but applied flipped and therefore are
# equivalent. To avoid flipping here we simply conjugate the weights.
x = jnp.einsum("nbm,b->nbm", x, Quads)

if precomps is not None:
# PRECOMPUTE TRANSFORM
delta, quads = precomps
else:
# OTF TRANSFORM
delta = None
# COMPUTE QUADRATURE WEIGHTS
quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
quads = quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
quads = jnp.fft.ifft(jnp.fft.ifftshift(quads), norm="forward")

# APPLY QUADRATURE
x = jnp.einsum("nbm,b->nbm", x, quads)

# COMPUTE GMM BY FFT
x = jnp.fft.fft(x, axis=1, norm="forward")
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]

# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
if delta is not None:
# PRECOMPUTE TRANSFORM
x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
else:
# OTF TRANSFORM
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
for el in range(L):
delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el)
xx = xx.at[:, el].set(
jnp.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
)
x = xx

x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))

# SYMMETRY REFLECT FOR N < 0
Expand Down
Loading
Loading