1+ from __future__ import annotations
2+
13from functools import partial
24
35import jax .numpy as jnp
@@ -12,7 +14,7 @@ def inverse_transform(
1214 flmn : np .ndarray ,
1315 L : int ,
1416 N : int ,
15- DW : np .ndarray = None ,
17+ precomps : tuple [ np .ndarray , np . ndarray ] | None = None ,
1618 reality : bool = False ,
1719 sampling : str = "mw" ,
1820) -> np .ndarray :
@@ -23,9 +25,9 @@ def inverse_transform(
2325 flmn (np.ndarray): Wigner coefficients.
2426 L (int): Harmonic band-limit.
2527 N (int): Azimuthal band-limit.
26- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
27- Wigner d-functions and the corresponding upsampled quadrature weights.
28- Defaults to None.
28+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
29+ reduced Wigner d-functions and the corresponding upsampled quadrature
30+ weights. Defaults to None.
2931 reality (bool, optional): Whether the signal on the sphere is real. If so,
3032 conjugate symmetry is exploited to reduce computational costs.
3133 Defaults to False.
@@ -53,28 +55,28 @@ def inverse_transform(
5355 m = np .arange (- L + 1 - m_offset , L )
5456 n = np .arange (n_start_ind - N + 1 , N )
5557
56- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
58+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
5759 x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
5860 flmn = np .einsum ("nlm,l->nlm" , flmn , (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ))
5961
6062 # PRECOMPUTE TRANSFORM
61- if DW is not None :
62- Delta , _ = DW
63+ if precomps is not None :
64+ delta , _ = precomps
6365 x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
6466 x [m_offset :, m_offset :] = np .einsum (
65- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
67+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
6668 )
6769
6870 # OTF TRANSFORM
6971 else :
70- Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
72+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
7173 for el in range (L ):
72- Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
74+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
7375 x [m_offset :, m_offset :] += np .einsum (
7476 "nm,am,an->amn" ,
7577 flmn [n_start_ind :, el ],
76- Delta_el ,
77- Delta_el [:, L - 1 + n ],
78+ delta_el ,
79+ delta_el [:, L - 1 + n ],
7880 )
7981
8082 # APPLY SIGN FUNCTION AND PHASE SHIFT
@@ -97,7 +99,7 @@ def inverse_transform_jax(
9799 flmn : jnp .ndarray ,
98100 L : int ,
99101 N : int ,
100- DW : jnp .ndarray = None ,
102+ precomps : tuple [ jnp .ndarray , jnp . ndarray ] | None = None ,
101103 reality : bool = False ,
102104 sampling : str = "mw" ,
103105) -> jnp .ndarray :
@@ -108,9 +110,9 @@ def inverse_transform_jax(
108110 flmn (jnp.ndarray): Wigner coefficients.
109111 L (int): Harmonic band-limit.
110112 N (int): Azimuthal band-limit.
111- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
112- Wigner d-functions and the corresponding upsampled quadrature weights.
113- Defaults to None.
113+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
114+ reduced Wigner d-functions and the corresponding upsampled quadrature
115+ weights. Defaults to None.
114116 reality (bool, optional): Whether the signal on the sphere is real. If so,
115117 conjugate symmetry is exploited to reduce computational costs.
116118 Defaults to False.
@@ -138,30 +140,30 @@ def inverse_transform_jax(
138140 m = jnp .arange (- L + 1 - m_offset , L )
139141 n = jnp .arange (n_start_ind - N + 1 , N )
140142
141- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
143+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
142144 x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
143145 flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
144146
145147 # PRECOMPUTE TRANSFORM
146- if DW is not None :
147- Delta , _ = DW
148+ if precomps is not None :
149+ delta , _ = precomps
148150 x = x .at [m_offset :, m_offset :].set (
149151 jnp .einsum (
150- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
152+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
151153 )
152154 )
153155
154156 # OTF TRANSFORM
155157 else :
156- Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
158+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
157159 for el in range (L ):
158- Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
160+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
159161 x = x .at [m_offset :, m_offset :].add (
160162 jnp .einsum (
161163 "nm,am,an->amn" ,
162164 flmn [n_start_ind :, el ],
163- Delta_el ,
164- Delta_el [:, L - 1 + n ],
165+ delta_el ,
166+ delta_el [:, L - 1 + n ],
165167 )
166168 )
167169
@@ -184,7 +186,7 @@ def forward_transform(
184186 f : np .ndarray ,
185187 L : int ,
186188 N : int ,
187- DW : np .ndarray = None ,
189+ precomps : tuple [ np .ndarray , np . ndarray ] | None = None ,
188190 reality : bool = False ,
189191 sampling : str = "mw" ,
190192) -> np .ndarray :
@@ -195,9 +197,9 @@ def forward_transform(
195197 f (np.ndarray): Function sampled on the rotation group.
196198 L (int): Harmonic band-limit.
197199 N (int): Azimuthal band-limit.
198- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
199- Wigner d-functions and the corresponding upsampled quadrature weights.
200- Defaults to None.
200+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
201+ reduced Wigner d-functions and the corresponding upsampled quadrature
202+ weights. Defaults to None.
201203 reality (bool, optional): Whether the signal on the sphere is real. If so,
202204 conjugate symmetry is exploited to reduce computational costs.
203205 Defaults to False.
@@ -253,43 +255,38 @@ def forward_transform(
253255 # the weights are conjugate but applied flipped and therefore are
254256 # equivalent. To avoid flipping here we simply conjugate the weights.
255257
256- # PRECOMPUTE TRANSFORM
257- if DW is not None :
258- # EXTRACT VARIOUS PRECOMPUTES
259- Delta , Quads = DW
260-
261- # APPLY QUADRATURE
262- x = np .einsum ("nbm,b->nbm" , x , Quads )
263-
264- # COMPUTE GMM BY FFT
265- x = np .fft .fft (x , axis = 1 , norm = "forward" )
266- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
267-
268- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
269- x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
270-
271- # OTF TRANSFORM
258+ if precomps is not None :
259+ # PRECOMPUTE TRANSFORM
260+ delta , quads = precomps
272261 else :
262+ # OTF TRANSFORM
263+ delta = None
273264 # COMPUTE QUADRATURE WEIGHTS
274- Quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
265+ quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
275266 for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
276- Quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
277- Quads = np .fft .ifft (np .fft .ifftshift (Quads ), norm = "forward" )
267+ quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
268+ quads = np .fft .ifft (np .fft .ifftshift (quads ), norm = "forward" )
278269
279- # APPLY QUADRATURE
280- x = np .einsum ("nbm,b->nbm" , x , Quads )
270+ # APPLY QUADRATURE
271+ x = np .einsum ("nbm,b->nbm" , x , quads )
281272
282- # COMPUTE GMM BY FFT
283- x = np .fft .fft (x , axis = 1 , norm = "forward" )
284- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
273+ # COMPUTE GMM BY FFT
274+ x = np .fft .fft (x , axis = 1 , norm = "forward" )
275+ x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
285276
286- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
287- Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
277+ # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+ if delta is not None :
279+ # PRECOMPUTE TRANSFORM
280+ x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
281+ else :
282+ # OTF TRANSFORM
283+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
288284 xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
289285 for el in range (L ):
290- Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
291- xx [:, el ] = np .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
286+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
287+ xx [:, el ] = np .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
292288 x = xx
289+
293290 x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
294291
295292 # SYMMETRY REFLECT FOR N < 0
@@ -310,7 +307,7 @@ def forward_transform_jax(
310307 f : jnp .ndarray ,
311308 L : int ,
312309 N : int ,
313- DW : jnp .ndarray = None ,
310+ precomps : tuple [ jnp .ndarray , jnp . ndarray ] | None = None ,
314311 reality : bool = False ,
315312 sampling : str = "mw" ,
316313) -> jnp .ndarray :
@@ -321,9 +318,9 @@ def forward_transform_jax(
321318 f (jnp.ndarray): Function sampled on the rotation group.
322319 L (int): Harmonic band-limit.
323320 N (int): Azimuthal band-limit.
324- DW (Tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
325- Wigner d-functions and the corresponding upsampled quadrature weights.
326- Defaults to None.
321+ precomps (tuple [np.ndarray, np.ndarray], optional): Fourier coefficients of the
322+ reduced Wigner d-functions and the corresponding upsampled quadrature
323+ weights. Defaults to None.
327324 reality (bool, optional): Whether the signal on the sphere is real. If so,
328325 conjugate symmetry is exploited to reduce computational costs.
329326 Defaults to False.
@@ -379,41 +376,37 @@ def forward_transform_jax(
379376 # the weights are conjugate but applied flipped and therefore are
380377 # equivalent. To avoid flipping here we simply conjugate the weights.
381378
382- # PRECOMPUTE TRANSFORM
383- if DW is not None :
384- # EXTRACT VARIOUS PRECOMPUTES
385- Delta , Quads = DW
386-
387- # APPLY QUADRATURE
388- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
389-
390- # COMPUTE GMM BY FFT
391- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
392- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
393-
394- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
395- x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
396-
379+ if precomps is not None :
380+ # PRECOMPUTE TRANSFORM
381+ delta , quads = precomps
397382 else :
398- Quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
383+ # OTF TRANSFORM
384+ delta = None
385+ # COMPUTE QUADRATURE WEIGHTS
386+ quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
399387 for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
400- Quads = Quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
401- Quads = jnp .fft .ifft (jnp .fft .ifftshift (Quads ), norm = "forward" )
388+ quads = quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
389+ quads = jnp .fft .ifft (jnp .fft .ifftshift (quads ), norm = "forward" )
402390
403- # APPLY QUADRATURE
404- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
391+ # APPLY QUADRATURE
392+ x = jnp .einsum ("nbm,b->nbm" , x , quads )
405393
406- # COMPUTE GMM BY FFT
407- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
408- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
394+ # COMPUTE GMM BY FFT
395+ x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
396+ x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
409397
410- # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
411- Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
398+ # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+ if delta is not None :
400+ # PRECOMPUTE TRANSFORM
401+ x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
402+ else :
403+ # OTF TRANSFORM
404+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
412405 xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
413406 for el in range (L ):
414- Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
407+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
415408 xx = xx .at [:, el ].set (
416- jnp .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
409+ jnp .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
417410 )
418411 x = xx
419412
0 commit comments