@@ -172,23 +172,51 @@ def test_fit_to_bic_numerical_stability():
172
172
np .testing .assert_allclose (predictions , expected_predictions )
173
173
174
174
175
- @pytest .mark .parametrize (
176
- "coefficients" ,
177
- [
178
- np .array ([0 , - 1 , 0 , - 2 ]),
179
- np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]),
180
- np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ]),
181
- ],
182
- )
183
- def test_fit_to_bic_xr (coefficients ):
184
- yearly_predictor = trend_data_2D (n_timesteps = 10 , n_lat = 3 , n_lon = 2 )
175
+ def get_2D_coefficients (order_per_cell , n_lat = 3 , n_lon = 2 ):
176
+ n_cells = n_lat * n_lon
177
+ max_order = 6
178
+
179
+ # generate coefficients that resemble real ones
180
+ # generate rapidly decreasing coefficients for increasing orders
181
+ trend = np .repeat (np .linspace (1.2 , 0.2 , max_order ) ** 2 , 4 )
182
+ # the first coefficients are rather small (scaling of seasonal variability with temperature change)
183
+ # while the second ones are large (constant distance of each month from the yearly mean)
184
+ scale = np .tile ([0.01 , 5.0 ], (n_cells , max_order * 2 ))
185
+ # generate some variability so not all coefficients are exactly the same
186
+ rng = np .random .default_rng (0 )
187
+ variability = rng .normal (loc = 0 , scale = 0.1 , size = (n_cells , max_order * 4 ))
188
+ # put it together
189
+ coeffs = trend * scale + variability
190
+ coeffs = np .round (coeffs , 1 )
191
+
192
+ # replace superfluous orders with nans
193
+ for cell , order in enumerate (order_per_cell ):
194
+ coeffs [cell , order * 4 :] = np .nan
195
+
196
+ LON , LAT = np .meshgrid (np .arange (n_lon ), np .arange (n_lat ))
197
+
198
+ coords = {
199
+ "lon" : ("cells" , LON .flatten ()),
200
+ "lat" : ("cells" , LAT .flatten ()),
201
+ }
202
+
203
+ return xr .DataArray (coeffs , dims = ("cells" , "coeff" ), coords = coords )
204
+
205
+
206
+ def test_fit_to_bic_xr ():
207
+ n_ts = 10
208
+ orders = [1 , 2 , 3 , 4 , 5 , 6 ]
209
+
210
+ coefficients = get_2D_coefficients (order_per_cell = orders , n_lat = 3 , n_lon = 2 )
211
+
212
+ yearly_predictor = trend_data_2D (n_timesteps = n_ts , n_lat = 3 , n_lon = 2 )
185
213
186
214
freq = "AS" if Version (pd .__version__ ) < Version ("2.2" ) else "YS"
187
215
yearly_predictor ["time" ] = xr .cftime_range (
188
- start = "2000-01-01" , periods = 10 , freq = freq
216
+ start = "2000-01-01" , periods = n_ts , freq = freq
189
217
)
190
218
191
- time = xr .cftime_range (start = "2000-01-01" , periods = 10 * 12 , freq = "MS" )
219
+ time = xr .cftime_range (start = "2000-01-01" , periods = n_ts * 12 , freq = "MS" )
192
220
monthly_time = xr .DataArray (
193
221
time ,
194
222
dims = ["time" ],
@@ -200,17 +228,27 @@ def test_fit_to_bic_xr(coefficients):
200
228
monthly_target = xr .apply_ufunc (
201
229
generate_fourier_series_np ,
202
230
upsampled_yearly_predictor ,
203
- input_core_dims = [["time" ]],
231
+ coefficients ,
232
+ input_core_dims = [["time" ], ["coeff" ]],
204
233
output_core_dims = [["time" ]],
205
234
vectorize = True ,
206
235
output_dtypes = [float ],
207
- kwargs = {"coeffs" : coefficients , " months" : months },
236
+ kwargs = {"months" : months },
208
237
)
209
238
239
+ # test if the model can recover the monthly target from perfect fourier series
210
240
result = fit_to_bic_xr (yearly_predictor , monthly_target )
211
-
241
+ np . testing . assert_equal ( result . n_sel . values , orders )
212
242
xr .testing .assert_allclose (result ["predictions" ], monthly_target , atol = 0.1 )
213
243
244
+ # test if the model can recover the underlying cycle with noise on top of monthly target
245
+ rng = np .random .default_rng (0 )
246
+ noisy_monthly_target = monthly_target + rng .normal (
247
+ loc = 0 , scale = 0.1 , size = monthly_target .values .shape
248
+ )
249
+ result = fit_to_bic_xr (yearly_predictor , noisy_monthly_target )
250
+ xr .testing .assert_allclose (result ["predictions" ], monthly_target , atol = 0.2 )
251
+
214
252
215
253
def test_fit_to_bix_xr_instance_checks ():
216
254
yearly_predictor = trend_data_2D (n_timesteps = 10 , n_lat = 3 , n_lon = 2 )
0 commit comments