@@ -576,11 +576,7 @@ where
576576 pub fn slice_move < I > ( mut self , info : I ) -> ArrayBase < S , I :: OutDim >
577577 where I : SliceArg < D >
578578 {
579- assert_eq ! (
580- info. in_ndim( ) ,
581- self . ndim( ) ,
582- "The input dimension of `info` must match the array to be sliced." ,
583- ) ;
579+ assert_eq ! ( info. in_ndim( ) , self . ndim( ) , "The input dimension of `info` must match the array to be sliced." , ) ;
584580 let out_ndim = info. out_ndim ( ) ;
585581 let mut new_dim = I :: OutDim :: zeros ( out_ndim) ;
586582 let mut new_strides = I :: OutDim :: zeros ( out_ndim) ;
@@ -648,11 +644,7 @@ impl<A, D: Dimension> LayoutRef<A, D>
648644 pub fn slice_collapse < I > ( & mut self , info : I )
649645 where I : SliceArg < D >
650646 {
651- assert_eq ! (
652- info. in_ndim( ) ,
653- self . ndim( ) ,
654- "The input dimension of `info` must match the array to be sliced." ,
655- ) ;
647+ assert_eq ! ( info. in_ndim( ) , self . ndim( ) , "The input dimension of `info` must match the array to be sliced." , ) ;
656648 let mut axis = 0 ;
657649 info. as_ref ( ) . iter ( ) . for_each ( |& ax_info| match ax_info {
658650 SliceInfoElem :: Slice { start, end, step } => {
@@ -1120,8 +1112,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
11201112 // bounds check the indices first
11211113 if let Some ( max_index) = indices. iter ( ) . cloned ( ) . max ( ) {
11221114 if max_index >= axis_len {
1123- panic ! ( "ndarray: index {} is out of bounds in array of len {}" ,
1124- max_index, self . len_of( axis) ) ;
1115+ panic ! ( "ndarray: index {} is out of bounds in array of len {}" , max_index, self . len_of( axis) ) ;
11251116 }
11261117 } // else: indices empty is ok
11271118 let view = self . view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
@@ -1530,10 +1521,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
15301521
15311522 ndassert ! (
15321523 axis_index < self . ndim( ) ,
1533- concat!(
1534- "Window axis {} does not match array dimension {} " ,
1535- "(with array of shape {:?})"
1536- ) ,
1524+ concat!( "Window axis {} does not match array dimension {} " , "(with array of shape {:?})" ) ,
15371525 axis_index,
15381526 self . ndim( ) ,
15391527 self . shape( )
@@ -3119,8 +3107,7 @@ where
31193107 /// ***Panics*** if not `index < self.len_of(axis)`.
31203108 pub fn remove_index ( & mut self , axis : Axis , index : usize )
31213109 {
3122- assert ! ( index < self . len_of( axis) , "index {} must be less than length of Axis({})" ,
3123- index, axis. index( ) ) ;
3110+ assert ! ( index < self . len_of( axis) , "index {} must be less than length of Axis({})" , index, axis. index( ) ) ;
31243111 let ( _, mut tail) = self . view_mut ( ) . split_at ( axis, index) ;
31253112 // shift elements to the front
31263113 Zip :: from ( tail. lanes_mut ( axis) ) . for_each ( |mut lane| lane. rotate1_front ( ) ) ;
@@ -3193,15 +3180,16 @@ impl<A, D: Dimension> ArrayRef<A, D>
31933180 /// - All elements equal or greater than the k-th element to its right
31943181 /// - The ordering within each partition is undefined
31953182 ///
3183+ /// Empty arrays (i.e., those with any zero-length axes) are considered partitioned already,
3184+ /// and will be returned unchanged.
3185+ ///
3186+ /// **Panics** if `k` is out of bounds for a non-zero axis length.
3187+ ///
31963188 /// # Parameters
31973189 ///
31983190 /// * `kth` - Index to partition by. The k-th element will be in its sorted position.
31993191 /// * `axis` - Axis along which to partition.
32003192 ///
3201- /// # Returns
3202- ///
3203- /// A new array of the same shape and type as the input array, with elements partitioned.
3204- ///
32053193 /// # Examples
32063194 ///
32073195 /// ```
@@ -3221,19 +3209,19 @@ impl<A, D: Dimension> ArrayRef<A, D>
32213209 A : Clone + Ord + num_traits:: Zero ,
32223210 D : Dimension ,
32233211 {
3224- // Bounds checking
3225- let axis_len = self . len_of ( axis) ;
3226- if kth >= axis_len {
3227- panic ! ( "partition index {} is out of bounds for axis of length {}" , kth, axis_len) ;
3228- }
3229-
32303212 let mut result = self . to_owned ( ) ;
32313213
3232- // Must guarantee that the array isn't empty before checking for contiguity
3233- if result . shape ( ) . iter ( ) . any ( |s| * s == 0 ) {
3214+ // Return early if the array has zero-length dimensions
3215+ if self . shape ( ) . iter ( ) . any ( |s| * s == 0 ) {
32343216 return result;
32353217 }
32363218
3219+ // Bounds checking. Panics if kth is out of bounds
3220+ let axis_len = self . len_of ( axis) ;
3221+ if kth >= axis_len {
3222+ panic ! ( "Partition index {} is out of bounds for axis {} of length {}" , kth, axis. 0 , axis_len) ;
3223+ }
3224+
32373225 // Check if the first lane is contiguous
32383226 let is_contiguous = result
32393227 . lanes_mut ( axis)
@@ -3428,11 +3416,7 @@ mod tests
34283416 fn test_partition_contiguous_or_not ( )
34293417 {
34303418 // Test contiguous case (C-order)
3431- let a = array ! [
3432- [ 7 , 1 , 5 ] ,
3433- [ 2 , 6 , 0 ] ,
3434- [ 3 , 4 , 8 ]
3435- ] ;
3419+ let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
34363420
34373421 // Partition along axis 0 (contiguous)
34383422 let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
@@ -3442,20 +3426,24 @@ mod tests
34423426 // - Last row should be >= middle row (kth element)
34433427 for col in 0 ..3 {
34443428 let kth = p_axis0[ [ 1 , col] ] ;
3445- assert ! ( p_axis0[ [ 0 , col] ] <= kth,
3429+ assert ! (
3430+ p_axis0[ [ 0 , col] ] <= kth,
34463431 "Column {}: First row {} should be <= middle row {}" ,
3447- col, p_axis0[ [ 0 , col] ] , kth) ;
3448- assert ! ( p_axis0[ [ 2 , col] ] >= kth,
3432+ col,
3433+ p_axis0[ [ 0 , col] ] ,
3434+ kth
3435+ ) ;
3436+ assert ! (
3437+ p_axis0[ [ 2 , col] ] >= kth,
34493438 "Column {}: Last row {} should be >= middle row {}" ,
3450- col, p_axis0[ [ 2 , col] ] , kth) ;
3439+ col,
3440+ p_axis0[ [ 2 , col] ] ,
3441+ kth
3442+ ) ;
34513443 }
34523444
34533445 // Test non-contiguous case (F-order)
3454- let a = array ! [
3455- [ 7 , 1 , 5 ] ,
3456- [ 2 , 6 , 0 ] ,
3457- [ 3 , 4 , 8 ]
3458- ] ;
3446+ let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
34593447
34603448 // Make array non-contiguous by transposing
34613449 let a = a. t ( ) . to_owned ( ) ;
@@ -3467,12 +3455,69 @@ mod tests
34673455 // - First column should be <= middle column
34683456 // - Last column should be >= middle column
34693457 for row in 0 ..3 {
3470- assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3458+ assert ! (
3459+ p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
34713460 "Row {}: First column {} should be <= middle column {}" ,
3472- row, p_axis1[ [ row, 0 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3473- assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3461+ row,
3462+ p_axis1[ [ row, 0 ] ] ,
3463+ p_axis1[ [ row, 1 ] ]
3464+ ) ;
3465+ assert ! (
3466+ p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
34743467 "Row {}: Last column {} should be >= middle column {}" ,
3475- row, p_axis1[ [ row, 2 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3468+ row,
3469+ p_axis1[ [ row, 2 ] ] ,
3470+ p_axis1[ [ row, 1 ] ]
3471+ ) ;
34763472 }
34773473 }
3474+
3475+ #[ test]
3476+ fn test_partition_empty ( )
3477+ {
3478+ // Test 1D empty array
3479+ let empty1d = Array1 :: < i32 > :: zeros ( 0 ) ;
3480+ let result1d = empty1d. partition ( 0 , Axis ( 0 ) ) ;
3481+ assert_eq ! ( result1d. len( ) , 0 ) ;
3482+
3483+ // Test 1D empty array with kth out of bounds
3484+ let result1d_out_of_bounds = empty1d. partition ( 1 , Axis ( 0 ) ) ;
3485+ assert_eq ! ( result1d_out_of_bounds. len( ) , 0 ) ;
3486+
3487+ // Test 2D empty array
3488+ let empty2d = Array2 :: < i32 > :: zeros ( ( 0 , 3 ) ) ;
3489+ let result2d = empty2d. partition ( 0 , Axis ( 0 ) ) ;
3490+ assert_eq ! ( result2d. shape( ) , & [ 0 , 3 ] ) ;
3491+
3492+ // Test 2D empty array with zero columns
3493+ let empty2d_cols = Array2 :: < i32 > :: zeros ( ( 2 , 0 ) ) ;
3494+ let result2d_cols = empty2d_cols. partition ( 0 , Axis ( 1 ) ) ;
3495+ assert_eq ! ( result2d_cols. shape( ) , & [ 2 , 0 ] ) ;
3496+
3497+ // Test 3D empty array
3498+ let empty3d = Array3 :: < i32 > :: zeros ( ( 0 , 2 , 3 ) ) ;
3499+ let result3d = empty3d. partition ( 0 , Axis ( 0 ) ) ;
3500+ assert_eq ! ( result3d. shape( ) , & [ 0 , 2 , 3 ] ) ;
3501+
3502+ // Test 3D empty array with zero in middle dimension
3503+ let empty3d_mid = Array3 :: < i32 > :: zeros ( ( 2 , 0 , 3 ) ) ;
3504+ let result3d_mid = empty3d_mid. partition ( 0 , Axis ( 1 ) ) ;
3505+ assert_eq ! ( result3d_mid. shape( ) , & [ 2 , 0 , 3 ] ) ;
3506+
3507+ // Test 4D empty array
3508+ let empty4d = Array4 :: < i32 > :: zeros ( ( 0 , 2 , 3 , 4 ) ) ;
3509+ let result4d = empty4d. partition ( 0 , Axis ( 0 ) ) ;
3510+ assert_eq ! ( result4d. shape( ) , & [ 0 , 2 , 3 , 4 ] ) ;
3511+
3512+ // Test empty array with non-zero dimensions in other axes
3513+ let empty_mixed = Array2 :: < i32 > :: zeros ( ( 0 , 5 ) ) ;
3514+ let result_mixed = empty_mixed. partition ( 0 , Axis ( 0 ) ) ;
3515+ assert_eq ! ( result_mixed. shape( ) , & [ 0 , 5 ] ) ;
3516+
3517+ // Test empty array with negative strides
3518+ let arr = Array2 :: < i32 > :: zeros ( ( 3 , 3 ) ) ;
3519+ let empty_slice = arr. slice ( s ! [ 0 ..0 , ..] ) ;
3520+ let result_slice = empty_slice. partition ( 0 , Axis ( 0 ) ) ;
3521+ assert_eq ! ( result_slice. shape( ) , & [ 0 , 3 ] ) ;
3522+ }
34783523}
0 commit comments