Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
72bb985
add regridder saving
stephenworsley Nov 9, 2021
fe607f3
add regridder saving
stephenworsley Nov 9, 2021
a022853
avoid saving bug
stephenworsley Nov 9, 2021
923e9ff
add docstrings, copy iris utils
stephenworsley Nov 10, 2021
245d610
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2021
d3885f0
test/lint fixes
stephenworsley Nov 10, 2021
9f52c02
test fix
stephenworsley Nov 10, 2021
c1b2f6e
add test
stephenworsley Nov 10, 2021
c1f7297
test functionality
stephenworsley Nov 11, 2021
6764841
add comments and tests
stephenworsley Nov 11, 2021
4ce372f
flake fix
stephenworsley Nov 12, 2021
d2b1253
refactor tests
stephenworsley Nov 12, 2021
96d9160
fix tests
stephenworsley Nov 12, 2021
f4031a7
fix tests
stephenworsley Nov 12, 2021
2431def
fix tests
stephenworsley Nov 12, 2021
29d9f1c
use pytest fixture tmp_path
stephenworsley Nov 15, 2021
1e7cee4
refresh nox cache
stephenworsley Nov 15, 2021
172ef4a
fix test
stephenworsley Nov 15, 2021
924b135
remove temp file architecture
stephenworsley Nov 16, 2021
fcb850c
remove imports
stephenworsley Nov 16, 2021
851fe1e
update nox cache
stephenworsley Nov 16, 2021
1674c2e
fix tests
stephenworsley Nov 16, 2021
7d8c4db
increment CONDA_CACHE_BUILD
stephenworsley Nov 16, 2021
37fbb7f
toggle nox environment reuse
stephenworsley Nov 16, 2021
927265c
toggle nox environment reuse
stephenworsley Nov 17, 2021
fa4ea36
determine regridder_type generically
stephenworsley Nov 17, 2021
471e311
fix saver
stephenworsley Nov 17, 2021
d063008
fix saver
stephenworsley Nov 17, 2021
11bc7f5
fix loader
stephenworsley Nov 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .cirrus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ env:
# Maximum cache period (in weeks) before forcing a new cache upload.
CACHE_PERIOD: "2"
# Increment the build number to force new conda cache upload.
CONDA_CACHE_BUILD: "0"
CONDA_CACHE_BUILD: "1"
# Increment the build number to force new nox cache upload.
NOX_CACHE_BUILD: "0"
NOX_CACHE_BUILD: "2"
# Increment the build number to force new pip cache upload.
PIP_CACHE_BUILD: "0"
# Pip package to be installed.
Expand Down
149 changes: 149 additions & 0 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Provides load/save functions for regridders."""

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
from iris.experimental.ugrid import PARSE_UGRID_ON_LOAD
import numpy as np
import scipy.sparse

from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)


SUPPORTED_REGRIDDERS = [
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
]
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.

Saves either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`.

Parameters
----------
rg : GridToMeshESMFRegridder, MeshToGridESMFRegridder
The regridder instance to save.
filename : str
The file name to save to.
"""
src_name = "regridder source field"
tgt_name = "regridder target field"
regridder_type = rg.__class__.__name__
if regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_shape = [len(coord.points) for coord in src_grid]
src_data = np.zeros(src_shape)
src_cube = Cube(src_data, long_name=src_name)
src_cube.add_dim_coord(src_grid[0], 0)
src_cube.add_dim_coord(src_grid[1], 1)

tgt_mesh = rg.mesh
tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0])
tgt_cube = Cube(tgt_data, long_name=tgt_name)
for coord in tgt_mesh.to_MeshCoords("face"):
tgt_cube.add_aux_coord(coord, 0)
elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0])
src_cube = Cube(src_data, long_name=src_name)
for coord in src_mesh.to_MeshCoords("face"):
src_cube.add_aux_coord(coord, 0)

tgt_grid = (rg.grid_y, rg.grid_x)
tgt_shape = [len(coord.points) for coord in tgt_grid]
tgt_data = np.zeros(tgt_shape)
tgt_cube = Cube(tgt_data, long_name=tgt_name)
tgt_cube.add_dim_coord(tgt_grid[0], 0)
tgt_cube.add_dim_coord(tgt_grid[1], 1)
else:
msg = (
f"Expected a regridder of type `GridToMeshESMFRegridder` or "
f"`MeshToGridESMFRegridder`, got type {regridder_type}."
)
raise TypeError(msg)

metadata_name = "regridder weights and metadata"

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = weight_matrix.tocoo()
weight_data = reformatted_weight_matrix.data
weight_rows = reformatted_weight_matrix.row
weight_cols = reformatted_weight_matrix.col
weight_shape = reformatted_weight_matrix.shape

mdtol = rg.mdtol
attributes = {
"regridder type": regridder_type,
"mdtol": mdtol,
"weights shape": weight_shape,
}

