@@ -491,39 +491,30 @@ def _regrid_area_weighted_array(
491491 cached_x_bounds .append (x_bounds )
492492 cached_x_indices .append (x_indices )
493493
494- # Move y_dim and x_dim to last dimensions
494+ # Ensure we have x_dim and y_dim.
495495 x_dim_orig = copy .copy (x_dim )
496496 y_dim_orig = copy .copy (y_dim )
497- if x_dim is None and y_dim is None :
498- # e.g. a scalar point such as a vertical profile
499- pass
500- elif x_dim is not None and y_dim is None :
501- # test cross_section along line latitude
502- src_data = np .moveaxis (src_data , x_dim , - 1 )
503- x_dim = src_data .ndim - 1
504- elif y_dim is not None and x_dim is None :
505- # test cross_section along line longitude
506- src_data = np .moveaxis (src_data , y_dim , - 1 )
497+ if y_dim is None :
498+ src_data = np .expand_dims (src_data , axis = src_data .ndim )
507499 y_dim = src_data .ndim - 1
508- elif x_dim < y_dim :
509- src_data = np .moveaxis (src_data , x_dim , - 1 )
510- src_data = np .moveaxis (src_data , y_dim - 1 , - 2 )
500+ if x_dim is None :
501+ src_data = np .expand_dims (src_data , axis = src_data .ndim )
511502 x_dim = src_data .ndim - 1
512- y_dim = src_data .ndim - 2
513- else :
514- src_data = np .moveaxis (src_data , x_dim , - 1 )
503+ # Move y_dim and x_dim to last dimensions
504+ src_data = np .moveaxis (src_data , x_dim , - 1 )
505+ if x_dim < y_dim :
506+ src_data = np .moveaxis (src_data , y_dim - 1 , - 2 )
507+ elif x_dim > y_dim :
515508 src_data = np .moveaxis (src_data , y_dim , - 2 )
516- x_dim = src_data .ndim - 1
517- y_dim = src_data .ndim - 2
509+ x_dim = src_data .ndim - 1
510+ y_dim = src_data .ndim - 2
518511
519512 # Create empty data array to match the new grid.
520513 # Note that dtype is not preserved and that the array is
521514 # masked to allow for regions that do not overlap.
522515 new_shape = list (src_data .shape )
523- if x_dim is not None :
524- new_shape [x_dim ] = grid_x_bounds .shape [0 ]
525- if y_dim is not None :
526- new_shape [y_dim ] = grid_y_bounds .shape [0 ]
516+ new_shape [x_dim ] = grid_x_bounds .shape [0 ]
517+ new_shape [y_dim ] = grid_y_bounds .shape [0 ]
527518
528519 # Use input cube dtype or convert values to the smallest possible float
529520 # dtype when necessary.
@@ -541,15 +532,9 @@ def _regrid_area_weighted_array(
541532 new_data .mask = False
542533
543534 # Axes of data over which the weighted mean is calculated.
544- axes = []
545- if y_dim is not None :
546- axes .append (y_dim )
547- if x_dim is not None :
548- axes .append (x_dim )
549- axis = tuple (axes )
535+ axis = (y_dim , x_dim )
550536
551537 # Simple for loop approach.
552- indices = [slice (None )] * new_data .ndim
553538 for j , (y_0 , y_1 ) in enumerate (grid_y_bounds ):
554539 # Reverse lower and upper if dest grid is decreasing.
555540 if grid_y_decreasing :
@@ -575,11 +560,7 @@ def _regrid_area_weighted_array(
575560 or not x_within_bounds [i ]
576561 ):
577562 # Mask out element(s) in new_data
578- if x_dim is not None :
579- indices [x_dim ] = i
580- if y_dim is not None :
581- indices [y_dim ] = j
582- new_data [tuple (indices )] = ma .masked
563+ new_data [..., j , i ] = ma .masked
583564 else :
584565 # Calculate weighted mean of data points.
585566 # Slice out relevant data (this may or may not be a view()
@@ -593,22 +574,16 @@ def _regrid_area_weighted_array(
593574 # Calculate weights based on areas of cropped bounds.
594575 weights = area_func (y_bounds , x_bounds )
595576
596- if x_dim is not None :
597- indices [x_dim ] = x_indices
598- if y_dim is not None :
599- indices [y_dim ] = y_indices
600- data = src_data [tuple (indices )]
577+ data = src_data [..., y_indices , x_indices ]
601578
602579 # Transpose weights to match dim ordering in data.
603580 weights_shape_y = weights .shape [0 ]
604581 weights_shape_x = weights .shape [1 ]
605582 # Broadcast the weights array to allow numpy's ma.average
606583 # to be called.
607584 weights_padded_shape = [1 ] * data .ndim
608- if y_dim is not None :
609- weights_padded_shape [y_dim ] = weights_shape_y
610- if x_dim is not None :
611- weights_padded_shape [x_dim ] = weights_shape_x
585+ weights_padded_shape [y_dim ] = weights_shape_y
586+ weights_padded_shape [x_dim ] = weights_shape_x
612587 # Assign new shape to raise error on copy.
613588 weights .shape = weights_padded_shape
614589 # Broadcast weights to match shape of data.
@@ -620,11 +595,7 @@ def _regrid_area_weighted_array(
620595 )
621596
622597 # Insert data (and mask) values into new array.
623- if x_dim is not None :
624- indices [x_dim ] = i
625- if y_dim is not None :
626- indices [y_dim ] = j
627- new_data [tuple (indices )] = new_data_pt
598+ new_data [..., j , i ] = new_data_pt
628599
629600 # Remove new mask if original data was not masked
630601 # and no values in the new array are masked.
@@ -633,10 +604,13 @@ def _regrid_area_weighted_array(
633604
634605 # Restore axis to original order
635606 if x_dim_orig is None and y_dim_orig is None :
636- pass
637- elif x_dim_orig is not None and y_dim_orig is None :
607+ new_data = np .squeeze (new_data , axis = x_dim )
608+ new_data = np .squeeze (new_data , axis = y_dim )
609+ elif y_dim_orig is None :
610+ new_data = np .squeeze (new_data , axis = y_dim )
638611 new_data = np .moveaxis (new_data , - 1 , x_dim_orig )
639- elif y_dim_orig is not None and x_dim_orig is None :
612+ elif x_dim_orig is None :
613+ new_data = np .squeeze (new_data , axis = x_dim )
640614 new_data = np .moveaxis (new_data , - 1 , y_dim_orig )
641615 elif x_dim_orig < y_dim_orig :
642616 new_data = np .moveaxis (new_data , - 1 , x_dim_orig )
0 commit comments