Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions java/src/main/java/ai/rapids/cudf/RollingAggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,14 @@ public static RollingAggregation collectSet() {
public static RollingAggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) {
return new RollingAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality));
}

/**
* Select the nth element from a specified window.
*
* @param n Indicates the index of the element to be selected from the window
* @param nullPolicy Indicates whether null elements are to be skipped, or not
*/
public static RollingAggregation nth(int n, NullPolicy nullPolicy) {
return new RollingAggregation(Aggregation.nth(n, nullPolicy));
}
}
50 changes: 50 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5315,6 +5315,56 @@ void testWindowingMean() {
}
}

@Test
void testWindowingNthElement() {
final Integer X = null;
try (Table unsorted = new Table.TestBuilder()
.column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // 0: GBY Key
.column( 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1) // 1: GBY Key
.column( 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) // 2: OBY Key
.column( X, 4, 0, X, 4, X, 9, 7, 7, 3, 5, 7) // 3: Agg Column
.build()) {
try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2));
ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 3, 7, 7, 9, X, 4, X, 0, 4, X)) {
ColumnVector sortedAggColumn = sorted.getColumn(3);
assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn);

try (Scalar one = Scalar.fromInt(1);
Scalar two = Scalar.fromInt(2);
WindowOptions window = WindowOptions.builder()
.minPeriods(1)
.window(two, one)
.build()) {

try (Table windowAggResults = sorted.groupBy(0, 1)
.aggregateWindows(
RollingAggregation.nth(0, NullPolicy.INCLUDE).onColumn(3).overWindow(window),
RollingAggregation.nth(-1, NullPolicy.INCLUDE).onColumn(3).overWindow(window),
RollingAggregation.nth(1, NullPolicy.INCLUDE).onColumn(3).overWindow(window),
RollingAggregation.nth(0, NullPolicy.EXCLUDE).onColumn(3).overWindow(window),
RollingAggregation.nth(-1, NullPolicy.EXCLUDE).onColumn(3).overWindow(window),
RollingAggregation.nth(1, NullPolicy.EXCLUDE).onColumn(3).overWindow(window));
ColumnVector expect_first = ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, X, X, X, 0, 4);
ColumnVector expect_last = ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, X, 4, 4, 0, 4, X, X);
ColumnVector expect_1th = ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, X, 4, 0, 0, 4, X);
ColumnVector expect_first_skip_null =
ColumnVector.fromBoxedInts(7, 7, 5, 3, 7, 7, 9, 4, 0, 0, 0, 4);
ColumnVector expect_last_skip_null =
ColumnVector.fromBoxedInts(5, 3, 7, 7, 9, 9, 4, 4, 0, 4, 4, 4);
ColumnVector expect_1th_skip_null =
ColumnVector.fromBoxedInts(5, 5, 3, 7, 9, 9, 4, X, X, 4, 4, X)) {
assertColumnsAreEqual(expect_first, windowAggResults.getColumn(0));
assertColumnsAreEqual(expect_last, windowAggResults.getColumn(1));
assertColumnsAreEqual(expect_1th, windowAggResults.getColumn(2));
assertColumnsAreEqual(expect_first_skip_null, windowAggResults.getColumn(3));
assertColumnsAreEqual(expect_last_skip_null, windowAggResults.getColumn(4));
assertColumnsAreEqual(expect_1th_skip_null, windowAggResults.getColumn(5));
}
}
}
}
}

@Test
void testWindowingOnMultipleDifferentColumns() {
try (Table unsorted = new Table.TestBuilder()
Expand Down