metadata_cube = Cube(weight_data, long_name=metadata_name, attributes=attributes)
row_name = "weight matrix rows"
row_coord = AuxCoord(weight_rows, long_name=row_name)
col_name = "weight matrix columns"
col_coord = AuxCoord(weight_cols, long_name=col_name)
metadata_cube.add_aux_coord(row_coord, 0)
metadata_cube.add_aux_coord(col_coord, 0)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
cube_list = CubeList([src_cube, tgt_cube, metadata_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, metadata_cube])
iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
"""
Load a regridder scheme instance.

Loads either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`.

Parameters
----------
filename : str
The file name to load from.
"""
with PARSE_UGRID_ON_LOAD.context():
cubes = iris.load(filename)

src_name = "regridder source field"
tgt_name = "regridder target field"
metadata_name = "regridder weights and metadata"

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(src_name)
tgt_cube = cubes.extract_cube(tgt_name)
metadata_cube = cubes.extract_cube(metadata_name)

# Determine the regridder type.
regridder_type = metadata_cube.attributes["regridder type"]
assert regridder_type in REGRIDDER_NAME_MAP.keys()
scheme = REGRIDDER_NAME_MAP[regridder_type]

# Reconstruct the weight matrix.
weight_data = metadata_cube.data
row_name = "weight matrix rows"
weight_rows = metadata_cube.coord(row_name).points
col_name = "weight matrix columns"
weight_cols = metadata_cube.coord(col_name).points
weight_shape = metadata_cube.attributes["weights shape"]
weight_matrix = scipy.sparse.csr_matrix(
(weight_data, (weight_rows, weight_cols)), shape=weight_shape
)

mdtol = metadata_cube.attributes["mdtol"]

regridder = scheme(
src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix
)
return regridder
27 changes: 19 additions & 8 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy_coords(src_coords, add_method):
return new_cube


def _regrid_unstructured_to_rectilinear__prepare(src_mesh_cube, target_grid_cube):
def _regrid_unstructured_to_rectilinear__prepare(
src_mesh_cube, target_grid_cube, precomputed_weights=None
):
"""
First (setup) part of 'regrid_unstructured_to_rectilinear'.

Expand Down Expand Up @@ -257,7 +259,7 @@ def _regrid_unstructured_to_rectilinear__prepare(src_mesh_cube, target_grid_cube
meshinfo = _mesh_to_MeshInfo(mesh)
gridinfo = _cube_to_GridInfo(target_grid_cube)

regridder = Regridder(meshinfo, gridinfo)
regridder = Regridder(meshinfo, gridinfo, precomputed_weights)

regrid_info = (mesh_dim, grid_x, grid_y, regridder)

Expand Down Expand Up @@ -350,7 +352,9 @@ def regrid_unstructured_to_rectilinear(src_cube, grid_cube, mdtol=0):
class MeshToGridESMFRegridder:
"""Regridder class for unstructured to rectilinear cubes."""

def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
def __init__(
self, src_mesh_cube, target_grid_cube, mdtol=1, precomputed_weights=None
):
"""
Create regridder for conversions between source mesh and target grid.

Expand Down Expand Up @@ -382,9 +386,12 @@ def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
self.mdtol = mdtol

partial_regrid_info = _regrid_unstructured_to_rectilinear__prepare(
src_mesh_cube, target_grid_cube
src_mesh_cube, target_grid_cube, precomputed_weights=precomputed_weights
)

# Record source mesh.
self.mesh = src_mesh_cube.mesh

# Store regrid info.
_, self.grid_x, self.grid_y, self.regridder = partial_regrid_info

Expand Down Expand Up @@ -491,7 +498,9 @@ def copy_coords(src_coords, add_method):
return new_cube


def _regrid_rectilinear_to_unstructured__prepare(src_grid_cube, target_mesh_cube):
def _regrid_rectilinear_to_unstructured__prepare(
src_grid_cube, target_mesh_cube, precomputed_weights=None
):
"""
First (setup) part of 'regrid_rectilinear_to_unstructured'.

Expand All @@ -510,7 +519,7 @@ def _regrid_rectilinear_to_unstructured__prepare(src_grid_cube, target_mesh_cube
meshinfo = _mesh_to_MeshInfo(mesh)
gridinfo = _cube_to_GridInfo(src_grid_cube)

regridder = Regridder(gridinfo, meshinfo)
regridder = Regridder(gridinfo, meshinfo, precomputed_weights)

regrid_info = (grid_x_dim, grid_y_dim, grid_x, grid_y, mesh, regridder)

Expand Down Expand Up @@ -610,7 +619,9 @@ def regrid_rectilinear_to_unstructured(src_cube, mesh_cube, mdtol=0):
class GridToMeshESMFRegridder:
"""Regridder class for rectilinear to unstructured cubes."""

def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
def __init__(
self, src_mesh_cube, target_grid_cube, mdtol=1, precomputed_weights=None
):
"""
Create regridder for conversions between source grid and target mesh.

Expand All @@ -637,7 +648,7 @@ def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
self.mdtol = mdtol

