Skip to content

Commit dc60db6

Browse files
veni-vidi-vici-dormivipre-commit-ci[bot]mathause
authored
expand harmonic model xarray test (#458)
* add more sophisticated xarray test * delete Shrutis test file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mathias Hauser <[email protected]>
1 parent c5df229 commit dc60db6

File tree

2 files changed

+53
-268
lines changed

2 files changed

+53
-268
lines changed

mesmer/mesmer_m/tests_harmonic_model.py

-253
This file was deleted.

tests/unit/test_harmonic_model.py

+53-15
Original file line numberDiff line numberDiff line change
@@ -172,23 +172,51 @@ def test_fit_to_bic_numerical_stability():
172172
np.testing.assert_allclose(predictions, expected_predictions)
173173

174174

175-
@pytest.mark.parametrize(
176-
"coefficients",
177-
[
178-
np.array([0, -1, 0, -2]),
179-
np.array([1, 2, 3, 4, 5, 6, 7, 8]),
180-
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
181-
],
182-
)
183-
def test_fit_to_bic_xr(coefficients):
184-
yearly_predictor = trend_data_2D(n_timesteps=10, n_lat=3, n_lon=2)
175+
def get_2D_coefficients(order_per_cell, n_lat=3, n_lon=2):
176+
n_cells = n_lat * n_lon
177+
max_order = 6
178+
179+
# generate coefficients that resemble real ones
180+
# generate rapidly decreasing coefficients for increasing orders
181+
trend = np.repeat(np.linspace(1.2, 0.2, max_order) ** 2, 4)
182+
# the first coefficients are rather small (scaling of seasonal variability with temperature change)
183+
# while the second ones are large (constant distance of each month from the yearly mean)
184+
scale = np.tile([0.01, 5.0], (n_cells, max_order * 2))
185+
# generate some variability so not all coefficients are exactly the same
186+
rng = np.random.default_rng(0)
187+
variability = rng.normal(loc=0, scale=0.1, size=(n_cells, max_order * 4))
188+
# put it together
189+
coeffs = trend * scale + variability
190+
coeffs = np.round(coeffs, 1)
191+
192+
# replace superfluous orders with nans
193+
for cell, order in enumerate(order_per_cell):
194+
coeffs[cell, order * 4 :] = np.nan
195+
196+
LON, LAT = np.meshgrid(np.arange(n_lon), np.arange(n_lat))
197+
198+
coords = {
199+
"lon": ("cells", LON.flatten()),
200+
"lat": ("cells", LAT.flatten()),
201+
}
202+
203+
return xr.DataArray(coeffs, dims=("cells", "coeff"), coords=coords)
204+
205+
206+
def test_fit_to_bic_xr():
207+
n_ts = 10
208+
orders = [1, 2, 3, 4, 5, 6]
209+
210+
coefficients = get_2D_coefficients(order_per_cell=orders, n_lat=3, n_lon=2)
211+
212+
yearly_predictor = trend_data_2D(n_timesteps=n_ts, n_lat=3, n_lon=2)
185213

186214
freq = "AS" if Version(pd.__version__) < Version("2.2") else "YS"
187215
yearly_predictor["time"] = xr.cftime_range(
188-
start="2000-01-01", periods=10, freq=freq
216+
start="2000-01-01", periods=n_ts, freq=freq
189217
)
190218

191-
time = xr.cftime_range(start="2000-01-01", periods=10 * 12, freq="MS")
219+
time = xr.cftime_range(start="2000-01-01", periods=n_ts * 12, freq="MS")
192220
monthly_time = xr.DataArray(
193221
time,
194222
dims=["time"],
@@ -200,17 +228,27 @@ def test_fit_to_bic_xr(coefficients):
200228
monthly_target = xr.apply_ufunc(
201229
generate_fourier_series_np,
202230
upsampled_yearly_predictor,
203-
input_core_dims=[["time"]],
231+
coefficients,
232+
input_core_dims=[["time"], ["coeff"]],
204233
output_core_dims=[["time"]],
205234
vectorize=True,
206235
output_dtypes=[float],
207-
kwargs={"coeffs": coefficients, "months": months},
236+
kwargs={"months": months},
208237
)
209238

239+
# test if the model can recover the monthly target from perfect fourier series
210240
result = fit_to_bic_xr(yearly_predictor, monthly_target)
211-
241+
np.testing.assert_equal(result.n_sel.values, orders)
212242
xr.testing.assert_allclose(result["predictions"], monthly_target, atol=0.1)
213243

244+
# test if the model can recover the underlying cycle with noise on top of monthly target
245+
rng = np.random.default_rng(0)
246+
noisy_monthly_target = monthly_target + rng.normal(
247+
loc=0, scale=0.1, size=monthly_target.values.shape
248+
)
249+
result = fit_to_bic_xr(yearly_predictor, noisy_monthly_target)
250+
xr.testing.assert_allclose(result["predictions"], monthly_target, atol=0.2)
251+
214252

215253
def test_fit_to_bix_xr_instance_checks():
216254
yearly_predictor = trend_data_2D(n_timesteps=10, n_lat=3, n_lon=2)

0 commit comments

Comments
 (0)