@@ -3184,6 +3184,81 @@ impl<A, D: Dimension> ArrayRef<A, D>
31843184 f ( & * prev, & mut * curr)
31853185 } ) ;
31863186 }
3187+
3188+ /// Return a partitioned copy of the array.
3189+ ///
3190+ /// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
3191+ /// The k-th element will be in its sorted position, with:
3192+ /// - All elements smaller than the k-th element to its left
3193+ /// - All elements equal or greater than the k-th element to its right
3194+ /// - The ordering within each partition is undefined
3195+ ///
3196+ /// # Parameters
3197+ ///
3198+ /// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199+ /// * `axis` - Axis along which to partition.
3200+ ///
3201+ /// # Returns
3202+ ///
3203+ /// A new array of the same shape and type as the input array, with elements partitioned.
3204+ ///
3205+ /// # Examples
3206+ ///
3207+ /// ```
3208+ /// use ndarray::prelude::*;
3209+ ///
3210+ /// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3211+ /// let p = a.partition(3, Axis(0));
3212+ ///
3213+ /// // The element at position 3 is now 3, with smaller elements to the left
3214+ /// // and greater elements to the right
3215+ /// assert_eq!(p[3], 3);
3216+ /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
3217+ /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
3218+ /// ```
3219+ pub fn partition ( & self , kth : usize , axis : Axis ) -> Array < A , D >
3220+ where
3221+ A : Clone + Ord + num_traits:: Zero ,
3222+ D : Dimension ,
3223+ {
3224+ // Bounds checking
3225+ let axis_len = self . len_of ( axis) ;
3226+ if kth >= axis_len {
3227+ panic ! ( "partition index {} is out of bounds for axis of length {}" , kth, axis_len) ;
3228+ }
3229+
3230+ let mut result = self . to_owned ( ) ;
3231+
3232+ // Check if the first lane is contiguous
3233+ let is_contiguous = result
3234+ . lanes_mut ( axis)
3235+ . into_iter ( )
3236+ . next ( )
3237+ . unwrap ( )
3238+ . is_contiguous ( ) ;
3239+
3240+ if is_contiguous {
3241+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3242+ lane. as_slice_mut ( ) . unwrap ( ) . select_nth_unstable ( kth) ;
3243+ } ) ;
3244+ } else {
3245+ let mut temp_vec = vec ! [ A :: zero( ) ; axis_len] ;
3246+
3247+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3248+ Zip :: from ( & mut temp_vec) . and ( & lane) . for_each ( |dest, src| {
3249+ * dest = src. clone ( ) ;
3250+ } ) ;
3251+
3252+ temp_vec. select_nth_unstable ( kth) ;
3253+
3254+ Zip :: from ( & mut lane) . and ( & temp_vec) . for_each ( |dest, src| {
3255+ * dest = src. clone ( ) ;
3256+ } ) ;
3257+ } ) ;
3258+ }
3259+
3260+ result
3261+ }
31873262}
31883263
31893264/// Transmute from A to B.
@@ -3277,4 +3352,121 @@ mod tests
32773352 let _a2 = a. clone ( ) ;
32783353 assert_first ! ( a) ;
32793354 }
3355+
3356+ #[ test]
3357+ fn test_partition_1d ( )
3358+ {
3359+ // Test partitioning a 1D array
3360+ let array = arr1 ( & [ 3 , 1 , 4 , 1 , 5 , 9 , 2 , 6 ] ) ;
3361+ let result = array. partition ( 3 , Axis ( 0 ) ) ;
3362+ // After partitioning, the element at index 3 should be in its final sorted position
3363+ assert ! ( result. slice( s![ ..3 ] ) . iter( ) . all( |& x| x <= result[ 3 ] ) ) ;
3364+ assert ! ( result. slice( s![ 4 ..] ) . iter( ) . all( |& x| x >= result[ 3 ] ) ) ;
3365+ }
3366+
3367+ #[ test]
3368+ fn test_partition_2d ( )
3369+ {
3370+ // Test partitioning a 2D array along both axes
3371+ let array = arr2 ( & [ [ 3 , 1 , 4 ] , [ 1 , 5 , 9 ] , [ 2 , 6 , 5 ] ] ) ;
3372+
3373+ // Partition along axis 0 (rows)
3374+ let result0 = array. partition ( 1 , Axis ( 0 ) ) ;
3375+ // After partitioning along axis 0, each column should have its middle element in the correct position
3376+ assert ! ( result0[ [ 0 , 0 ] ] <= result0[ [ 1 , 0 ] ] && result0[ [ 2 , 0 ] ] >= result0[ [ 1 , 0 ] ] ) ;
3377+ assert ! ( result0[ [ 0 , 1 ] ] <= result0[ [ 1 , 1 ] ] && result0[ [ 2 , 1 ] ] >= result0[ [ 1 , 1 ] ] ) ;
3378+ assert ! ( result0[ [ 0 , 2 ] ] <= result0[ [ 1 , 2 ] ] && result0[ [ 2 , 2 ] ] >= result0[ [ 1 , 2 ] ] ) ;
3379+
3380+ // Partition along axis 1 (columns)
3381+ let result1 = array. partition ( 1 , Axis ( 1 ) ) ;
3382+ // After partitioning along axis 1, each row should have its middle element in the correct position
3383+ assert ! ( result1[ [ 0 , 0 ] ] <= result1[ [ 0 , 1 ] ] && result1[ [ 0 , 2 ] ] >= result1[ [ 0 , 1 ] ] ) ;
3384+ assert ! ( result1[ [ 1 , 0 ] ] <= result1[ [ 1 , 1 ] ] && result1[ [ 1 , 2 ] ] >= result1[ [ 1 , 1 ] ] ) ;
3385+ assert ! ( result1[ [ 2 , 0 ] ] <= result1[ [ 2 , 1 ] ] && result1[ [ 2 , 2 ] ] >= result1[ [ 2 , 1 ] ] ) ;
3386+ }
3387+
3388+ #[ test]
3389+ fn test_partition_3d ( )
3390+ {
3391+ // Test partitioning a 3D array
3392+ let array = arr3 ( & [ [ [ 3 , 1 ] , [ 4 , 1 ] ] , [ [ 5 , 9 ] , [ 2 , 6 ] ] ] ) ;
3393+
3394+ // Partition along axis 0
3395+ let result = array. partition ( 0 , Axis ( 0 ) ) ;
3396+ // After partitioning, each 2x2 slice should have its first element in the correct position
3397+ assert ! ( result[ [ 0 , 0 , 0 ] ] <= result[ [ 1 , 0 , 0 ] ] ) ;
3398+ assert ! ( result[ [ 0 , 0 , 1 ] ] <= result[ [ 1 , 0 , 1 ] ] ) ;
3399+ assert ! ( result[ [ 0 , 1 , 0 ] ] <= result[ [ 1 , 1 , 0 ] ] ) ;
3400+ assert ! ( result[ [ 0 , 1 , 1 ] ] <= result[ [ 1 , 1 , 1 ] ] ) ;
3401+ }
3402+
3403+ #[ test]
3404+ #[ should_panic]
3405+ fn test_partition_invalid_kth ( )
3406+ {
3407+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3408+ // This should panic because kth=4 is out of bounds
3409+ let _ = a. partition ( 4 , Axis ( 0 ) ) ;
3410+ }
3411+
3412+ #[ test]
3413+ #[ should_panic]
3414+ fn test_partition_invalid_axis ( )
3415+ {
3416+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3417+ // This should panic because axis=1 is out of bounds for a 1D array
3418+ let _ = a. partition ( 0 , Axis ( 1 ) ) ;
3419+ }
3420+
3421+ #[ test]
3422+ fn test_partition_contiguous_or_not ( )
3423+ {
3424+ // Test contiguous case (C-order)
3425+ let a = array ! [
3426+ [ 7 , 1 , 5 ] ,
3427+ [ 2 , 6 , 0 ] ,
3428+ [ 3 , 4 , 8 ]
3429+ ] ;
3430+
3431+ // Partition along axis 0 (contiguous)
3432+ let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
3433+
3434+ // For each column, verify the partitioning:
3435+ // - First row should be <= middle row (kth element)
3436+ // - Last row should be >= middle row (kth element)
3437+ for col in 0 ..3 {
3438+ let kth = p_axis0[ [ 1 , col] ] ;
3439+ assert ! ( p_axis0[ [ 0 , col] ] <= kth,
3440+ "Column {}: First row {} should be <= middle row {}" ,
3441+ col, p_axis0[ [ 0 , col] ] , kth) ;
3442+ assert ! ( p_axis0[ [ 2 , col] ] >= kth,
3443+ "Column {}: Last row {} should be >= middle row {}" ,
3444+ col, p_axis0[ [ 2 , col] ] , kth) ;
3445+ }
3446+
3447+ // Test non-contiguous case (F-order)
3448+ let a = array ! [
3449+ [ 7 , 1 , 5 ] ,
3450+ [ 2 , 6 , 0 ] ,
3451+ [ 3 , 4 , 8 ]
3452+ ] ;
3453+
3454+ // Make array non-contiguous by transposing
3455+ let a = a. t ( ) . to_owned ( ) ;
3456+
3457+ // Partition along axis 1 (non-contiguous)
3458+ let p_axis1 = a. partition ( 1 , Axis ( 1 ) ) ;
3459+
3460+ // For each row, verify the partitioning:
3461+ // - First column should be <= middle column
3462+ // - Last column should be >= middle column
3463+ for row in 0 ..3 {
3464+ assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3465+ "Row {}: First column {} should be <= middle column {}" ,
3466+ row, p_axis1[ [ row, 0 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3467+ assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3468+ "Row {}: Last column {} should be >= middle column {}" ,
3469+ row, p_axis1[ [ row, 2 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3470+ }
3471+ }
32803472}
0 commit comments