partial_regrid_info = _regrid_rectilinear_to_unstructured__prepare(
src_mesh_cube, target_grid_cube
src_mesh_cube, target_grid_cube, precomputed_weights=precomputed_weights
)

# Store regrid info.
Expand Down
1 change: 1 addition & 0 deletions esmf_regrid/tests/unit/experimental/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for :mod:`esmf_regrid.experimental.io`."""
118 changes: 118 additions & 0 deletions esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`."""

from iris.cube import Cube
import numpy as np
from numpy import ma

from esmf_regrid.experimental.io import load_regridder, save_regridder
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__cube_to_GridInfo import (
_grid_cube,
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import (
_gridlike_mesh,
)


def _make_grid_to_mesh_regridder():
src_lons = 3
src_lats = 4
tgt_lons = 5
tgt_lats = 6
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
# TODO check that circularity is preserved.
src = _grid_cube(src_lons, src_lats, lon_bounds, lat_bounds, circular=True)
src.coord("longitude").var_name = "longitude"
src.coord("latitude").var_name = "latitude"
mesh = _gridlike_mesh(tgt_lons, tgt_lats)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
tgt_data = np.zeros(tgt_lons * tgt_lats)
tgt = Cube(tgt_data)
tgt.add_aux_coord(mesh_coord_x, 0)
tgt.add_aux_coord(mesh_coord_y, 0)

rg = GridToMeshESMFRegridder(src, tgt, mdtol=0.5)
return rg, src


def _make_mesh_to_grid_regridder():
src_lons = 3
src_lats = 4
tgt_lons = 5
tgt_lats = 6
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
# TODO check that circularity is preserved.
tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True)
tgt.coord("longitude").var_name = "longitude"
tgt.coord("latitude").var_name = "latitude"
mesh = _gridlike_mesh(src_lons, src_lats)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
src_data = np.zeros(src_lons * src_lats)
src = Cube(src_data)
src.add_aux_coord(mesh_coord_x, 0)
src.add_aux_coord(mesh_coord_y, 0)

rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5)
return rg, src


def test_GridToMeshESMFRegridder_round_trip(tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephenworsley Lovely stuff 👨‍🎤

"""Test save/load round tripping for `GridToMeshESMFRegridder`."""
original_rg, src = _make_grid_to_mesh_regridder()
filename = tmp_path / "regridder.nc"
save_regridder(original_rg, filename)
loaded_rg = load_regridder(str(filename))

assert original_rg.mdtol == loaded_rg.mdtol
assert original_rg.grid_x == loaded_rg.grid_x
assert original_rg.grid_y == loaded_rg.grid_y
# TODO: uncomment when iris mesh comparison becomes available.
# assert original_rg.mesh == loaded_rg.mesh

# Compare the weight matrices.
original_matrix = original_rg.regridder.weight_matrix
loaded_matrix = loaded_rg.regridder.weight_matrix
# Ensure the original and loaded weight matrix have identical type.
assert type(original_matrix) is type(loaded_matrix) # noqa E721
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense())

# Demonstrate regridding still gives the same results.
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape)
src_mask = np.zeros(src.data.shape)
src_mask[0, 0] = 1
src.data = ma.array(src_data, mask=src_mask)
# TODO: make this a cube comparison when mesh comparison becomes available.
assert np.array_equal(original_rg(src).data, loaded_rg(src).data)


def test_MeshToGridESMFRegridder_round_trip(tmp_path):
"""Test save/load round tripping for `MeshToGridESMFRegridder`."""
original_rg, src = _make_mesh_to_grid_regridder()
filename = tmp_path / "regridder.nc"
save_regridder(original_rg, filename)
loaded_rg = load_regridder(str(filename))

assert original_rg.mdtol == loaded_rg.mdtol
assert original_rg.grid_x == loaded_rg.grid_x
assert original_rg.grid_y == loaded_rg.grid_y
# TODO: uncomment when iris mesh comparison becomes available.
# assert original_rg.mesh == loaded_rg.mesh

# Compare the weight matrices.
original_matrix = original_rg.regridder.weight_matrix
loaded_matrix = loaded_rg.regridder.weight_matrix
# Ensure the original and loaded weight matrix have identical type.
assert type(original_matrix) is type(loaded_matrix) # noqa E721
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense())

# Demonstrate regridding still gives the same results.
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape)
src_mask = np.zeros(src.data.shape)
src_mask[0] = 1
src.data = ma.array(src_data, mask=src_mask)
assert original_rg(src) == loaded_rg(src)
13 changes: 13 additions & 0 deletions esmf_regrid/tests/unit/experimental/io/test_save_regridder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Unit tests for :mod:`esmf_regrid.experimental.io.save_regridder`."""

import pytest

from esmf_regrid.experimental.io import save_regridder


def test_invalid_type(tmp_path):
"""Test that `save_regridder` raises a TypeError where appropriate."""
invalid_obj = None
filename = tmp_path / "regridder.nc"
with pytest.raises(TypeError):
save_regridder(invalid_obj, filename)