Skip to content

Commit cd44a3e

Browse files
authored
PI-2472: Tweak area weighting regrid enforce xdim ydim (#3595)
* _regrid_area_weighted_array: Set axis order to y_dim, x_dim last dimensions * _regrid_area_weighted_array: Extra tests for axes ordering * _regrid_area_weighted_array: Ensure x_dim and y_dim
1 parent e3e61b3 commit cd44a3e

File tree

1 file changed

+26
-52
lines changed

1 file changed

+26
-52
lines changed

lib/iris/experimental/regrid.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -491,39 +491,30 @@ def _regrid_area_weighted_array(
491491
cached_x_bounds.append(x_bounds)
492492
cached_x_indices.append(x_indices)
493493

494-
# Move y_dim and x_dim to last dimensions
494+
# Ensure we have x_dim and y_dim.
495495
x_dim_orig = copy.copy(x_dim)
496496
y_dim_orig = copy.copy(y_dim)
497-
if x_dim is None and y_dim is None:
498-
# e.g. a scalar point such as a vertical profile
499-
pass
500-
elif x_dim is not None and y_dim is None:
501-
# test cross_section along line latitude
502-
src_data = np.moveaxis(src_data, x_dim, -1)
503-
x_dim = src_data.ndim - 1
504-
elif y_dim is not None and x_dim is None:
505-
# test cross_section along line longitude
506-
src_data = np.moveaxis(src_data, y_dim, -1)
497+
if y_dim is None:
498+
src_data = np.expand_dims(src_data, axis=src_data.ndim)
507499
y_dim = src_data.ndim - 1
508-
elif x_dim < y_dim:
509-
src_data = np.moveaxis(src_data, x_dim, -1)
510-
src_data = np.moveaxis(src_data, y_dim - 1, -2)
500+
if x_dim is None:
501+
src_data = np.expand_dims(src_data, axis=src_data.ndim)
511502
x_dim = src_data.ndim - 1
512-
y_dim = src_data.ndim - 2
513-
else:
514-
src_data = np.moveaxis(src_data, x_dim, -1)
503+
# Move y_dim and x_dim to last dimensions
504+
src_data = np.moveaxis(src_data, x_dim, -1)
505+
if x_dim < y_dim:
506+
src_data = np.moveaxis(src_data, y_dim - 1, -2)
507+
elif x_dim > y_dim:
515508
src_data = np.moveaxis(src_data, y_dim, -2)
516-
x_dim = src_data.ndim - 1
517-
y_dim = src_data.ndim - 2
509+
x_dim = src_data.ndim - 1
510+
y_dim = src_data.ndim - 2
518511

519512
# Create empty data array to match the new grid.
520513
# Note that dtype is not preserved and that the array is
521514
# masked to allow for regions that do not overlap.
522515
new_shape = list(src_data.shape)
523-
if x_dim is not None:
524-
new_shape[x_dim] = grid_x_bounds.shape[0]
525-
if y_dim is not None:
526-
new_shape[y_dim] = grid_y_bounds.shape[0]
516+
new_shape[x_dim] = grid_x_bounds.shape[0]
517+
new_shape[y_dim] = grid_y_bounds.shape[0]
527518

528519
# Use input cube dtype or convert values to the smallest possible float
529520
# dtype when necessary.
@@ -541,15 +532,9 @@ def _regrid_area_weighted_array(
541532
new_data.mask = False
542533

543534
# Axes of data over which the weighted mean is calculated.
544-
axes = []
545-
if y_dim is not None:
546-
axes.append(y_dim)
547-
if x_dim is not None:
548-
axes.append(x_dim)
549-
axis = tuple(axes)
535+
axis = (y_dim, x_dim)
550536

551537
# Simple for loop approach.
552-
indices = [slice(None)] * new_data.ndim
553538
for j, (y_0, y_1) in enumerate(grid_y_bounds):
554539
# Reverse lower and upper if dest grid is decreasing.
555540
if grid_y_decreasing:
@@ -575,11 +560,7 @@ def _regrid_area_weighted_array(
575560
or not x_within_bounds[i]
576561
):
577562
# Mask out element(s) in new_data
578-
if x_dim is not None:
579-
indices[x_dim] = i
580-
if y_dim is not None:
581-
indices[y_dim] = j
582-
new_data[tuple(indices)] = ma.masked
563+
new_data[..., j, i] = ma.masked
583564
else:
584565
# Calculate weighted mean of data points.
585566
# Slice out relevant data (this may or may not be a view()
@@ -593,22 +574,16 @@ def _regrid_area_weighted_array(
593574
# Calculate weights based on areas of cropped bounds.
594575
weights = area_func(y_bounds, x_bounds)
595576

596-
if x_dim is not None:
597-
indices[x_dim] = x_indices
598-
if y_dim is not None:
599-
indices[y_dim] = y_indices
600-
data = src_data[tuple(indices)]
577+
data = src_data[..., y_indices, x_indices]
601578

602579
# Transpose weights to match dim ordering in data.
603580
weights_shape_y = weights.shape[0]
604581
weights_shape_x = weights.shape[1]
605582
# Broadcast the weights array to allow numpy's ma.average
606583
# to be called.
607584
weights_padded_shape = [1] * data.ndim
608-
if y_dim is not None:
609-
weights_padded_shape[y_dim] = weights_shape_y
610-
if x_dim is not None:
611-
weights_padded_shape[x_dim] = weights_shape_x
585+
weights_padded_shape[y_dim] = weights_shape_y
586+
weights_padded_shape[x_dim] = weights_shape_x
612587
# Assign new shape to raise error on copy.
613588
weights.shape = weights_padded_shape
614589
# Broadcast weights to match shape of data.
@@ -620,11 +595,7 @@ def _regrid_area_weighted_array(
620595
)
621596

622597
# Insert data (and mask) values into new array.
623-
if x_dim is not None:
624-
indices[x_dim] = i
625-
if y_dim is not None:
626-
indices[y_dim] = j
627-
new_data[tuple(indices)] = new_data_pt
598+
new_data[..., j, i] = new_data_pt
628599

629600
# Remove new mask if original data was not masked
630601
# and no values in the new array are masked.
@@ -633,10 +604,13 @@ def _regrid_area_weighted_array(
633604

634605
# Restore axis to original order
635606
if x_dim_orig is None and y_dim_orig is None:
636-
pass
637-
elif x_dim_orig is not None and y_dim_orig is None:
607+
new_data = np.squeeze(new_data, axis=x_dim)
608+
new_data = np.squeeze(new_data, axis=y_dim)
609+
elif y_dim_orig is None:
610+
new_data = np.squeeze(new_data, axis=y_dim)
638611
new_data = np.moveaxis(new_data, -1, x_dim_orig)
639-
elif y_dim_orig is not None and x_dim_orig is None:
612+
elif x_dim_orig is None:
613+
new_data = np.squeeze(new_data, axis=x_dim)
640614
new_data = np.moveaxis(new_data, -1, y_dim_orig)
641615
elif x_dim_orig < y_dim_orig:
642616
new_data = np.moveaxis(new_data, -1, x_dim_orig)

0 commit comments

Comments
 (0)