Skip to content
8 changes: 6 additions & 2 deletions lib/iris/analysis/_area_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __init__(self, src_grid_cube, target_grid_cube, mdtol=1):
# current usage of the experimental regrid function.
self._target_grid_cube_cache = None

self._regrid_info = eregrid._regrid_area_weighted_rectilinear_src_and_grid__prepare(
src_grid_cube, self._target_grid_cube
)

@property
def _target_grid_cube(self):
if self._target_grid_cube_cache is None:
Expand Down Expand Up @@ -97,6 +101,6 @@ def __call__(self, cube):
"The given cube is not defined on the same "
"source grid as this regridder."
)
return eregrid.regrid_area_weighted_rectilinear_src_and_grid(
cube, self._target_grid_cube, mdtol=self._mdtol
return eregrid._regrid_area_weighted_rectilinear_src_and_grid__perform(
cube, self._regrid_info, mdtol=self._mdtol
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ def extract_grid(self, cube):

def check_mdtol(self, mdtol=None):
src_grid, target_grid = self.grids()
if mdtol is None:
regridder = AreaWeightedRegridder(src_grid, target_grid)
mdtol = 1
else:
regridder = AreaWeightedRegridder(
src_grid, target_grid, mdtol=mdtol
)
with mock.patch(
"iris.experimental.regrid."
"_regrid_area_weighted_rectilinear_src_and_grid__prepare"
) as prepare:
if mdtol is None:
regridder = AreaWeightedRegridder(src_grid, target_grid)
mdtol = 1
else:
regridder = AreaWeightedRegridder(
src_grid, target_grid, mdtol=mdtol
)

# Make a new cube to regrid with different data so we can
# distinguish between regridding the original src grid
Expand All @@ -58,18 +62,22 @@ def check_mdtol(self, mdtol=None):

with mock.patch(
"iris.experimental.regrid."
"regrid_area_weighted_rectilinear_src_and_grid",
"_regrid_area_weighted_rectilinear_src_and_grid__perform",
return_value=mock.sentinel.result,
) as regrid:
) as perform:
result = regridder(src)

self.assertEqual(regrid.call_count, 1)
_, args, kwargs = regrid.mock_calls[0]

self.assertEqual(args[0], src)
# Prepare:
self.assertEqual(prepare.call_count, 1)
_, args, kwargs = prepare.mock_calls[0]
self.assertEqual(
self.extract_grid(args[1]), self.extract_grid(target_grid)
)

# Perform:
self.assertEqual(perform.call_count, 1)
_, args, kwargs = perform.mock_calls[0]
self.assertEqual(args[0], src)
self.assertEqual(kwargs, {"mdtol": mdtol})
self.assertIs(result, mock.sentinel.result)

Expand Down