From c764a98e43487a2858205d164a6d721286b07372 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 21 Jan 2025 09:45:08 +0000 Subject: [PATCH 1/3] add on-the-fly support for Fourier Wigner transforms --- s2fft/precompute_transforms/fourier_wigner.py | 186 +++++++++++++----- tests/test_fourier_wigner.py | 38 ++-- 2 files changed, 160 insertions(+), 64 deletions(-) diff --git a/s2fft/precompute_transforms/fourier_wigner.py b/s2fft/precompute_transforms/fourier_wigner.py index 3a0299eb..01b441e2 100644 --- a/s2fft/precompute_transforms/fourier_wigner.py +++ b/s2fft/precompute_transforms/fourier_wigner.py @@ -4,12 +4,15 @@ import numpy as np from jax import jit +from s2fft import recursions +from s2fft.utils import quadrature, quadrature_jax + def inverse_transform( flmn: np.ndarray, - DW: np.ndarray, L: int, N: int, + DW: np.ndarray = None, reality: bool = False, sampling: str = "mw", ) -> np.ndarray: @@ -18,10 +21,11 @@ def inverse_transform( Args: flmn (np.ndarray): Wigner coefficients. - DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -37,9 +41,6 @@ def inverse_transform( f"Fourier-Wigner algorithm does not support {sampling} sampling." ) - # EXTRACT VARIOUS PRECOMPUTES - Delta, _ = DW - # INDEX VALUES n_start_ind = N - 1 if reality else 0 n_dim = N if reality else 2 * N - 1 @@ -54,13 +55,27 @@ def inverse_transform( # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype) - x[m_offset:, m_offset:] = np.einsum( - "nlm,lam,lan,l->amn", - flmn[n_start_ind:], - Delta, - Delta[:, :, L - 1 + n], - (2 * np.arange(L) + 1) / (8 * np.pi**2), - ) + flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2)) + + # PRECOMPUTE TRANSFORM + if DW is not None: + Delta, _ = DW + x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype) + x[m_offset:, m_offset:] = np.einsum( + "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + ) + + # OTF TRANSFORM + else: + Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) + for el in range(L): + Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el) + x[m_offset:, m_offset:] += np.einsum( + "nm,am,an->amn", + flmn[n_start_ind:, el], + Delta_el, + Delta_el[:, L - 1 + n], + ) # APPLY SIGN FUNCTION AND PHASE SHIFT x = np.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0)) @@ -77,12 +92,12 @@ def inverse_transform( return np.fft.ifft2(x, axes=(0, 2), norm="forward") -@partial(jit, static_argnums=(2, 3, 4, 5)) +@partial(jit, static_argnums=(1, 2, 4, 5)) def inverse_transform_jax( flmn: jnp.ndarray, - DW: jnp.ndarray, L: int, N: int, + DW: jnp.ndarray = None, reality: bool = False, sampling: str = "mw", ) -> jnp.ndarray: @@ -91,10 +106,11 @@ def inverse_transform_jax( Args: flmn (jnp.ndarray): Wigner coefficients. - DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -110,9 +126,6 @@ def inverse_transform_jax( f"Fourier-Wigner algorithm does not support {sampling} sampling." ) - # EXTRACT VARIOUS PRECOMPUTES - Delta, _ = DW - # INDEX VALUES n_start_ind = N - 1 if reality else 0 n_dim = N if reality else 2 * N - 1 @@ -128,11 +141,29 @@ def inverse_transform_jax( # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128) flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2)) - x = x.at[m_offset:, m_offset:].set( - jnp.einsum( - "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + + # PRECOMPUTE TRANSFORM + if DW is not None: + Delta, _ = DW + x = x.at[m_offset:, m_offset:].set( + jnp.einsum( + "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + ) ) - ) + + # OTF TRANSFORM + else: + Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) + for el in range(L): + Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el) + x = x.at[m_offset:, m_offset:].add( + jnp.einsum( + "nm,am,an->amn", + flmn[n_start_ind:, el], + Delta_el, + Delta_el[:, L - 1 + n], + ) + ) # APPLY SIGN FUNCTION AND PHASE SHIFT x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0)) @@ -151,9 +182,9 @@ def inverse_transform_jax( def forward_transform( f: np.ndarray, - DW: np.ndarray, L: int, N: int, + DW: np.ndarray = None, reality: bool = False, sampling: str = "mw", ) -> np.ndarray: @@ -162,10 +193,11 @@ def forward_transform( Args: f (np.ndarray): Function sampled on the rotation group. - DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -181,9 +213,6 @@ def forward_transform( f"Fourier-Wigner algorithm does not support {sampling} sampling." ) - # EXTRACT VARIOUS PRECOMPUTES - Delta, Quads = DW - # INDEX VALUES n_start_ind = N - 1 if reality else 0 m_offset = 1 if sampling.lower() == "mwss" else 0 @@ -223,14 +252,44 @@ def forward_transform( # NB: Our convention here is conjugate to that of SSHT, in which # the weights are conjugate but applied flipped and therefore are # equivalent. To avoid flipping here we simply conjugate the weights. - x = np.einsum("nbm,b->nbm", x, Quads) - # COMPUTE GMM BY FFT - x = np.fft.fft(x, axis=1, norm="forward") - x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + # PRECOMPUTE TRANSFORM + if DW is not None: + # EXTRACT VARIOUS PRECOMPUTES + Delta, Quads = DW - # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) + # APPLY QUADRATURE + x = np.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = np.fft.fft(x, axis=1, norm="forward") + x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) + + # OTF TRANSFORM + else: + # COMPUTE QUADRATURE WEIGHTS + Quads = np.zeros(4 * L - 3, dtype=np.complex128) + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): + Quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm) + Quads = np.fft.ifft(np.fft.ifftshift(Quads), norm="forward") + + # APPLY QUADRATURE + x = np.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = np.fft.fft(x, axis=1, norm="forward") + x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) + xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype) + for el in range(L): + Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el) + xx[:, el] = np.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n]) + x = xx x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n)) # SYMMETRY REFLECT FOR N < 0 @@ -246,12 +305,12 @@ def forward_transform( return x * (2.0 * np.pi) ** 2 -@partial(jit, static_argnums=(2, 3, 4, 5)) +@partial(jit, static_argnums=(1, 2, 4, 5)) def forward_transform_jax( f: jnp.ndarray, - DW: jnp.ndarray, L: int, N: int, + DW: jnp.ndarray = None, reality: bool = False, sampling: str = "mw", ) -> jnp.ndarray: @@ -260,10 +319,11 @@ def forward_transform_jax( Args: f (jnp.ndarray): Function sampled on the rotation group. - DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -279,9 +339,6 @@ def forward_transform_jax( f"Fourier-Wigner algorithm does not support {sampling} sampling." ) - # EXTRACT VARIOUS PRECOMPUTES - Delta, Quads = DW - # INDEX VALUES n_start_ind = N - 1 if reality else 0 m_offset = 1 if sampling.lower() == "mwss" else 0 @@ -321,14 +378,45 @@ def forward_transform_jax( # NB: Our convention here is conjugate to that of SSHT, in which # the weights are conjugate but applied flipped and therefore are # equivalent. To avoid flipping here we simply conjugate the weights. - x = jnp.einsum("nbm,b->nbm", x, Quads) - # COMPUTE GMM BY FFT - x = jnp.fft.fft(x, axis=1, norm="forward") - x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + # PRECOMPUTE TRANSFORM + if DW is not None: + # EXTRACT VARIOUS PRECOMPUTES + Delta, Quads = DW + + # APPLY QUADRATURE + x = jnp.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = jnp.fft.fft(x, axis=1, norm="forward") + x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) + + else: + Quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128) + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): + Quads = Quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm)) + Quads = jnp.fft.ifft(jnp.fft.ifftshift(Quads), norm="forward") + + # APPLY QUADRATURE + x = jnp.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = jnp.fft.fft(x, axis=1, norm="forward") + x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) + xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype) + for el in range(L): + Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el) + xx = xx.at[:, el].set( + jnp.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n]) + ) + x = xx - # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n)) # SYMMETRY REFLECT FOR N < 0 diff --git a/tests/test_fourier_wigner.py b/tests/test_fourier_wigner.py index b7062f4e..9fa9414f 100644 --- a/tests/test_fourier_wigner.py +++ b/tests/test_fourier_wigner.py @@ -15,6 +15,7 @@ reality_to_test = [False, True] sampling_schemes = ["mw", "mwss"] methods_to_test = ["numpy", "jax"] +delta_method_to_test = ["otf", "precomp"] # Test tolerance atol = 1e-12 @@ -25,6 +26,7 @@ @pytest.mark.parametrize("sampling", sampling_schemes) @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("method", methods_to_test) +@pytest.mark.parametrize("delta_method", delta_method_to_test) def test_inverse_fourier_wigner_transform( flmn_generator, s2fft_to_so3_sampling, @@ -33,6 +35,7 @@ def test_inverse_fourier_wigner_transform( sampling: str, reality: bool, method: str, + delta_method: str, ): flmn = flmn_generator(L=L, N=N, reality=reality) @@ -44,13 +47,15 @@ def test_inverse_fourier_wigner_transform( ) f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) - delta = ( - c.fourier_wigner_kernel_jax(L) - if method == "jax" - else c.fourier_wigner_kernel(L) - ) + delta = None transform = fw.inverse_transform_jax if method == "jax" else fw.inverse_transform - f_check = transform(flmn, delta, L, N, reality, sampling) + if delta_method.lower() == "precomp": + delta = ( + c.fourier_wigner_kernel_jax(L) + if method == "jax" + else c.fourier_wigner_kernel(L) + ) + f_check = transform(flmn, L, N, delta, reality, sampling) np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) @@ -59,6 +64,7 @@ def test_inverse_fourier_wigner_transform( @pytest.mark.parametrize("sampling", sampling_schemes) @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("method", methods_to_test) +@pytest.mark.parametrize("delta_method", delta_method_to_test) def test_forward_fourier_wigner_transform( flmn_generator, s2fft_to_so3_sampling, @@ -67,6 +73,7 @@ def test_forward_fourier_wigner_transform( sampling: str, reality: bool, method: str, + delta_method: str, ): flmn = flmn_generator(L=L, N=N, reality=reality) @@ -84,14 +91,15 @@ def test_forward_fourier_wigner_transform( ) flmn = samples.flmn_1d_to_3d(so3.forward(f, params), L, N) - delta = ( - c.fourier_wigner_kernel_jax(L) - if method == "jax" - else c.fourier_wigner_kernel(L) - ) + delta = None transform = fw.forward_transform_jax if method == "jax" else fw.forward_transform - - flmn_check = transform(f_3D, delta, L, N, reality, sampling) + if delta_method.lower() == "precomp": + delta = ( + c.fourier_wigner_kernel_jax(L) + if method == "jax" + else c.fourier_wigner_kernel(L) + ) + flmn_check = transform(f_3D, L, N, delta, reality, sampling) np.testing.assert_allclose(flmn, flmn_check, atol=atol) @@ -114,7 +122,7 @@ def test_inverse_fourier_wigner_transform_high_N( f = f.real if reality else f delta = c.fourier_wigner_kernel(L) - f_check = fw.inverse_transform(flmn, delta, L, N, reality, sampling) + f_check = fw.inverse_transform(flmn, L, N, delta, reality, sampling) np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) @@ -144,5 +152,5 @@ def test_forward_fourier_wigner_transform_high_N( flmn_so3 = samples.flmn_1d_to_3d(so3.forward(f_1D, params), L, N) delta = c.fourier_wigner_kernel_jax(L) - flmn_check = fw.forward_transform_jax(f_3D, delta, L, N, reality, sampling) + flmn_check = fw.forward_transform_jax(f_3D, L, N, delta, reality, sampling) np.testing.assert_allclose(flmn_so3, flmn_check, atol=atol) From 1d41923e55334f902fe9ae326f0b54e11591afdc Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 21 Jan 2025 12:22:39 +0000 Subject: [PATCH 2/3] update custom ops test for new fourier wigner variable ordering --- tests/test_lifting_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lifting_transforms.py b/tests/test_lifting_transforms.py index 34b1c831..5e36f560 100644 --- a/tests/test_lifting_transforms.py +++ b/tests/test_lifting_transforms.py @@ -28,7 +28,7 @@ def test_custom_forward_from_s2( # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, DW, L, N, False, sampling) + f = fw.inverse_transform(flmn, L, N, DW, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH @@ -77,7 +77,7 @@ def test_custom_forward_from_so3( # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, DW, L, N, False, sampling) + f = fw.inverse_transform(flmn, L, N, DW, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH @@ -121,7 +121,7 @@ def test_custom_inverse_to_s2( # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, DW, L, N, False, sampling) + f = fw.inverse_transform(flmn, L, N, DW, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH From fd1e8b1b68e6472ce72b9da865e86e34d96d0254 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 23 Apr 2025 11:08:11 +0100 Subject: [PATCH 3/3] Some minor updates to #260 (#298) * Variable name and type hint cleanup * Refactor to remove repeated code --- s2fft/precompute_transforms/fourier_wigner.py | 167 +++++++++--------- tests/test_fourier_wigner.py | 22 +-- tests/test_lifting_transforms.py | 42 ++--- 3 files changed, 113 insertions(+), 118 deletions(-) diff --git a/s2fft/precompute_transforms/fourier_wigner.py b/s2fft/precompute_transforms/fourier_wigner.py index 01b441e2..78809965 100644 --- a/s2fft/precompute_transforms/fourier_wigner.py +++ b/s2fft/precompute_transforms/fourier_wigner.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial import jax.numpy as jnp @@ -12,7 +14,7 @@ def inverse_transform( flmn: np.ndarray, L: int, N: int, - DW: np.ndarray = None, + precomps: tuple[np.ndarray, np.ndarray] | None = None, reality: bool = False, sampling: str = "mw", ) -> np.ndarray: @@ -23,9 +25,9 @@ def inverse_transform( flmn (np.ndarray): Wigner coefficients. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. - DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. - Defaults to None. + precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the + reduced Wigner d-functions and the corresponding upsampled quadrature + weights. Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -53,28 +55,28 @@ def inverse_transform( m = np.arange(-L + 1 - m_offset, L) n = np.arange(n_start_ind - N + 1, N) - # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) + # Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2) x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype) flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2)) # PRECOMPUTE TRANSFORM - if DW is not None: - Delta, _ = DW + if precomps is not None: + delta, _ = precomps x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype) x[m_offset:, m_offset:] = np.einsum( - "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + "nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n] ) # OTF TRANSFORM else: - Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) + delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) for el in range(L): - Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el) + delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el) x[m_offset:, m_offset:] += np.einsum( "nm,am,an->amn", flmn[n_start_ind:, el], - Delta_el, - Delta_el[:, L - 1 + n], + delta_el, + delta_el[:, L - 1 + n], ) # APPLY SIGN FUNCTION AND PHASE SHIFT @@ -97,7 +99,7 @@ def inverse_transform_jax( flmn: jnp.ndarray, L: int, N: int, - DW: jnp.ndarray = None, + precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None, reality: bool = False, sampling: str = "mw", ) -> jnp.ndarray: @@ -108,9 +110,9 @@ def inverse_transform_jax( flmn (jnp.ndarray): Wigner coefficients. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. - DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. - Defaults to None. + precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the + reduced Wigner d-functions and the corresponding upsampled quadrature + weights. Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -138,30 +140,30 @@ def inverse_transform_jax( m = jnp.arange(-L + 1 - m_offset, L) n = jnp.arange(n_start_ind - N + 1, N) - # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) + # Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2) x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128) flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2)) # PRECOMPUTE TRANSFORM - if DW is not None: - Delta, _ = DW + if precomps is not None: + delta, _ = precomps x = x.at[m_offset:, m_offset:].set( jnp.einsum( - "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + "nlm,lam,lan->amn", flmn[n_start_ind:], delta, delta[:, :, L - 1 + n] ) ) # OTF TRANSFORM else: - Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) + delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) for el in range(L): - Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el) + delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el) x = x.at[m_offset:, m_offset:].add( jnp.einsum( "nm,am,an->amn", flmn[n_start_ind:, el], - Delta_el, - Delta_el[:, L - 1 + n], + delta_el, + delta_el[:, L - 1 + n], ) ) @@ -184,7 +186,7 @@ def forward_transform( f: np.ndarray, L: int, N: int, - DW: np.ndarray = None, + precomps: tuple[np.ndarray, np.ndarray] | None = None, reality: bool = False, sampling: str = "mw", ) -> np.ndarray: @@ -195,9 +197,9 @@ def forward_transform( f (np.ndarray): Function sampled on the rotation group. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. - DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. - Defaults to None. + precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the + reduced Wigner d-functions and the corresponding upsampled quadrature + weights. Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -253,43 +255,38 @@ def forward_transform( # the weights are conjugate but applied flipped and therefore are # equivalent. To avoid flipping here we simply conjugate the weights. - # PRECOMPUTE TRANSFORM - if DW is not None: - # EXTRACT VARIOUS PRECOMPUTES - Delta, Quads = DW - - # APPLY QUADRATURE - x = np.einsum("nbm,b->nbm", x, Quads) - - # COMPUTE GMM BY FFT - x = np.fft.fft(x, axis=1, norm="forward") - x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] - - # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) - - # OTF TRANSFORM + if precomps is not None: + # PRECOMPUTE TRANSFORM + delta, quads = precomps else: + # OTF TRANSFORM + delta = None # COMPUTE QUADRATURE WEIGHTS - Quads = np.zeros(4 * L - 3, dtype=np.complex128) + quads = np.zeros(4 * L - 3, dtype=np.complex128) for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): - Quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm) - Quads = np.fft.ifft(np.fft.ifftshift(Quads), norm="forward") + quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm) + quads = np.fft.ifft(np.fft.ifftshift(quads), norm="forward") - # APPLY QUADRATURE - x = np.einsum("nbm,b->nbm", x, Quads) + # APPLY QUADRATURE + x = np.einsum("nbm,b->nbm", x, quads) - # COMPUTE GMM BY FFT - x = np.fft.fft(x, axis=1, norm="forward") - x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + # COMPUTE GMM BY FFT + x = np.fft.fft(x, axis=1, norm="forward") + x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] - # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) + # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt + if delta is not None: + # PRECOMPUTE TRANSFORM + x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n]) + else: + # OTF TRANSFORM + delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype) for el in range(L): - Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el) - xx[:, el] = np.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n]) + delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el) + xx[:, el] = np.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n]) x = xx + x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n)) # SYMMETRY REFLECT FOR N < 0 @@ -310,7 +307,7 @@ def forward_transform_jax( f: jnp.ndarray, L: int, N: int, - DW: jnp.ndarray = None, + precomps: tuple[jnp.ndarray, jnp.ndarray] | None = None, reality: bool = False, sampling: str = "mw", ) -> jnp.ndarray: @@ -321,9 +318,9 @@ def forward_transform_jax( f (jnp.ndarray): Function sampled on the rotation group. L (int): Harmonic band-limit. N (int): Azimuthal band-limit. - DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced - Wigner d-functions and the corresponding upsampled quadrature weights. - Defaults to None. + precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the + reduced Wigner d-functions and the corresponding upsampled quadrature + weights. Defaults to None. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -379,41 +376,37 @@ def forward_transform_jax( # the weights are conjugate but applied flipped and therefore are # equivalent. To avoid flipping here we simply conjugate the weights. - # PRECOMPUTE TRANSFORM - if DW is not None: - # EXTRACT VARIOUS PRECOMPUTES - Delta, Quads = DW - - # APPLY QUADRATURE - x = jnp.einsum("nbm,b->nbm", x, Quads) - - # COMPUTE GMM BY FFT - x = jnp.fft.fft(x, axis=1, norm="forward") - x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] - - # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) - + if precomps is not None: + # PRECOMPUTE TRANSFORM + delta, quads = precomps else: - Quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128) + # OTF TRANSFORM + delta = None + # COMPUTE QUADRATURE WEIGHTS + quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128) for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): - Quads = Quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm)) - Quads = jnp.fft.ifft(jnp.fft.ifftshift(Quads), norm="forward") + quads = quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm)) + quads = jnp.fft.ifft(jnp.fft.ifftshift(quads), norm="forward") - # APPLY QUADRATURE - x = jnp.einsum("nbm,b->nbm", x, Quads) + # APPLY QUADRATURE + x = jnp.einsum("nbm,b->nbm", x, quads) - # COMPUTE GMM BY FFT - x = jnp.fft.fft(x, axis=1, norm="forward") - x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + # COMPUTE GMM BY FFT + x = jnp.fft.fft(x, axis=1, norm="forward") + x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] - # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt - Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) + # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt + if delta is not None: + # PRECOMPUTE TRANSFORM + x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n]) + else: + # OTF TRANSFORM + delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype) for el in range(L): - Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el) + delta_el = recursions.risbo_jax.compute_full(delta_el, jnp.pi / 2, L, el) xx = xx.at[:, el].set( - jnp.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n]) + jnp.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n]) ) x = xx diff --git a/tests/test_fourier_wigner.py b/tests/test_fourier_wigner.py index 9fa9414f..d36980e3 100644 --- a/tests/test_fourier_wigner.py +++ b/tests/test_fourier_wigner.py @@ -47,15 +47,16 @@ def test_inverse_fourier_wigner_transform( ) f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) - delta = None transform = fw.inverse_transform_jax if method == "jax" else fw.inverse_transform if delta_method.lower() == "precomp": - delta = ( + precomps = ( c.fourier_wigner_kernel_jax(L) if method == "jax" else c.fourier_wigner_kernel(L) ) - f_check = transform(flmn, L, N, delta, reality, sampling) + else: + precomps = None + f_check = transform(flmn, L, N, precomps, reality, sampling) np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) @@ -91,15 +92,16 @@ def test_forward_fourier_wigner_transform( ) flmn = samples.flmn_1d_to_3d(so3.forward(f, params), L, N) - delta = None transform = fw.forward_transform_jax if method == "jax" else fw.forward_transform if delta_method.lower() == "precomp": - delta = ( + precomps = ( c.fourier_wigner_kernel_jax(L) if method == "jax" else c.fourier_wigner_kernel(L) ) - flmn_check = transform(f_3D, L, N, delta, reality, sampling) + else: + precomps = None + flmn_check = transform(f_3D, L, N, precomps, reality, sampling) np.testing.assert_allclose(flmn, flmn_check, atol=atol) @@ -121,8 +123,8 @@ def test_inverse_fourier_wigner_transform_high_N( f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) f = f.real if reality else f - delta = c.fourier_wigner_kernel(L) - f_check = fw.inverse_transform(flmn, L, N, delta, reality, sampling) + precomps = c.fourier_wigner_kernel(L) + f_check = fw.inverse_transform(flmn, L, N, precomps, reality, sampling) np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) @@ -151,6 +153,6 @@ def test_forward_fourier_wigner_transform_high_N( ) flmn_so3 = samples.flmn_1d_to_3d(so3.forward(f_1D, params), L, N) - delta = c.fourier_wigner_kernel_jax(L) - flmn_check = fw.forward_transform_jax(f_3D, L, N, delta, reality, sampling) + precomps = c.fourier_wigner_kernel_jax(L) + flmn_check = fw.forward_transform_jax(f_3D, L, N, precomps, reality, sampling) np.testing.assert_allclose(flmn_so3, flmn_check, atol=atol) diff --git a/tests/test_lifting_transforms.py b/tests/test_lifting_transforms.py index 5e36f560..7e3dafc7 100644 --- a/tests/test_lifting_transforms.py +++ b/tests/test_lifting_transforms.py @@ -27,8 +27,8 @@ def test_custom_forward_from_s2( ): # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) - DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, L, N, DW, False, sampling) + precomps = c.fourier_wigner_kernel(L) + f = fw.inverse_transform(flmn, L, N, precomps, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH @@ -42,27 +42,27 @@ def test_custom_forward_from_s2( fn = fn.reshape((1,) + fn.shape + (1,)) # TEST: ALL UNIQUE SPINS - flmn_test = func(fn, spins, DW, L, sampling) + flmn_test = func(fn, spins, precomps, L, sampling) np.testing.assert_allclose(flmn, np.squeeze(flmn_test), atol=atol) # TEST: A SINGLE SPIN - flmn_test = func(fn[:, [0]], spins[[0]], DW, L, sampling) + flmn_test = func(fn[:, [0]], spins[[0]], precomps, L, sampling) np.testing.assert_allclose(flmn[0], np.squeeze(flmn_test), atol=atol) # TEST: SUBSET OF SPINS - flmn_test = func(fn[:, ::2], spins[::2], DW, L, sampling) + flmn_test = func(fn[:, ::2], spins[::2], precomps, L, sampling) np.testing.assert_allclose(flmn[::2], np.squeeze(flmn_test), atol=atol) # TEST: REPEATED SPINS fn_repeat = np.concatenate([fn, fn], axis=1) spins_repeat = np.concatenate([spins, spins]) - flmn_test = func(fn_repeat, spins_repeat, DW, L, sampling) + flmn_test = func(fn_repeat, spins_repeat, precomps, L, sampling) np.testing.assert_allclose(flmn, np.squeeze(flmn_test[:, : len(spins)]), atol=atol) np.testing.assert_allclose(flmn, np.squeeze(flmn_test[:, len(spins) :]), atol=atol) # TEST: SIMULATED BATCHING fnb = np.concatenate([fn, fn], axis=0) - flmn_test = func(fnb, spins, DW, L, sampling) + flmn_test = func(fnb, spins, precomps, L, sampling) for b in range(2): np.testing.assert_allclose(flmn, np.squeeze(flmn_test[b]), atol=atol) @@ -76,8 +76,8 @@ def test_custom_forward_from_so3( ): # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) - DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, L, N, DW, False, sampling) + precomps = c.fourier_wigner_kernel(L) + f = fw.inverse_transform(flmn, L, N, precomps, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH @@ -87,26 +87,26 @@ def test_custom_forward_from_so3( f = f.reshape((1,) + f.shape + (1,)) # TEST: ALL UNIQUE SPINS - flmn_test = func(f, spins, DW, L, N, sampling) + flmn_test = func(f, spins, precomps, L, N, sampling) np.testing.assert_allclose(flmn, np.squeeze(flmn_test), atol=atol) # TEST: A SINGLE SPIN - flmn_test = func(f, spins[[0]], DW, L, N, sampling) + flmn_test = func(f, spins[[0]], precomps, L, N, sampling) np.testing.assert_allclose(flmn[0], np.squeeze(flmn_test), atol=atol) # TEST: SUBSET OF SPINS - flmn_test = func(f, spins[::2], DW, L, N, sampling) + flmn_test = func(f, spins[::2], precomps, L, N, sampling) np.testing.assert_allclose(flmn[::2], np.squeeze(flmn_test), atol=atol) # TEST: REPEATED SPINS spins_repeat = np.concatenate([spins, spins]) - flmn_test = func(f, spins_repeat, DW, L, N, sampling) + flmn_test = func(f, spins_repeat, precomps, L, N, sampling) np.testing.assert_allclose(flmn, np.squeeze(flmn_test[:, : len(spins)]), atol=atol) np.testing.assert_allclose(flmn, np.squeeze(flmn_test[:, len(spins) :]), atol=atol) # TEST: SIMULATED BATCHING fb = np.concatenate([f, f], axis=0) - flmn_test = func(fb, spins, DW, L, N, sampling) + flmn_test = func(fb, spins, precomps, L, N, sampling) for b in range(2): np.testing.assert_allclose(flmn, np.squeeze(flmn_test[b]), atol=atol) @@ -120,8 +120,8 @@ def test_custom_inverse_to_s2( ): # GENERATE MOCK SIGNAL flmn = flmn_generator(L=L, N=N) - DW = c.fourier_wigner_kernel(L) - f = fw.inverse_transform(flmn, L, N, DW, False, sampling) + precomps = c.fourier_wigner_kernel(L) + f = fw.inverse_transform(flmn, L, N, precomps, False, sampling) spins = -np.arange(-N + 1, N) # FUNCTION SWITCH @@ -135,26 +135,26 @@ def test_custom_inverse_to_s2( flmn = flmn.reshape((1,) + flmn.shape + (1,)) # TEST: ALL UNIQUE SPINS - f_test = func(flmn, spins, DW, L, sampling) + f_test = func(flmn, spins, precomps, L, sampling) np.testing.assert_allclose(fn, np.squeeze(f_test), atol=atol) # TEST: A SINGLE SPIN - f_test = func(flmn[:, [0]], spins[[0]], DW, L, sampling) + f_test = func(flmn[:, [0]], spins[[0]], precomps, L, sampling) np.testing.assert_allclose(fn[0], np.squeeze(f_test), atol=atol) # TEST: SUBSET OF SPINS - f_test = func(flmn[:, ::2], spins[::2], DW, L, sampling) + f_test = func(flmn[:, ::2], spins[::2], precomps, L, sampling) np.testing.assert_allclose(fn[::2], np.squeeze(f_test), atol=atol) # TEST: REPEATED SPINS flmn_repeat = np.concatenate([flmn, flmn], axis=1) spins_repeat = np.concatenate([spins, spins]) - f_test = func(flmn_repeat, spins_repeat, DW, L, sampling) + f_test = func(flmn_repeat, spins_repeat, precomps, L, sampling) np.testing.assert_allclose(fn, np.squeeze(f_test[:, : len(spins)]), atol=atol) np.testing.assert_allclose(fn, np.squeeze(f_test[:, len(spins) :]), atol=atol) # TEST: SIMULATED BATCHING flmnb = np.concatenate([flmn, flmn], axis=0) - f_test = func(flmnb, spins, DW, L, sampling) + f_test = func(flmnb, spins, precomps, L, sampling) for b in range(2): np.testing.assert_allclose(fn, np.squeeze(f_test[b]), atol=atol)