diff --git a/CHANGELOG.md b/CHANGELOG.md index c5daaeec..52a60ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). benchmarks. [@stephenworsley](https://github.com/stephenworsley) +### Fixed +- [PR#239](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/239) + Ensured dtype is preserved by regridding. + [@stephenworsley](https://github.com/stephenworsley) + ## [0.9] - 2023-11-03 ### Added diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 07d81071..5322c415 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -134,6 +134,12 @@ def __init__( self.esmf_version = None self.weight_matrix = precomputed_weights + def _out_dtype(self, in_dtype): + """Return the expected output dtype for a given input dtype.""" + weight_dtype = self.weight_matrix.dtype + out_dtype = (np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)).dtype + return out_dtype + def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): """ Perform regridding on an array of data. @@ -175,12 +181,13 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): extra_size = max(1, np.prod(extra_shape)) src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) weight_sums = self.weight_matrix @ src_inverted_mask + out_dtype = self._out_dtype(src_array.dtype) # Set the minimum mdtol to be slightly higher than 0 to account for rounding # errors. mdtol = max(mdtol, 1e-8) tgt_mask = weight_sums > 1 - mdtol masked_weight_sums = weight_sums * tgt_mask - normalisations = np.ones([self.tgt.size, extra_size]) + normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype) if norm_type == Constants.NormType.FRACAREA: normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] elif norm_type == Constants.NormType.DSTAREA: diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index bd60e9f9..2991b36f 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -275,7 +275,9 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): return result -def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): +def _map_complete_blocks( + src, func, active_dims, out_sizes, *args, dtype=None, **kwargs +): """ Apply a function to complete blocks. @@ -299,6 +301,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): Dimensions that cannot be chunked. out_sizes : tuple of int Output size of dimensions that cannot be chunked. + dtype : type, optional + Type of the output array, if not given, the dtype of src is used. Returns ------- @@ -311,6 +315,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): return func(src.data, *args, **kwargs) data = src.lazy_data() + if dtype is None: + dtype = data.dtype # Ensure dims are not chunked in_chunks = list(data.chunks) @@ -373,7 +379,7 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): chunks=out_chunks, drop_axis=dropped_dims, new_axis=new_axis, - dtype=src.dtype, + dtype=dtype, **kwargs, ) @@ -557,6 +563,8 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol): grid_x, grid_y = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -574,6 +582,7 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol): dims=[grid_x_dim, grid_y_dim], num_out_dims=2, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -636,6 +645,8 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol): grid_x, grid_y = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -653,6 +664,7 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol): dims=[mesh_dim], num_out_dims=2, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -739,6 +751,8 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol): else: raise NotImplementedError(f"Unrecognised location {location}.") + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -751,6 +765,7 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol): dims=[grid_x_dim, grid_y_dim], num_out_dims=1, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -823,6 +838,8 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol): mesh, location = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + if location == "face": face_node = mesh.face_node_connectivity chunk_shape = (face_node.shape[face_node.location_axis],) @@ -840,6 +857,7 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol): dims=[mesh_dim], num_out_dims=1, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( diff --git a/esmf_regrid/tests/conftest.py b/esmf_regrid/tests/conftest.py new file mode 100644 index 00000000..eec32b1d --- /dev/null +++ b/esmf_regrid/tests/conftest.py @@ -0,0 +1,22 @@ +"""Common testing infrastructure.""" + +import pytest + + +@pytest.fixture(params=["float32", "float64"]) +def in_dtype(request): + """Fixture for controlling dtype.""" + return request.param + + +@pytest.fixture( + params=[ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ] +) +def src_tgt_types(request): + """Fixture for controlling type of source and target.""" + return request.param diff --git a/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py b/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py index c29d6df8..81be1302 100644 --- a/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py +++ b/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py @@ -226,3 +226,29 @@ def _get_points(bounds): (weights_dict["weights"], (weights_dict["row_dst"], weights_dict["col_src"])) ) assert np.allclose(result.toarray(), expected_weights.toarray()) + + +def test_Regridder_dtype_handling(): + """ + Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`. + + Tests that dtype is handled as expected. + """ + lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3) + src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) + + lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2) + tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) + + # Set up the regridder with precomputed weights. + rg_64 = Regridder(src_grid, tgt_grid, precomputed_weights=_expected_weights()) + weights_32 = _expected_weights().astype(np.float32) + rg_32 = Regridder(src_grid, tgt_grid, precomputed_weights=weights_32) + + src_32 = np.ones([3, 2], dtype=np.float32) + src_64 = np.ones([3, 2], dtype=np.float64) + + assert rg_64.regrid(src_64).dtype == np.float64 + assert rg_64.regrid(src_32).dtype == np.float64 + assert rg_32.regrid(src_64).dtype == np.float64 + assert rg_32.regrid(src_32).dtype == np.float32 diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index 22fc55be..c51deffd 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -1,5 +1,6 @@ """Unit tests for `esmf_regrid.schemes`.""" +import dask.array as da from iris.coord_systems import OSGB import numpy as np from numpy import ma @@ -215,3 +216,39 @@ def _test_non_degree_crs(scheme): # Check that the number of masked points is as expected. assert (1 - result.data.mask).sum() == expected_unmasked + + +def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): + """Test regridding scheme handles dtype as expected.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if in_dtype == "float32": + dtype = np.float32 + elif in_dtype == "float64": + dtype = np.float64 + + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + src_data = np.zeros([n_lats_src, n_lons_src], dtype=dtype) + src.data = da.array(src_data) + elif src_type == "mesh": + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + src_data = np.zeros([n_lats_src * n_lons_src], dtype=dtype) + src.data = da.array(src_data) + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + + result = src.regrid(tgt, scheme()) + + expected_dtype = np.float64 + + assert result.has_lazy_data() + + assert result.lazy_data().dtype == expected_dtype + assert result.data.dtype == expected_dtype diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index e73bee5d..50af91b3 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -5,6 +5,7 @@ from esmf_regrid.schemes import ESMFAreaWeighted from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, + _test_dtype_handling, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -74,3 +75,9 @@ def test_invalid_tgt_location(): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFAreaWeighted) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index 393bbe0a..f6a9e30f 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -5,6 +5,7 @@ from esmf_regrid.schemes import ESMFBilinear from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, + _test_dtype_handling, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -63,3 +64,9 @@ def test_mask_from_regridder(mask_keyword): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFBilinear) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py index b5812b5a..6f146e6c 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py @@ -6,6 +6,7 @@ from esmf_regrid.schemes import ESMFNearest from esmf_regrid.tests.unit.schemes.__init__ import ( + _test_dtype_handling, _test_mask_from_init, _test_mask_from_regridder, _test_non_degree_crs, @@ -106,3 +107,9 @@ def test_mask_from_regridder(mask_keyword): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFNearest) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py b/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py index 5670d71f..115de52a 100644 --- a/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py @@ -167,7 +167,9 @@ def test_laziness(src_transposed, tgt_transposed): lat_bounds = (-90, 90) grid = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) - src_data = np.arange(n_lats * n_lons * h).reshape([n_lats, n_lons, h]) + src_data = np.arange(n_lats * n_lons * h, dtype=np.float32).reshape( + [n_lats, n_lons, h] + ) src_data = da.from_array(src_data, chunks=[3, 5, 1]) src = Cube(src_data) src.add_dim_coord(grid.coord("latitude"), 0) @@ -185,6 +187,8 @@ def test_laziness(src_transposed, tgt_transposed): assert src.has_lazy_data() result = regrid_rectilinear_to_rectilinear(src, tgt) assert result.has_lazy_data() + assert result.lazy_data().dtype == np.float64 + assert result.data.dtype == np.float64 assert np.allclose(result.data, src_data) @@ -227,7 +231,7 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed): extra = AuxCoord(np.arange(e), long_name="extra dim") src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ + src_data[:] = np.arange(t * h * e, dtype=np.float32).reshape([h, t, e])[ :, np.newaxis, :, np.newaxis, : ] src_data_lazy = da.array(src_data) @@ -253,6 +257,8 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed): result_lazy = regrid_rectilinear_to_rectilinear(src_cube_lazy, tgt_grid) assert result_lazy.has_lazy_data() + assert result.lazy_data().dtype == np.float64 + assert result.data.dtype == np.float64 assert result_lazy == result