diff --git a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java index 408c93ff0a1..718e17c7f5c 100644 --- a/java/src/main/java/ai/rapids/cudf/RollingAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/RollingAggregation.java @@ -217,4 +217,14 @@ public static RollingAggregation collectSet() { public static RollingAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { return new RollingAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); } + + /** + * Select the nth element from a specified window. + * + * @param n Indicates the index of the element to be selected from the window + * @param nullPolicy Indicates whether null elements are to be skipped, or not + */ + public static RollingAggregation nth(int n, NullPolicy nullPolicy) { + return new RollingAggregation(Aggregation.nth(n, nullPolicy)); + } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index af28cfb6d6c..557bd8f289a 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -5315,6 +5315,56 @@ void testWindowingMean() { } } + @Test + void testWindowingNthElement() { + final Integer X = null; + try (Table unsorted = new Table.TestBuilder() + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // 0: GBY Key + .column( 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1) // 1: GBY Key + .column( 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) // 2: OBY Key + .column( X, 4, 0, X, 4, X, 9, 7, 7, 3, 5, 7) // 3: Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 3, 7, 7, 9, X, 4, X, 0, 4, X)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn); + + try (Scalar one = Scalar.fromInt(1); + Scalar two = Scalar.fromInt(2); + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(two, one) + .build()) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows( + RollingAggregation.nth(0, NullPolicy.INCLUDE).onColumn(3).overWindow(window), + RollingAggregation.nth(-1, NullPolicy.INCLUDE).onColumn(3).overWindow(window), + RollingAggregation.nth(1, NullPolicy.INCLUDE).onColumn(3).overWindow(window), + RollingAggregation.nth(0, NullPolicy.EXCLUDE).onColumn(3).overWindow(window), + RollingAggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(3).overWindow(window), + RollingAggregation.nth(1, NullPolicy.EXCLUDE).onColumn(3).overWindow(window)); + ColumnVector expect_first = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, X, X, X, 0, 4); + ColumnVector expect_last = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, X, 4, 4, 0, 4, X, X); + ColumnVector expect_1th = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, X, 4, 0, 0, 4, X); + ColumnVector expect_first_skip_null = + ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, 4, 0, 0, 0, 4); + ColumnVector expect_last_skip_null = + ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, 9, 4, 4, 0, 4, 4, 4); + ColumnVector expect_1th_skip_null = + ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, 4, X, X, 4, 4, X)) { + assertColumnsAreEqual(expect_first, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_last, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_1th, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_first_skip_null, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_last_skip_null, windowAggResults.getColumn(4)); + assertColumnsAreEqual(expect_1th_skip_null, windowAggResults.getColumn(5)); + } + } + } + } + } + @Test void testWindowingOnMultipleDifferentColumns() { try (Table unsorted = new Table.TestBuilder()