diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java b/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java index 1355594176a9..8e753ff67c5e 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java @@ -90,18 +90,6 @@ public void testPreAggregate() "VALUES (22, 2, 11, 64)", plan -> assertAggregationNodeCount(plan, 4)); - assertQuery( - memorySession, - "SELECT " + - "sum(CASE WHEN sequence = 0 THEN value END), " + - "min(CASE WHEN sequence = 1 THEN value ELSE null END), " + - "max(CASE WHEN sequence = 0 THEN value END), " + - "sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) " + - "FROM test_table " + - "WHERE sequence = 42", - "VALUES (null, null, null, null)", - plan -> assertAggregationNodeCount(plan, 4)); - assertQuery( memorySession, "SELECT " + @@ -155,6 +143,22 @@ public void testPreAggregate() plan -> assertAggregationNodeCount(plan, 4)); } + @Test + public void testPreAggregateWithFilter() + { + assertQuery( + memorySession, + "SELECT " + + "sum(CASE WHEN sequence = 0 THEN value END), " + + "min(CASE WHEN sequence = 1 THEN value ELSE null END), " + + "max(CASE WHEN sequence = 0 THEN value END), " + + "sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) " + + "FROM test_table " + + "WHERE sequence = 42", + "VALUES (null, null, null, null)", + plan -> assertAggregationNodeCount(plan, 4)); + } + private void assertAggregationNodeCount(Plan plan, int count) { assertThat(countOfMatchingNodes(plan, AggregationNode.class::isInstance)).isEqualTo(count);