-
Notifications
You must be signed in to change notification settings - Fork 19
Regridder load/saving #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
bjlittle
merged 29 commits into
SciTools:unstructured_scheme
from
stephenworsley:regridder_saving
Nov 18, 2021
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
72bb985
add regridder saving
stephenworsley fe607f3
add regridder saving
stephenworsley a022853
avoid saving bug
stephenworsley 923e9ff
add docstrings, copy iris utils
stephenworsley 245d610
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d3885f0
test/lint fixes
stephenworsley 9f52c02
test fix
stephenworsley c1b2f6e
add test
stephenworsley c1f7297
test functionality
stephenworsley 6764841
add comments and tests
stephenworsley 4ce372f
flake fix
stephenworsley d2b1253
refactor tests
stephenworsley 96d9160
fix tests
stephenworsley f4031a7
fix tests
stephenworsley 2431def
fix tests
stephenworsley 29d9f1c
use pytest fixture tmp_path
stephenworsley 1e7cee4
refresh nox cache
stephenworsley 172ef4a
fix test
stephenworsley 924b135
remove temp file architecture
stephenworsley fcb850c
remove imports
stephenworsley 851fe1e
update nox cache
stephenworsley 1674c2e
fix tests
stephenworsley 7d8c4db
increment CONDA_CACHE_BUILD
stephenworsley 37fbb7f
toggle nox environment reuse
stephenworsley 927265c
toggle nox environment reuse
stephenworsley fa4ea36
determine regridder_type generically
stephenworsley 471e311
fix saver
stephenworsley d063008
fix saver
stephenworsley 11bc7f5
fix loader
stephenworsley File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| """Provides load/save functions for regridders.""" | ||
|
|
||
| import iris | ||
| from iris.coords import AuxCoord | ||
| from iris.cube import Cube, CubeList | ||
| from iris.experimental.ugrid import PARSE_UGRID_ON_LOAD | ||
| import numpy as np | ||
| import scipy.sparse | ||
|
|
||
| from esmf_regrid.experimental.unstructured_scheme import ( | ||
| GridToMeshESMFRegridder, | ||
| MeshToGridESMFRegridder, | ||
| ) | ||
|
|
||
|
|
||
| SUPPORTED_REGRIDDERS = [ | ||
| GridToMeshESMFRegridder, | ||
| MeshToGridESMFRegridder, | ||
| ] | ||
| REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS} | ||
|
|
||
|
|
||
| def save_regridder(rg, filename): | ||
| """ | ||
| Save a regridder scheme instance. | ||
|
|
||
| Saves either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| rg : GridToMeshESMFRegridder, MeshToGridESMFRegridder | ||
| The regridder instance to save. | ||
| filename : str | ||
| The file name to save to. | ||
| """ | ||
| src_name = "regridder source field" | ||
| tgt_name = "regridder target field" | ||
| regridder_type = rg.__class__.__name__ | ||
| if regridder_type == "GridToMeshESMFRegridder": | ||
| src_grid = (rg.grid_y, rg.grid_x) | ||
| src_shape = [len(coord.points) for coord in src_grid] | ||
| src_data = np.zeros(src_shape) | ||
| src_cube = Cube(src_data, long_name=src_name) | ||
| src_cube.add_dim_coord(src_grid[0], 0) | ||
| src_cube.add_dim_coord(src_grid[1], 1) | ||
|
|
||
| tgt_mesh = rg.mesh | ||
| tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0]) | ||
| tgt_cube = Cube(tgt_data, long_name=tgt_name) | ||
| for coord in tgt_mesh.to_MeshCoords("face"): | ||
| tgt_cube.add_aux_coord(coord, 0) | ||
| elif regridder_type == "MeshToGridESMFRegridder": | ||
| src_mesh = rg.mesh | ||
| src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0]) | ||
| src_cube = Cube(src_data, long_name=src_name) | ||
| for coord in src_mesh.to_MeshCoords("face"): | ||
| src_cube.add_aux_coord(coord, 0) | ||
|
|
||
| tgt_grid = (rg.grid_y, rg.grid_x) | ||
| tgt_shape = [len(coord.points) for coord in tgt_grid] | ||
| tgt_data = np.zeros(tgt_shape) | ||
| tgt_cube = Cube(tgt_data, long_name=tgt_name) | ||
| tgt_cube.add_dim_coord(tgt_grid[0], 0) | ||
| tgt_cube.add_dim_coord(tgt_grid[1], 1) | ||
| else: | ||
| msg = ( | ||
| f"Expected a regridder of type `GridToMeshESMFRegridder` or " | ||
| f"`MeshToGridESMFRegridder`, got type {regridder_type}." | ||
| ) | ||
| raise TypeError(msg) | ||
|
|
||
| metadata_name = "regridder weights and metadata" | ||
|
|
||
| weight_matrix = rg.regridder.weight_matrix | ||
| reformatted_weight_matrix = weight_matrix.tocoo() | ||
| weight_data = reformatted_weight_matrix.data | ||
| weight_rows = reformatted_weight_matrix.row | ||
| weight_cols = reformatted_weight_matrix.col | ||
| weight_shape = reformatted_weight_matrix.shape | ||
|
|
||
| mdtol = rg.mdtol | ||
| attributes = { | ||
| "regridder type": regridder_type, | ||
| "mdtol": mdtol, | ||
| "weights shape": weight_shape, | ||
| } | ||
|
|
||
| metadata_cube = Cube(weight_data, long_name=metadata_name, attributes=attributes) | ||
| row_name = "weight matrix rows" | ||
| row_coord = AuxCoord(weight_rows, long_name=row_name) | ||
| col_name = "weight matrix columns" | ||
| col_coord = AuxCoord(weight_cols, long_name=col_name) | ||
| metadata_cube.add_aux_coord(row_coord, 0) | ||
| metadata_cube.add_aux_coord(col_coord, 0) | ||
|
|
||
| # Avoid saving bug by placing the mesh cube second. | ||
| # TODO: simplify this when this bug is fixed in iris. | ||
| if regridder_type == "GridToMeshESMFRegridder": | ||
| cube_list = CubeList([src_cube, tgt_cube, metadata_cube]) | ||
| elif regridder_type == "MeshToGridESMFRegridder": | ||
| cube_list = CubeList([tgt_cube, src_cube, metadata_cube]) | ||
| iris.fileformats.netcdf.save(cube_list, filename) | ||
|
|
||
|
|
||
| def load_regridder(filename): | ||
| """ | ||
| Load a regridder scheme instance. | ||
|
|
||
| Loads either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| filename : str | ||
| The file name to load from. | ||
| """ | ||
| with PARSE_UGRID_ON_LOAD.context(): | ||
| cubes = iris.load(filename) | ||
|
|
||
| src_name = "regridder source field" | ||
| tgt_name = "regridder target field" | ||
| metadata_name = "regridder weights and metadata" | ||
|
|
||
| # Extract the source, target and metadata information. | ||
| src_cube = cubes.extract_cube(src_name) | ||
| tgt_cube = cubes.extract_cube(tgt_name) | ||
| metadata_cube = cubes.extract_cube(metadata_name) | ||
|
|
||
| # Determine the regridder type. | ||
| regridder_type = metadata_cube.attributes["regridder type"] | ||
| assert regridder_type in REGRIDDER_NAME_MAP.keys() | ||
| scheme = REGRIDDER_NAME_MAP[regridder_type] | ||
|
|
||
| # Reconstruct the weight matrix. | ||
| weight_data = metadata_cube.data | ||
| row_name = "weight matrix rows" | ||
| weight_rows = metadata_cube.coord(row_name).points | ||
| col_name = "weight matrix columns" | ||
| weight_cols = metadata_cube.coord(col_name).points | ||
| weight_shape = metadata_cube.attributes["weights shape"] | ||
| weight_matrix = scipy.sparse.csr_matrix( | ||
| (weight_data, (weight_rows, weight_cols)), shape=weight_shape | ||
| ) | ||
|
|
||
| mdtol = metadata_cube.attributes["mdtol"] | ||
|
|
||
| regridder = scheme( | ||
| src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix | ||
| ) | ||
| return regridder | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for :mod:`esmf_regrid.experimental.io`.""" |
118 changes: 118 additions & 0 deletions
118
esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| """Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`.""" | ||
|
|
||
| from iris.cube import Cube | ||
| import numpy as np | ||
| from numpy import ma | ||
|
|
||
| from esmf_regrid.experimental.io import load_regridder, save_regridder | ||
| from esmf_regrid.experimental.unstructured_scheme import ( | ||
| GridToMeshESMFRegridder, | ||
| MeshToGridESMFRegridder, | ||
| ) | ||
| from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__cube_to_GridInfo import ( | ||
| _grid_cube, | ||
| ) | ||
| from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import ( | ||
| _gridlike_mesh, | ||
| ) | ||
|
|
||
|
|
||
| def _make_grid_to_mesh_regridder(): | ||
| src_lons = 3 | ||
| src_lats = 4 | ||
| tgt_lons = 5 | ||
| tgt_lats = 6 | ||
| lon_bounds = (-180, 180) | ||
| lat_bounds = (-90, 90) | ||
| # TODO check that circularity is preserved. | ||
| src = _grid_cube(src_lons, src_lats, lon_bounds, lat_bounds, circular=True) | ||
| src.coord("longitude").var_name = "longitude" | ||
| src.coord("latitude").var_name = "latitude" | ||
| mesh = _gridlike_mesh(tgt_lons, tgt_lats) | ||
| mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face") | ||
| tgt_data = np.zeros(tgt_lons * tgt_lats) | ||
| tgt = Cube(tgt_data) | ||
| tgt.add_aux_coord(mesh_coord_x, 0) | ||
| tgt.add_aux_coord(mesh_coord_y, 0) | ||
|
|
||
| rg = GridToMeshESMFRegridder(src, tgt, mdtol=0.5) | ||
| return rg, src | ||
|
|
||
|
|
||
| def _make_mesh_to_grid_regridder(): | ||
| src_lons = 3 | ||
| src_lats = 4 | ||
| tgt_lons = 5 | ||
| tgt_lats = 6 | ||
| lon_bounds = (-180, 180) | ||
| lat_bounds = (-90, 90) | ||
| # TODO check that circularity is preserved. | ||
| tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True) | ||
| tgt.coord("longitude").var_name = "longitude" | ||
| tgt.coord("latitude").var_name = "latitude" | ||
| mesh = _gridlike_mesh(src_lons, src_lats) | ||
| mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face") | ||
| src_data = np.zeros(src_lons * src_lats) | ||
| src = Cube(src_data) | ||
| src.add_aux_coord(mesh_coord_x, 0) | ||
| src.add_aux_coord(mesh_coord_y, 0) | ||
|
|
||
| rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5) | ||
| return rg, src | ||
|
|
||
|
|
||
| def test_GridToMeshESMFRegridder_round_trip(tmp_path): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stephenworsley Lovely stuff 👨🎤 |
||
| """Test save/load round tripping for `GridToMeshESMFRegridder`.""" | ||
| original_rg, src = _make_grid_to_mesh_regridder() | ||
| filename = tmp_path / "regridder.nc" | ||
| save_regridder(original_rg, filename) | ||
| loaded_rg = load_regridder(str(filename)) | ||
|
|
||
| assert original_rg.mdtol == loaded_rg.mdtol | ||
| assert original_rg.grid_x == loaded_rg.grid_x | ||
| assert original_rg.grid_y == loaded_rg.grid_y | ||
| # TODO: uncomment when iris mesh comparison becomes available. | ||
| # assert original_rg.mesh == loaded_rg.mesh | ||
|
|
||
| # Compare the weight matrices. | ||
| original_matrix = original_rg.regridder.weight_matrix | ||
| loaded_matrix = loaded_rg.regridder.weight_matrix | ||
| # Ensure the original and loaded weight matrix have identical type. | ||
| assert type(original_matrix) is type(loaded_matrix) # noqa E721 | ||
| assert np.array_equal(original_matrix.todense(), loaded_matrix.todense()) | ||
|
|
||
| # Demonstrate regridding still gives the same results. | ||
| src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape) | ||
| src_mask = np.zeros(src.data.shape) | ||
| src_mask[0, 0] = 1 | ||
| src.data = ma.array(src_data, mask=src_mask) | ||
| # TODO: make this a cube comparison when mesh comparison becomes available. | ||
| assert np.array_equal(original_rg(src).data, loaded_rg(src).data) | ||
|
|
||
|
|
||
| def test_MeshToGridESMFRegridder_round_trip(tmp_path): | ||
| """Test save/load round tripping for `MeshToGridESMFRegridder`.""" | ||
| original_rg, src = _make_mesh_to_grid_regridder() | ||
| filename = tmp_path / "regridder.nc" | ||
| save_regridder(original_rg, filename) | ||
| loaded_rg = load_regridder(str(filename)) | ||
|
|
||
| assert original_rg.mdtol == loaded_rg.mdtol | ||
| assert original_rg.grid_x == loaded_rg.grid_x | ||
| assert original_rg.grid_y == loaded_rg.grid_y | ||
| # TODO: uncomment when iris mesh comparison becomes available. | ||
| # assert original_rg.mesh == loaded_rg.mesh | ||
|
|
||
| # Compare the weight matrices. | ||
| original_matrix = original_rg.regridder.weight_matrix | ||
| loaded_matrix = loaded_rg.regridder.weight_matrix | ||
| # Ensure the original and loaded weight matrix have identical type. | ||
| assert type(original_matrix) is type(loaded_matrix) # noqa E721 | ||
| assert np.array_equal(original_matrix.todense(), loaded_matrix.todense()) | ||
|
|
||
| # Demonstrate regridding still gives the same results. | ||
| src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape) | ||
| src_mask = np.zeros(src.data.shape) | ||
| src_mask[0] = 1 | ||
| src.data = ma.array(src_data, mask=src_mask) | ||
| assert original_rg(src) == loaded_rg(src) | ||
13 changes: 13 additions & 0 deletions
13
esmf_regrid/tests/unit/experimental/io/test_save_regridder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| """Unit tests for :mod:`esmf_regrid.experimental.io.save_regridder`.""" | ||
|
|
||
| import pytest | ||
|
|
||
| from esmf_regrid.experimental.io import save_regridder | ||
|
|
||
|
|
||
| def test_invalid_type(tmp_path): | ||
| """Test that `save_regridder` raises a TypeError where appropriate.""" | ||
| invalid_obj = None | ||
| filename = tmp_path / "regridder.nc" | ||
| with pytest.raises(TypeError): | ||
| save_regridder(invalid_obj, filename) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.