Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/iris/analysis/_area_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, src_grid_cube, target_grid_cube, mdtol=1):
self.meshgrid_x,
self.meshgrid_y,
self.weights_info,
self.index_info,
) = _regrid_info

def __call__(self, cube):
Expand Down Expand Up @@ -124,6 +125,7 @@ def __call__(self, cube):
self.meshgrid_x,
self.meshgrid_y,
self.weights_info,
self.index_info,
)
return _regrid_area_weighted_rectilinear_src_and_grid__perform(
cube, _regrid_info, mdtol=self._mdtol
Expand Down
229 changes: 172 additions & 57 deletions lib/iris/experimental/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ def _weighted_mean_with_mdtol(data, weights, axis=None, mdtol=0):
return res


def _regrid_area_weighted_array(src_data, x_dim, y_dim, weights_info, mdtol=0):
def _regrid_area_weighted_array(
src_data, x_dim, y_dim, weights_info, index_info, mdtol=0
):
"""
Regrid the given data from its source grid to a new grid using
an area weighted mean to determine the resulting data values.
Expand Down Expand Up @@ -444,13 +446,19 @@ def _regrid_area_weighted_array(src_data, x_dim, y_dim, weights_info, mdtol=0):

"""
(
cached_x_indices,
cached_y_indices,
max_x_indices,
max_y_indices,
cached_weights,
blank_weights,
src_area_weights,
new_data_mask_basis,
) = weights_info

(
result_x_extent,
result_y_extent,
square_data_indices_y,
square_data_indices_x,
src_area_datas_required,
) = index_info

# Ensure we have x_dim and y_dim.
x_dim_orig = x_dim
y_dim_orig = y_dim
Expand Down Expand Up @@ -479,60 +487,47 @@ def _regrid_area_weighted_array(src_data, x_dim, y_dim, weights_info, mdtol=0):
# Note that dtype is not preserved and that the array mask
# allows for regions that do not overlap.
new_shape = list(src_data.shape)
new_shape[x_dim] = len(cached_x_indices)
new_shape[y_dim] = len(cached_y_indices)
num_target_pts = len(cached_y_indices) * len(cached_x_indices)
src_areas_shape = list(src_data.shape)
src_areas_shape[y_dim] = max_y_indices
src_areas_shape[x_dim] = max_x_indices
src_areas_shape += [num_target_pts]
new_shape[x_dim] = result_x_extent
new_shape[y_dim] = result_y_extent

# Use input cube dtype or convert values to the smallest possible float
# dtype when necessary.
dtype = np.promote_types(src_data.dtype, np.float16)
# Create empty arrays to hold src_data per target point, and weights
src_area_datas = np.zeros(src_areas_shape, dtype=np.float64)
src_area_weights = np.zeros(
list((max_y_indices, max_x_indices, num_target_pts))

# Axes of data over which the weighted mean is calculated.
axis = (y_dim, x_dim)

# Use previously established indices

src_area_datas_square = src_data[
..., square_data_indices_y, square_data_indices_x
]

_, src_area_datas_required = np.broadcast_arrays(
src_area_datas_square, src_area_datas_required
)

src_area_datas = np.where(
src_area_datas_required, src_area_datas_square, 0
)

# Flag to indicate whether the original data was a masked array.
src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False
if src_masked:
src_area_masks = np.full(src_areas_shape, True, dtype=np.bool_)
else:
new_data_mask = np.full(new_shape, False, dtype=np.bool_)
src_area_masks_square = src_data.mask[
..., square_data_indices_y, square_data_indices_x
]
src_area_masks = np.where(
src_area_datas_required, src_area_masks_square, True
)

# Axes of data over which the weighted mean is calculated.
axis = (y_dim, x_dim)
else:
# If the weights were originally blank, set the weights to all 1 to
# avoid divide by 0 error and set the new data mask for making the
# values 0
src_area_weights = np.where(blank_weights, 1, src_area_weights)

# Stack the src_area data and weights for each target point
target_pt_ji = -1
for j, y_indices in enumerate(cached_y_indices):
for i, x_indices in enumerate(cached_x_indices):
target_pt_ji += 1
# Determine whether to mask element i, j based on whether
# there are valid weights.
weights = cached_weights[j][i]
if isinstance(weights, bool) and not weights:
if not src_masked:
# Cheat! Fill the data with zeros and weights as one.
# The weighted average result will be the same, but
# we avoid dividing by zero.
src_area_weights[..., target_pt_ji] = 1
new_data_mask[..., j, i] = True
else:
# Calculate weighted mean of data points.
# Slice out relevant data (this may or may not be a view()
# depending on x_indices being a slice or not).
data = src_data[..., y_indices, x_indices]
len_x = data.shape[-1]
len_y = data.shape[-2]
src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data
src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights
if src_masked:
src_area_masks[
..., 0:len_y, 0:len_x, target_pt_ji
] = data.mask
new_data_mask = np.broadcast_to(new_data_mask_basis, new_shape)

# Broadcast the weights array to allow numpy's ma.average
# to be called.
Expand Down Expand Up @@ -770,9 +765,7 @@ def _calculate_regrid_area_weighted_weights(
):
"""
Compute the area weights used for area-weighted regridding.

Args:

* src_x_bounds:
A NumPy array of bounds along the X axis defining the source grid.
* src_y_bounds:
Expand All @@ -790,16 +783,12 @@ def _calculate_regrid_area_weighted_weights(
* area_func:
A function that returns an (p, q) array of weights given an (p, 2)
shaped array of Y bounds and an (q, 2) shaped array of X bounds.

Kwargs:

* circular:
A boolean indicating whether the `src_x_bounds` are periodic.
Default is False.

Returns:
The area weights to be used for area-weighted regridding.

"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
Expand Down Expand Up @@ -886,7 +875,13 @@ def _calculate_regrid_area_weighted_weights(
tuple(cached_weights),
)

weights_info = _calculate_regrid_area_weighted_weights(
(
cached_x_indices,
cached_y_indices,
max_x_indices,
max_y_indices,
cached_weights,
) = _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
Expand All @@ -897,6 +892,123 @@ def _calculate_regrid_area_weighted_weights(
circular,
)

# Go further, calculating the full weights array that we'll need in the
# perform step and the indices we'll need to extract from the cube we're
# regridding (src_data)

result_y_extent = len(grid_y_bounds)
result_x_extent = len(grid_x_bounds)

# Total number of points
num_target_pts = result_y_extent * result_x_extent

# Create empty array to hold weights
src_area_weights = np.zeros(
list((max_y_indices, max_x_indices, num_target_pts))
)

# Built for the case where the source cube isn't masked
blank_weights = np.zeros((num_target_pts,))
new_data_mask_basis = np.full(
(len(cached_y_indices), len(cached_x_indices)), False, dtype=np.bool_
)

# To permit fancy indexing, we need to store our data in an array whose
# first two dimensions represent the indices needed for the target cell.
# Since target cells can require a different number of indices, the size of
# these dimensions should be the maximum of this number.
# This means we need to track whether the data in
# that array is actually required and build those squared-off arrays
# TODO: Consider if a proper mask would be better
src_area_datas_required = np.full(
(max_y_indices, max_x_indices, num_target_pts), False
)
square_data_indices_y = np.zeros(
(max_y_indices, max_x_indices, num_target_pts), dtype=int
)
square_data_indices_x = np.zeros(
(max_y_indices, max_x_indices, num_target_pts), dtype=int
)

# Stack the weights for each target point and build the indices we'll need
# to extract the src_area_data
target_pt_ji = -1
for j, y_indices in enumerate(cached_y_indices):
for i, x_indices in enumerate(cached_x_indices):
target_pt_ji += 1
# Determine whether to mask element i, j based on whether
# there are valid weights.
weights = cached_weights[j][i]
if weights is False:
# Prepare for the src_data not being masked by storing the
# information that will let us fill the data with zeros and
# weights as one. The weighted average result will be the same,
# but we avoid dividing by zero.
blank_weights[target_pt_ji] = True
new_data_mask_basis[j, i] = True
else:
# Establish which indices are actually in y_indices and x_indices
if isinstance(y_indices, slice):
y_indices = list(
range(
y_indices.start,
y_indices.stop,
y_indices.step or 1,
)
)
else:
y_indices = list(y_indices)

if isinstance(x_indices, slice):
x_indices = list(
range(
x_indices.start,
x_indices.stop,
x_indices.step or 1,
)
)
else:
x_indices = list(x_indices)

# For the weights, we just need the lengths of these as we're
# dropping them into a pre-made array

len_y = len(y_indices)
len_x = len(x_indices)

src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights

# To build the indices for the source cube, we need equal
# shaped array so we pad with 0s and record the need to mask
# them in src_area_datas_required
padded_y_indices = y_indices + [0] * (max_y_indices - len_y)
padded_x_indices = x_indices + [0] * (max_x_indices - len_x)

square_data_indices_y[..., target_pt_ji] = np.array(
padded_y_indices
)[:, np.newaxis]
square_data_indices_x[..., target_pt_ji] = padded_x_indices

src_area_datas_required[0:len_y, 0:len_x, target_pt_ji] = True

# Package up the return data

weights_info = (
blank_weights,
src_area_weights,
new_data_mask_basis,
)

index_info = (
result_x_extent,
result_y_extent,
square_data_indices_y,
square_data_indices_x,
src_area_datas_required,
)

# Now return it

return (
src_x,
src_y,
Expand All @@ -907,6 +1019,7 @@ def _calculate_regrid_area_weighted_weights(
meshgrid_x,
meshgrid_y,
weights_info,
index_info,
)


Expand All @@ -929,6 +1042,7 @@ def _regrid_area_weighted_rectilinear_src_and_grid__perform(
meshgrid_x,
meshgrid_y,
weights_info,
index_info,
) = regrid_info

# Calculate new data array for regridded cube.
Expand All @@ -937,6 +1051,7 @@ def _regrid_area_weighted_rectilinear_src_and_grid__perform(
x_dim=src_x_dim,
y_dim=src_y_dim,
weights_info=weights_info,
index_info=index_info,
mdtol=mdtol,
)

Expand Down
Loading