diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index 1363a9fa9a7c..73b9cd58314f 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -13,15 +13,6 @@ ${project.parent.basedir} - - - instances @@ -390,12 +381,6 @@ provided - - org.testng - testng - provided - - com.squareup.okhttp3 okhttp @@ -533,25 +518,35 @@ - - - - org.apache.maven.plugins - maven-surefire-plugin - - - - org.apache.maven.surefire - surefire-junit-platform - ${dep.plugin.surefire.version} - - - org.apache.maven.surefire - surefire-testng - ${dep.plugin.surefire.version} - - - - - + + + benchmarks + + + + org.codehaus.mojo + exec-maven-plugin + + ${java.home}/bin/java + + -DoutputDirectory=benchmark_outputs + -classpath + + io.trino.benchmark.BenchmarkSuite + + test + + + + benchmarks + + exec + + + + + + + + diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankAccumulator.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankAccumulator.java index 7cf8f1f114f5..ca21f4cca5e5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankAccumulator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankAccumulator.java @@ -15,14 +15,12 @@ import io.trino.array.LongBigArray; import it.unimi.dsi.fastutil.longs.LongArrayList; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static com.google.common.collect.Lists.cartesianProduct; import static java.lang.Math.min; import static org.assertj.core.api.Assertions.assertThat; @@ -43,27 +41,20 @@ public long hashCode(long rowId) } }; - @DataProvider - public static Object[][] parameters() - { - List topNs = Arrays.asList(1, 2, 3); - List valueCounts = Arrays.asList(0, 1, 2, 4, 8); - List groupCounts = Arrays.asList(1, 2, 3); - List drainWithRankings = Arrays.asList(true, false); - return to2DArray(cartesianProduct(topNs, valueCounts, groupCounts, drainWithRankings)); - } - - private static Object[][] to2DArray(List> nestedList) + @Test + public void testSinglePeerGroupInsert() { - Object[][] array = new Object[nestedList.size()][]; - for (int i = 0; i < nestedList.size(); i++) { - array[i] = nestedList.get(i).toArray(); + for (int topN : Arrays.asList(1, 2, 3)) { + for (int valueCount : Arrays.asList(0, 1, 2, 4, 8)) { + for (int groupCount : Arrays.asList(1, 2, 3)) { + testSinglePeerGroupInsert(topN, valueCount, groupCount, true); + testSinglePeerGroupInsert(topN, valueCount, groupCount, false); + } + } } - return array; } - @Test(dataProvider = "parameters") - public void testSinglePeerGroupInsert(int topN, long valueCount, long groupCount, boolean drainWithRanking) + private void testSinglePeerGroupInsert(int topN, long valueCount, long groupCount, boolean drainWithRanking) { List evicted = new LongArrayList(); GroupedTopNRankAccumulator accumulator = new GroupedTopNRankAccumulator(STRATEGY, topN, evicted::add); @@ -103,8 +94,20 @@ public void testSinglePeerGroupInsert(int topN, long valueCount, long groupCount } } - @Test(dataProvider = "parameters") - public void testIncreasingAllUniqueValues(int topN, long valueCount, long groupCount, boolean drainWithRanking) + @Test + public void testIncreasingAllUniqueValues() + { + for (int topN : Arrays.asList(1, 2, 3)) { + for (int valueCount : Arrays.asList(0, 1, 2, 4, 8)) { + for (int groupCount : Arrays.asList(1, 2, 3)) { + testIncreasingAllUniqueValues(topN, valueCount, groupCount, true); + testIncreasingAllUniqueValues(topN, valueCount, groupCount, false); + } + } + } + } + + private void testIncreasingAllUniqueValues(int topN, long valueCount, long groupCount, boolean drainWithRanking) { List evicted = new LongArrayList(); GroupedTopNRankAccumulator accumulator = new GroupedTopNRankAccumulator(STRATEGY, topN, evicted::add); @@ -144,8 +147,20 @@ public void testIncreasingAllUniqueValues(int topN, long valueCount, long groupC } } - @Test(dataProvider = "parameters") - public void testDecreasingAllUniqueValues(int topN, long valueCount, long groupCount, boolean drainWithRanking) + @Test + public void testDecreasingAllUniqueValues() + { + for (int topN : Arrays.asList(1, 2, 3)) { + for (int valueCount : Arrays.asList(0, 1, 2, 4, 8)) { + for (int groupCount : Arrays.asList(1, 2, 3)) { + testDecreasingAllUniqueValues(topN, valueCount, groupCount, true); + testDecreasingAllUniqueValues(topN, valueCount, groupCount, false); + } + } + } + } + + private void testDecreasingAllUniqueValues(int topN, long valueCount, long groupCount, boolean drainWithRanking) { List evicted = new LongArrayList(); GroupedTopNRankAccumulator accumulator = new GroupedTopNRankAccumulator(STRATEGY, topN, evicted::add); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankBuilder.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankBuilder.java index 5b5a67a8b7d5..4f6c6a0a43a0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRankBuilder.java @@ -19,8 +19,7 @@ import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinCompiler; import io.trino.type.BlockTypeOperators; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -39,12 +38,6 @@ public class TestGroupedTopNRankBuilder { - @DataProvider - public static Object[][] produceRanking() - { - return new Object[][] {{true}, {false}}; - } - @Test public void testEmptyInput() { @@ -74,8 +67,14 @@ public long hashCode(Page page, int position) assertThat(groupedTopNBuilder.buildResult().hasNext()).isFalse(); } - @Test(dataProvider = "produceRanking") - public void testSingleGroupTopN(boolean produceRanking) + @Test + public void testSingleGroupTopN() + { + testSingleGroupTopN(true); + testSingleGroupTopN(false); + } + + private void testSingleGroupTopN(boolean produceRanking) { TypeOperators typeOperators = new TypeOperators(); BlockTypeOperators blockTypeOperators = new BlockTypeOperators(typeOperators); @@ -133,8 +132,14 @@ public void testSingleGroupTopN(boolean produceRanking) assertPageEquals(outputTypes, getOnlyElement(output), expected); } - @Test(dataProvider = "produceRanking") - public void testMultiGroupTopN(boolean produceRanking) + @Test + public void testMultiGroupTopN() + { + testMultiGroupTopN(true); + testMultiGroupTopN(false); + } + + private void testMultiGroupTopN(boolean produceRanking) { TypeOperators typeOperators = new TypeOperators(); BlockTypeOperators blockTypeOperators = new BlockTypeOperators(typeOperators); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRowNumberBuilder.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRowNumberBuilder.java index 06d1acce7219..a150a4e47dcf 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRowNumberBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupedTopNRowNumberBuilder.java @@ -18,8 +18,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinCompiler; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -36,19 +35,6 @@ public class TestGroupedTopNRowNumberBuilder { private static final TypeOperators TYPE_OPERATORS_CACHE = new TypeOperators(); - @DataProvider - public static Object[][] produceRowNumbers() - { - return new Object[][] {{true}, {false}}; - } - - @DataProvider - public static Object[][] pageRowCounts() - { - // make either page or row count > 1024 to expand the big arrays - return new Object[][] {{10000, 20}, {20, 10000}}; - } - @Test public void testEmptyInput() { @@ -64,8 +50,14 @@ public void testEmptyInput() assertThat(groupedTopNBuilder.buildResult().hasNext()).isFalse(); } - @Test(dataProvider = "produceRowNumbers") - public void testMultiGroupTopN(boolean produceRowNumbers) + @Test + public void testMultiGroupTopN() + { + testMultiGroupTopN(true); + testMultiGroupTopN(false); + } + + private void testMultiGroupTopN(boolean produceRowNumbers) { List types = ImmutableList.of(BIGINT, DOUBLE); List input = rowPagesBuilder(types) @@ -131,8 +123,14 @@ public void testMultiGroupTopN(boolean produceRowNumbers) } } - @Test(dataProvider = "produceRowNumbers") - public void testSingleGroupTopN(boolean produceRowNumbers) + @Test + public void testSingleGroupTopN() + { + testSingleGroupTopN(true); + testSingleGroupTopN(false); + } + + private void testSingleGroupTopN(boolean produceRowNumbers) { List types = ImmutableList.of(BIGINT, DOUBLE); List input = rowPagesBuilder(types) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java index babf7dbdf044..cdeb3e748778 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java @@ -40,10 +40,10 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.io.IOException; import java.util.ArrayList; @@ -94,8 +94,11 @@ import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestHashAggregationOperator { private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution(); @@ -107,58 +110,36 @@ public class TestHashAggregationOperator private static final int MAX_BLOCK_SIZE_IN_BYTES = 64 * 1024; - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); private final TypeOperators typeOperators = new TypeOperators(); private final JoinCompiler joinCompiler = new JoinCompiler(typeOperators); - private DummySpillerFactory spillerFactory; - @BeforeMethod - public void setUp() - { - executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - spillerFactory = new DummySpillerFactory(); - } - - @AfterMethod(alwaysRun = true) + @AfterAll public void tearDown() { - spillerFactory = null; executor.shutdownNow(); scheduledExecutor.shutdownNow(); } - @DataProvider(name = "hashEnabled") - public static Object[][] hashEnabled() - { - return new Object[][] {{true}, {false}}; - } - - @DataProvider(name = "hashEnabledAndMemoryLimitForMergeValues") - public static Object[][] hashEnabledAndMemoryLimitForMergeValuesProvider() + @Test + public void testHashAggregation() { - return new Object[][] { - {true, true, true, 8, Integer.MAX_VALUE}, - {true, true, false, 8, Integer.MAX_VALUE}, - {false, false, false, 0, 0}, - {false, true, true, 0, 0}, - {false, true, false, 0, 0}, - {false, true, true, 8, 0}, - {false, true, false, 8, 0}, - {false, true, true, 8, Integer.MAX_VALUE}, - {false, true, false, 8, Integer.MAX_VALUE}}; + testHashAggregation(true, true, true, 8, Integer.MAX_VALUE); + testHashAggregation(true, true, false, 8, Integer.MAX_VALUE); + testHashAggregation(false, false, false, 0, 0); + testHashAggregation(false, true, true, 0, 0); + testHashAggregation(false, true, false, 0, 0); + testHashAggregation(false, true, true, 8, 0); + testHashAggregation(false, true, false, 8, 0); + testHashAggregation(false, true, true, 8, Integer.MAX_VALUE); + testHashAggregation(false, true, false, 8, Integer.MAX_VALUE); } - @DataProvider - public Object[][] dataType() + private void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { - return new Object[][] {{VARCHAR}, {BIGINT}}; - } + DummySpillerFactory spillerFactory = new DummySpillerFactory(); - @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) - { // make operator produce multiple pages during finish phase int numberOfRows = 40_000; TestingAggregationFunction countVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", fromTypes(VARCHAR)); @@ -215,9 +196,24 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boole .isTrue(); } - @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + @Test + public void testHashAggregationWithGlobals() + { + testHashAggregationWithGlobals(true, true, true, 8, Integer.MAX_VALUE); + testHashAggregationWithGlobals(true, true, false, 8, Integer.MAX_VALUE); + testHashAggregationWithGlobals(false, false, false, 0, 0); + testHashAggregationWithGlobals(false, true, true, 0, 0); + testHashAggregationWithGlobals(false, true, false, 0, 0); + testHashAggregationWithGlobals(false, true, true, 8, 0); + testHashAggregationWithGlobals(false, true, false, 8, 0); + testHashAggregationWithGlobals(false, true, true, 8, Integer.MAX_VALUE); + testHashAggregationWithGlobals(false, true, false, 8, Integer.MAX_VALUE); + } + + private void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + TestingAggregationFunction countVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", fromTypes(VARCHAR)); TestingAggregationFunction countBooleanColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", fromTypes(BOOLEAN)); TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", fromTypes(VARCHAR)); @@ -263,9 +259,24 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()), revokeMemoryWhenAddingPages); } - @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + @Test + public void testHashAggregationMemoryReservation() + { + testHashAggregationMemoryReservation(true, true, true, 8, Integer.MAX_VALUE); + testHashAggregationMemoryReservation(true, true, false, 8, Integer.MAX_VALUE); + testHashAggregationMemoryReservation(false, false, false, 0, 0); + testHashAggregationMemoryReservation(false, true, true, 0, 0); + testHashAggregationMemoryReservation(false, true, false, 0, 0); + testHashAggregationMemoryReservation(false, true, true, 8, 0); + testHashAggregationMemoryReservation(false, true, false, 8, 0); + testHashAggregationMemoryReservation(false, true, true, 8, Integer.MAX_VALUE); + testHashAggregationMemoryReservation(false, true, false, 8, Integer.MAX_VALUE); + } + + private void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + TestingAggregationFunction arrayAggColumn = FUNCTION_RESOLUTION.getAggregateFunction("array_agg", fromTypes(BIGINT)); List hashChannels = Ints.asList(1); @@ -308,8 +319,19 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp assertThat(getOnlyElement(operator.getOperatorContext().getNestedOperatorStats()).getRevocableMemoryReservation().toBytes()).isEqualTo(0); } - @Test(dataProvider = "hashEnabled", expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node memory limit of 10B.*") - public void testMemoryLimit(boolean hashEnabled) + @Test + public void testMemoryLimit() + { + assertThatThrownBy(() -> testMemoryLimit(true)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of 10B.*"); + + assertThatThrownBy(() -> testMemoryLimit(false)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of 10B.*"); + } + + private void testMemoryLimit(boolean hashEnabled) { TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", fromTypes(VARCHAR)); @@ -347,9 +369,24 @@ public void testMemoryLimit(boolean hashEnabled) toPages(operatorFactory, driverContext, input); } - @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + @Test + public void testHashBuilderResize() + { + testHashBuilderResize(true, true, true, 8, Integer.MAX_VALUE); + testHashBuilderResize(true, true, false, 8, Integer.MAX_VALUE); + testHashBuilderResize(false, false, false, 0, 0); + testHashBuilderResize(false, true, true, 0, 0); + testHashBuilderResize(false, true, false, 0, 0); + testHashBuilderResize(false, true, true, 8, 0); + testHashBuilderResize(false, true, false, 8, 0); + testHashBuilderResize(false, true, true, 8, Integer.MAX_VALUE); + testHashBuilderResize(false, true, false, 8, Integer.MAX_VALUE); + } + + private void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + BlockBuilder builder = VARCHAR.createBlockBuilder(null, 1, MAX_BLOCK_SIZE_IN_BYTES); VARCHAR.writeSlice(builder, Slices.allocate(200_000)); // this must be larger than MAX_BLOCK_SIZE_IN_BYTES, 64K builder.build(); @@ -388,7 +425,13 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boo toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "dataType") + @Test + public void testMemoryReservationYield() + { + testMemoryReservationYield(VARCHAR); + testMemoryReservationYield(BIGINT); + } + public void testMemoryReservationYield(Type type) { List input = createPagesWithDistinctHashKeys(type, 6_000, 600); @@ -426,8 +469,19 @@ public void testMemoryReservationYield(Type type) assertThat(count).isEqualTo(6_000 * 600); } - @Test(dataProvider = "hashEnabled", expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node memory limit of 3MB.*") - public void testHashBuilderResizeLimit(boolean hashEnabled) + @Test + public void testHashBuilderResizeLimit() + { + assertThatThrownBy(() -> testHashBuilderResizeLimit(true)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of 3MB.*"); + + assertThatThrownBy(() -> testHashBuilderResizeLimit(false)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of 3MB.*"); + } + + private void testHashBuilderResizeLimit(boolean hashEnabled) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 1, MAX_BLOCK_SIZE_IN_BYTES); VARCHAR.writeSlice(builder, Slices.allocate(5_000_000)); // this must be larger than MAX_BLOCK_SIZE_IN_BYTES, 64K @@ -464,8 +518,14 @@ public void testHashBuilderResizeLimit(boolean hashEnabled) toPages(operatorFactory, driverContext, input); } - @Test(dataProvider = "hashEnabled") - public void testMultiSliceAggregationOutput(boolean hashEnabled) + @Test + public void testMultiSliceAggregationOutput() + { + testMultiSliceAggregationOutput(true); + testMultiSliceAggregationOutput(false); + } + + private void testMultiSliceAggregationOutput(boolean hashEnabled) { // estimate the number of entries required to create 1.5 pages of results // See InMemoryHashAggregationBuilder.buildTypes() @@ -499,8 +559,15 @@ public void testMultiSliceAggregationOutput(boolean hashEnabled) assertThat(toPages(operatorFactory, createDriverContext(), input).size()).isEqualTo(2); } - @Test(dataProvider = "hashEnabled") - public void testMultiplePartialFlushes(boolean hashEnabled) + @Test + public void testMultiplePartialFlushes() + throws Exception + { + testMultiplePartialFlushes(true); + testMultiplePartialFlushes(false); + } + + private void testMultiplePartialFlushes(boolean hashEnabled) throws Exception { List hashChannels = Ints.asList(0); @@ -584,6 +651,8 @@ public void testMultiplePartialFlushes(boolean hashEnabled) @Test public void testMergeWithMemorySpill() { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + RowPagesBuilder rowPagesBuilder = rowPagesBuilder(BIGINT); int smallPagesSpillThresholdSize = 150000; diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java index c43a6c412aa9..3171e6c59a7b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java @@ -25,10 +25,11 @@ import io.trino.sql.gen.JoinCompiler; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Optional; @@ -46,8 +47,12 @@ import static io.trino.testing.TestingTaskContext.createTaskContext; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) +@Execution(SAME_THREAD) public class TestHashSemiJoinOperator { private ExecutorService executor; @@ -55,7 +60,7 @@ public class TestHashSemiJoinOperator private TaskContext taskContext; private TypeOperators typeOperators; - @BeforeMethod + @BeforeEach public void setUp() { executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); @@ -64,27 +69,21 @@ public void setUp() typeOperators = new TypeOperators(); } - @AfterMethod(alwaysRun = true) + @AfterEach public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); } - @DataProvider(name = "hashEnabledValues") - public static Object[][] hashEnabledValuesProvider() + @Test + public void testSemiJoin() { - return new Object[][] {{true}, {false}}; + testSemiJoin(true); + testSemiJoin(false); } - @DataProvider - public Object[][] dataType() - { - return new Object[][] {{VARCHAR}, {BIGINT}}; - } - - @Test(dataProvider = "hashEnabledValues") - public void testSemiJoin(boolean hashEnabled) + private void testSemiJoin(boolean hashEnabled) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -148,8 +147,14 @@ public void testSemiJoin(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "hashEnabledValues") - public void testSemiJoinOnVarcharType(boolean hashEnabled) + @Test + public void testSemiJoinOnVarcharType() + { + testSemiJoinOnVarcharType(true); + testSemiJoinOnVarcharType(false); + } + + private void testSemiJoinOnVarcharType(boolean hashEnabled) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -213,8 +218,14 @@ public void testSemiJoinOnVarcharType(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "hashEnabledValues") - public void testBuildSideNulls(boolean hashEnabled) + @Test + public void testBuildSideNulls() + { + testBuildSideNulls(true); + testBuildSideNulls(false); + } + + private void testBuildSideNulls(boolean hashEnabled) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -272,8 +283,14 @@ public void testBuildSideNulls(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "hashEnabledValues") - public void testProbeSideNulls(boolean hashEnabled) + @Test + public void testProbeSideNulls() + { + testProbeSideNulls(true); + testProbeSideNulls(false); + } + + private void testProbeSideNulls(boolean hashEnabled) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -331,8 +348,14 @@ public void testProbeSideNulls(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "hashEnabledValues") - public void testProbeAndBuildNulls(boolean hashEnabled) + @Test + public void testProbeAndBuildNulls() + { + testProbeAndBuildNulls(true); + testProbeAndBuildNulls(false); + } + + private void testProbeAndBuildNulls(boolean hashEnabled) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -391,8 +414,19 @@ public void testProbeAndBuildNulls(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "hashEnabledValues", expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node memory limit of.*") - public void testMemoryLimit(boolean hashEnabled) + @Test + public void testMemoryLimit() + { + assertThatThrownBy(() -> testMemoryLimit(true)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of.*"); + + assertThatThrownBy(() -> testMemoryLimit(false)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of.*"); + } + + private void testMemoryLimit(boolean hashEnabled) { DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, DataSize.ofBytes(100)) .addPipelineContext(0, true, true, false) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java index 1657f51fb5d7..f79b8b979bce 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java @@ -25,10 +25,10 @@ import io.trino.sql.gen.JoinCompiler; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Optional; @@ -51,47 +51,33 @@ import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestMarkDistinctOperator { - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; - private DriverContext driverContext; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); private final TypeOperators typeOperators = new TypeOperators(); private final JoinCompiler joinCompiler = new JoinCompiler(typeOperators); - @BeforeMethod - public void setUp() - { - executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION) - .addPipelineContext(0, true, true, false) - .addDriverContext(); - } - - @AfterMethod(alwaysRun = true) + @AfterAll public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); } - @DataProvider - public Object[][] dataType() + @Test + public void testMarkDistinct() { - return new Object[][] {{VARCHAR}, {BIGINT}}; + testMarkDistinct(true, newDriverContext()); + testMarkDistinct(false, newDriverContext()); } - @DataProvider(name = "hashEnabledValues") - public static Object[][] hashEnabledValuesProvider() - { - return new Object[][] {{true}, {false}}; - } - - @Test(dataProvider = "hashEnabledValues") - public void testMarkDistinct(boolean hashEnabled) + private void testMarkDistinct(boolean hashEnabled, DriverContext driverContext) { RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, Ints.asList(0), BIGINT); List input = rowPagesBuilder @@ -116,8 +102,14 @@ public void testMarkDistinct(boolean hashEnabled) OperatorAssertion.assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected.build(), hashEnabled, Optional.of(1)); } - @Test(dataProvider = "hashEnabledValues") - public void testRleDistinctMask(boolean hashEnabled) + @Test + public void testRleDistinctMask() + { + testRleDistinctMask(true, newDriverContext()); + testRleDistinctMask(false, newDriverContext()); + } + + private void testRleDistinctMask(boolean hashEnabled, DriverContext driverContext) { RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, Ints.asList(0), BIGINT); List inputs = rowPagesBuilder @@ -180,8 +172,14 @@ public void testRleDistinctMask(boolean hashEnabled) } } - @Test(dataProvider = "dataType") - public void testMemoryReservationYield(Type type) + @Test + public void testMemoryReservationYield() + { + testMemoryReservationYield(BIGINT); + testMemoryReservationYield(VARCHAR); + } + + private void testMemoryReservationYield(Type type) { List input = createPagesWithDistinctHashKeys(type, 6_000, 600); @@ -202,4 +200,11 @@ public void testMemoryReservationYield(Type type) } assertThat(count).isEqualTo(6_000 * 600); } + + private DriverContext newDriverContext() + { + return createTaskContext(executor, scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestOrderByOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestOrderByOperator.java index ed23a1d2860d..2fbe0ac96793 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestOrderByOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestOrderByOperator.java @@ -23,10 +23,10 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Optional; @@ -53,45 +53,38 @@ import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestOrderByOperator { - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; - private DummySpillerFactory spillerFactory; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); private final TypeOperators typeOperators = new TypeOperators(); - @DataProvider - public static Object[][] spillEnabled() + @AfterAll + public void tearDown() { - return new Object[][] { - {false, false, 0}, - {true, false, 8}, - {true, true, 8}, - {true, false, 0}, - {true, true, 0}}; + executor.shutdownNow(); + scheduledExecutor.shutdownNow(); } - @BeforeMethod - public void setUp() + @Test + public void testMultipleOutputPages() { - executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - spillerFactory = new DummySpillerFactory(); + testMultipleOutputPages(false, false, 0); + testMultipleOutputPages(true, false, 8); + testMultipleOutputPages(true, true, 8); + testMultipleOutputPages(true, false, 0); + testMultipleOutputPages(true, true, 0); } - @AfterMethod(alwaysRun = true) - public void tearDown() + private void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { - executor.shutdownNow(); - scheduledExecutor.shutdownNow(); - spillerFactory = null; - } + DummySpillerFactory spillerFactory = new DummySpillerFactory(); - @Test(dataProvider = "spillEnabled") - public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) - { // make operator produce multiple pages during finish phase int numberOfRows = 80_000; List input = rowPagesBuilder(BIGINT, DOUBLE) @@ -129,8 +122,17 @@ public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWh .isTrue(); } - @Test(dataProvider = "spillEnabled") - public void testSingleFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testSingleFieldKey() + { + testSingleFieldKey(false, false, 0); + testSingleFieldKey(true, false, 8); + testSingleFieldKey(true, true, 8); + testSingleFieldKey(true, false, 0); + testSingleFieldKey(true, true, 0); + } + + private void testSingleFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, DOUBLE) .row(1L, 0.1) @@ -150,7 +152,7 @@ public void testSingleFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAdd ImmutableList.of(ASC_NULLS_LAST), new PagesIndex.TestingFactory(false), spillEnabled, - Optional.of(spillerFactory), + Optional.of(new DummySpillerFactory()), new OrderingCompiler(typeOperators)); DriverContext driverContext = createDriverContext(memoryLimit); @@ -164,8 +166,17 @@ public void testSingleFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAdd assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testMultiFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testMultiFieldKey() + { + testMultiFieldKey(false, false, 0); + testMultiFieldKey(true, false, 8); + testMultiFieldKey(true, true, 8); + testMultiFieldKey(true, false, 0); + testMultiFieldKey(true, true, 0); + } + + private void testMultiFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, BIGINT) .row("a", 1L) @@ -185,7 +196,7 @@ public void testMultiFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddi ImmutableList.of(ASC_NULLS_LAST, DESC_NULLS_LAST), new PagesIndex.TestingFactory(false), spillEnabled, - Optional.of(spillerFactory), + Optional.of(new DummySpillerFactory()), new OrderingCompiler(typeOperators)); DriverContext driverContext = createDriverContext(memoryLimit); @@ -199,8 +210,17 @@ public void testMultiFieldKey(boolean spillEnabled, boolean revokeMemoryWhenAddi assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testReverseOrder(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testReverseOrder() + { + testReverseOrder(false, false, 0); + testReverseOrder(true, false, 8); + testReverseOrder(true, true, 8); + testReverseOrder(true, false, 0); + testReverseOrder(true, true, 0); + } + + private void testReverseOrder(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, DOUBLE) .row(1L, 0.1) @@ -220,7 +240,7 @@ public void testReverseOrder(boolean spillEnabled, boolean revokeMemoryWhenAddin ImmutableList.of(DESC_NULLS_LAST), new PagesIndex.TestingFactory(false), spillEnabled, - Optional.of(spillerFactory), + Optional.of(new DummySpillerFactory()), new OrderingCompiler(typeOperators)); DriverContext driverContext = createDriverContext(memoryLimit); @@ -259,7 +279,7 @@ public void testMemoryLimit() ImmutableList.of(ASC_NULLS_LAST), new PagesIndex.TestingFactory(false), false, - Optional.of(spillerFactory), + Optional.of(new DummySpillerFactory()), new OrderingCompiler(typeOperators)); assertThatThrownBy(() -> toPages(operatorFactory, driverContext, input)) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTopNPeerGroupLookup.java b/core/trino-main/src/test/java/io/trino/operator/TestTopNPeerGroupLookup.java index d88a5cb00c7a..a2c6c192026c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTopNPeerGroupLookup.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTopNPeerGroupLookup.java @@ -13,13 +13,10 @@ */ package io.trino.operator; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; -import java.util.List; -import static com.google.common.collect.Lists.cartesianProduct; import static org.assertj.core.api.Assertions.assertThat; public class TestTopNPeerGroupLookup @@ -41,28 +38,20 @@ public long hashCode(long rowId) private static final long UNMAPPED_GROUP_ID = Long.MIN_VALUE; private static final long DEFAULT_RETURN_VALUE = -1L; - @DataProvider - public static Object[][] parameters() + @Test + public void testCombinations() { - List expectedSizes = Arrays.asList(0, 1, 2, 3, 1_000); - List fillFactors = Arrays.asList(0.1f, 0.9f, 1f); - List totalGroupIds = Arrays.asList(1L, 10L); - List totalRowIds = Arrays.asList(1L, 1_000L); - - return to2DArray(cartesianProduct(expectedSizes, fillFactors, totalGroupIds, totalRowIds)); - } - - private static Object[][] to2DArray(List> nestedList) - { - Object[][] array = new Object[nestedList.size()][]; - for (int i = 0; i < nestedList.size(); i++) { - array[i] = nestedList.get(i).toArray(); + for (int expectedSize : Arrays.asList(0, 1, 2, 3, 1_000)) { + for (float fillFactor : Arrays.asList(0.1f, 0.9f, 1f)) { + testCombinations(expectedSize, fillFactor, 1L, 1L); + testCombinations(expectedSize, fillFactor, 10L, 1L); + testCombinations(expectedSize, fillFactor, 1L, 1_000L); + testCombinations(expectedSize, fillFactor, 10L, 1_000L); + } } - return array; } - @Test(dataProvider = "parameters") - public void testCombinations(int expectedSize, float fillFactor, long totalGroupIds, long totalRowIds) + private void testCombinations(int expectedSize, float fillFactor, long totalGroupIds, long totalRowIds) { TopNPeerGroupLookup lookup = new TopNPeerGroupLookup(expectedSize, fillFactor, HASH_STRATEGY, UNMAPPED_GROUP_ID, DEFAULT_RETURN_VALUE); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java index f9c3e5f6f4aa..1c55c5d36569 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestWindowOperator.java @@ -38,10 +38,10 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Optional; @@ -72,8 +72,12 @@ import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestWindowOperator { private static final TypeOperators TYPE_OPERATORS_CACHE = new TypeOperators(); @@ -100,40 +104,30 @@ public class TestWindowOperator private static final List LEAD = ImmutableList.of( window(new ReflectionWindowFunctionSupplier(3, LeadFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3, 4)); - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; - private DummySpillerFactory spillerFactory; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - @BeforeMethod - public void setUp() - { - executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - spillerFactory = new DummySpillerFactory(); - } - - @AfterMethod(alwaysRun = true) + @AfterAll public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); - spillerFactory = null; } - @DataProvider - public static Object[][] spillEnabled() + @Test + public void testMultipleOutputPages() { - return new Object[][] { - {false, false, 0}, - {true, false, 8}, - {true, true, 8}, - {true, false, 0}, - {true, true, 0}}; + testMultipleOutputPages(false, false, 0); + testMultipleOutputPages(true, false, 8); + testMultipleOutputPages(true, true, 8); + testMultipleOutputPages(true, false, 0); + testMultipleOutputPages(true, true, 0); } - @Test(dataProvider = "spillEnabled") - public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + private void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { + DummySpillerFactory spillerFactory = new DummySpillerFactory(); + // make operator produce multiple pages during finish phase int numberOfRows = 80_000; List input = rowPagesBuilder(BIGINT, DOUBLE) @@ -147,6 +141,7 @@ public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWh Ints.asList(), Ints.asList(0), ImmutableList.copyOf(new SortOrder[] {SortOrder.DESC_NULLS_FIRST}), + spillerFactory, spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -167,8 +162,17 @@ public void testMultipleOutputPages(boolean spillEnabled, boolean revokeMemoryWh .isTrue(); } - @Test(dataProvider = "spillEnabled") - public void testRowNumber(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testRowNumber() + { + testRowNumber(false, false, 0); + testRowNumber(true, false, 8); + testRowNumber(true, true, 8); + testRowNumber(true, false, 0); + testRowNumber(true, true, 0); + } + + private void testRowNumber(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, DOUBLE) .row(2L, 0.3) @@ -186,6 +190,7 @@ public void testRowNumber(boolean spillEnabled, boolean revokeMemoryWhenAddingPa Ints.asList(), Ints.asList(0), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -200,8 +205,17 @@ public void testRowNumber(boolean spillEnabled, boolean revokeMemoryWhenAddingPa assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testRowNumberPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testRowNumberPartition() + { + testRowNumberPartition(false, false, 0); + testRowNumberPartition(true, false, 8); + testRowNumberPartition(true, true, 8); + testRowNumberPartition(true, false, 0); + testRowNumberPartition(true, true, 0); + } + + private void testRowNumberPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, BIGINT, DOUBLE, BOOLEAN) .row("b", -1L, -0.1, true) @@ -219,6 +233,7 @@ public void testRowNumberPartition(boolean spillEnabled, boolean revokeMemoryWhe Ints.asList(0), Ints.asList(1), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -255,6 +270,7 @@ public void testRowNumberArbitrary() Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[] {}), + new DummySpillerFactory(), false); DriverContext driverContext = createDriverContext(); @@ -294,6 +310,7 @@ public void testRowNumberArbitraryWithSpill() Ints.asList(), Ints.asList(), ImmutableList.copyOf(new SortOrder[] {}), + new DummySpillerFactory(), true); DriverContext driverContext = createDriverContext(); @@ -311,7 +328,16 @@ public void testRowNumberArbitraryWithSpill() assertOperatorEquals(operatorFactory, driverContext, input, expected); } - @Test(dataProvider = "spillEnabled") + @Test + public void testDistinctPartitionAndPeers() + { + testDistinctPartitionAndPeers(false, false, 0); + testDistinctPartitionAndPeers(true, false, 8); + testDistinctPartitionAndPeers(true, true, 8); + testDistinctPartitionAndPeers(true, false, 0); + testDistinctPartitionAndPeers(true, true, 0); + } + public void testDistinctPartitionAndPeers(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(DOUBLE, DOUBLE) @@ -344,6 +370,7 @@ public void testDistinctPartitionAndPeers(boolean spillEnabled, boolean revokeMe Ints.asList(0), Ints.asList(1), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -372,35 +399,49 @@ public void testDistinctPartitionAndPeers(boolean spillEnabled, boolean revokeMe assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node memory limit of 10B.*") + @Test public void testMemoryLimit() { - List input = rowPagesBuilder(BIGINT, DOUBLE) - .row(1L, 0.1) - .row(2L, 0.2) - .pageBreak() - .row(-1L, -0.1) - .row(4L, 0.4) - .build(); - - DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, DataSize.ofBytes(10)) - .addPipelineContext(0, true, true, false) - .addDriverContext(); - - WindowOperatorFactory operatorFactory = createFactoryUnbounded( - ImmutableList.of(BIGINT, DOUBLE), - Ints.asList(1), - ROW_NUMBER, - Ints.asList(), - Ints.asList(0), - ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), - false); + assertThatThrownBy(() -> { + List input = rowPagesBuilder(BIGINT, DOUBLE) + .row(1L, 0.1) + .row(2L, 0.2) + .pageBreak() + .row(-1L, -0.1) + .row(4L, 0.4) + .build(); + + DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, DataSize.ofBytes(10)) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + + WindowOperatorFactory operatorFactory = createFactoryUnbounded( + ImmutableList.of(BIGINT, DOUBLE), + Ints.asList(1), + ROW_NUMBER, + Ints.asList(), + Ints.asList(0), + ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), + false); + + toPages(operatorFactory, driverContext, input); + }) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessageMatching("Query exceeded per-node memory limit of 10B.*"); + } - toPages(operatorFactory, driverContext, input); + @Test + public void testFirstValuePartition() + { + testFirstValuePartition(false, false, 0); + testFirstValuePartition(true, false, 8); + testFirstValuePartition(true, true, 8); + testFirstValuePartition(true, false, 0); + testFirstValuePartition(true, true, 0); } - @Test(dataProvider = "spillEnabled") - public void testFirstValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + private void testFirstValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, true, "") @@ -419,6 +460,7 @@ public void testFirstValuePartition(boolean spillEnabled, boolean revokeMemoryWh Ints.asList(0), Ints.asList(2), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -454,6 +496,7 @@ public void testClose() Ints.asList(0), Ints.asList(1), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), false); DriverContext driverContext = createDriverContext(1000); @@ -469,8 +512,17 @@ public void testClose() operator.close(); } - @Test(dataProvider = "spillEnabled") - public void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testLastValuePartition() + { + testLastValuePartition(false, false, 0); + testLastValuePartition(true, false, 8); + testLastValuePartition(true, true, 8); + testLastValuePartition(true, false, 0); + testLastValuePartition(true, true, 0); + } + + private void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, true, "") @@ -490,6 +542,7 @@ public void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhe Ints.asList(0), Ints.asList(2), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, VARCHAR, BIGINT, BOOLEAN, VARCHAR) @@ -503,8 +556,17 @@ public void testLastValuePartition(boolean spillEnabled, boolean revokeMemoryWhe assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testNthValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testNthValuePartition() + { + testNthValuePartition(false, false, 0); + testNthValuePartition(true, false, 8); + testNthValuePartition(true, true, 8); + testNthValuePartition(true, false, 0); + testNthValuePartition(true, true, 0); + } + + private void testNthValuePartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 2L, true, "") @@ -523,6 +585,7 @@ public void testNthValuePartition(boolean spillEnabled, boolean revokeMemoryWhen Ints.asList(0), Ints.asList(2), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -538,8 +601,17 @@ public void testNthValuePartition(boolean spillEnabled, boolean revokeMemoryWhen assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testLagPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testLagPartition() + { + testLagPartition(false, false, 0); + testLagPartition(true, false, 8); + testLagPartition(true, true, 8); + testLagPartition(true, false, 0); + testLagPartition(true, true, 0); + } + + private void testLagPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, VARCHAR, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 1L, "D", true, "") @@ -558,6 +630,7 @@ public void testLagPartition(boolean spillEnabled, boolean revokeMemoryWhenAddin Ints.asList(0), Ints.asList(2), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -573,8 +646,17 @@ public void testLagPartition(boolean spillEnabled, boolean revokeMemoryWhenAddin assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testLeadPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testLeadPartition() + { + testLeadPartition(false, false, 0); + testLeadPartition(true, false, 8); + testLeadPartition(true, true, 8); + testLeadPartition(true, false, 0); + testLeadPartition(true, true, 0); + } + + private void testLeadPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(VARCHAR, VARCHAR, BIGINT, BIGINT, VARCHAR, BOOLEAN, VARCHAR) .row("b", "A1", 1L, 1L, "D", true, "") @@ -593,6 +675,7 @@ public void testLeadPartition(boolean spillEnabled, boolean revokeMemoryWhenAddi Ints.asList(0), Ints.asList(2), ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}), + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -608,8 +691,17 @@ public void testLeadPartition(boolean spillEnabled, boolean revokeMemoryWhenAddi assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testPartiallyPreGroupedPartitionWithEmptyInput(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testPartiallyPreGroupedPartitionWithEmptyInput() + { + testPartiallyPreGroupedPartitionWithEmptyInput(false, false, 0); + testPartiallyPreGroupedPartitionWithEmptyInput(true, false, 8); + testPartiallyPreGroupedPartitionWithEmptyInput(true, true, 8); + testPartiallyPreGroupedPartitionWithEmptyInput(true, false, 0); + testPartiallyPreGroupedPartitionWithEmptyInput(true, true, 0); + } + + private void testPartiallyPreGroupedPartitionWithEmptyInput(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -625,6 +717,7 @@ public void testPartiallyPreGroupedPartitionWithEmptyInput(boolean spillEnabled, Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), 0, + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -634,8 +727,17 @@ public void testPartiallyPreGroupedPartitionWithEmptyInput(boolean spillEnabled, assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testPartiallyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testPartiallyPreGroupedPartition() + { + testPartiallyPreGroupedPartition(false, false, 0); + testPartiallyPreGroupedPartition(true, false, 8); + testPartiallyPreGroupedPartition(true, true, 8); + testPartiallyPreGroupedPartition(true, false, 0); + testPartiallyPreGroupedPartition(true, true, 0); + } + + private void testPartiallyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -659,6 +761,7 @@ public void testPartiallyPreGroupedPartition(boolean spillEnabled, boolean revok Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), 0, + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -674,8 +777,17 @@ public void testPartiallyPreGroupedPartition(boolean spillEnabled, boolean revok assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testFullyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testFullyPreGroupedPartition() + { + testFullyPreGroupedPartition(false, false, 0); + testFullyPreGroupedPartition(true, false, 8); + testFullyPreGroupedPartition(true, true, 8); + testFullyPreGroupedPartition(true, false, 0); + testFullyPreGroupedPartition(true, true, 0); + } + + private void testFullyPreGroupedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -700,6 +812,7 @@ public void testFullyPreGroupedPartition(boolean spillEnabled, boolean revokeMem Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), 0, + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -716,8 +829,17 @@ public void testFullyPreGroupedPartition(boolean spillEnabled, boolean revokeMem assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testFullyPreGroupedAndPartiallySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testFullyPreGroupedAndPartiallySortedPartition() + { + testFullyPreGroupedAndPartiallySortedPartition(false, false, 0); + testFullyPreGroupedAndPartiallySortedPartition(true, false, 8); + testFullyPreGroupedAndPartiallySortedPartition(true, true, 8); + testFullyPreGroupedAndPartiallySortedPartition(true, false, 0); + testFullyPreGroupedAndPartiallySortedPartition(true, true, 0); + } + + private void testFullyPreGroupedAndPartiallySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -743,6 +865,7 @@ public void testFullyPreGroupedAndPartiallySortedPartition(boolean spillEnabled, Ints.asList(3, 2), ImmutableList.of(SortOrder.ASC_NULLS_LAST, SortOrder.ASC_NULLS_LAST), 1, + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -760,8 +883,17 @@ public void testFullyPreGroupedAndPartiallySortedPartition(boolean spillEnabled, assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages); } - @Test(dataProvider = "spillEnabled") - public void testFullyPreGroupedAndFullySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) + @Test + public void testFullyPreGroupedAndFullySortedPartition() + { + testFullyPreGroupedAndFullySortedPartition(false, false, 0); + testFullyPreGroupedAndFullySortedPartition(true, false, 8); + testFullyPreGroupedAndFullySortedPartition(true, true, 8); + testFullyPreGroupedAndFullySortedPartition(true, false, 0); + testFullyPreGroupedAndFullySortedPartition(true, true, 0); + } + + private void testFullyPreGroupedAndFullySortedPartition(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit) { List input = rowPagesBuilder(BIGINT, VARCHAR, BIGINT, VARCHAR) .pageBreak() @@ -787,6 +919,7 @@ public void testFullyPreGroupedAndFullySortedPartition(boolean spillEnabled, boo Ints.asList(3), ImmutableList.of(SortOrder.ASC_NULLS_LAST), 1, + new DummySpillerFactory(), spillEnabled); DriverContext driverContext = createDriverContext(memoryLimit); @@ -844,6 +977,7 @@ private WindowOperatorFactory createFactoryUnbounded( List partitionChannels, List sortChannels, List sortOrder, + SpillerFactory spillerFactory, boolean spillEnabled) { return createFactoryUnbounded( @@ -855,6 +989,7 @@ private WindowOperatorFactory createFactoryUnbounded( sortChannels, sortOrder, 0, + spillerFactory, spillEnabled); } @@ -867,6 +1002,7 @@ private WindowOperatorFactory createFactoryUnbounded( List sortChannels, List sortOrder, int preSortedChannelPrefix, + DummySpillerFactory spillerFactory, boolean spillEnabled) { return new WindowOperatorFactory( diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java index a697a552f7e5..6ee64e55a1c9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestApproximateCountDistinct.java @@ -21,8 +21,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.Type; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Collections; @@ -52,35 +51,35 @@ protected int getUniqueValuesCount() return 20000; } - @DataProvider(name = "provideStandardErrors") - public Object[][] provideStandardErrors() + @Test + public void testNoPositions() { - return new Object[][] { - {0.0230}, // 2k buckets - {0.0115}, // 8k buckets - }; + assertCount(ImmutableList.of(), 0.0230, 0); + assertCount(ImmutableList.of(), 0.0115, 0); } - @Test(dataProvider = "provideStandardErrors") - public void testNoPositions(double maxStandardError) + @Test + public void testSinglePosition() { - assertCount(ImmutableList.of(), maxStandardError, 0); + assertCount(ImmutableList.of(randomValue()), 0.0230, 1); + assertCount(ImmutableList.of(randomValue()), 0.0115, 1); } - @Test(dataProvider = "provideStandardErrors") - public void testSinglePosition(double maxStandardError) + @Test + public void testAllPositionsNull() { - assertCount(ImmutableList.of(randomValue()), maxStandardError, 1); + assertCount(Collections.nCopies(100, null), 0.0230, 0); + assertCount(Collections.nCopies(100, null), 0.0115, 0); } - @Test(dataProvider = "provideStandardErrors") - public void testAllPositionsNull(double maxStandardError) + @Test + public void testMixedNullsAndNonNulls() { - assertCount(Collections.nCopies(100, null), maxStandardError, 0); + testMixedNullsAndNonNulls(0.0230); + testMixedNullsAndNonNulls(0.0115); } - @Test(dataProvider = "provideStandardErrors") - public void testMixedNullsAndNonNulls(double maxStandardError) + private void testMixedNullsAndNonNulls(double maxStandardError) { int uniques = getUniqueValuesCount(); List baseline = createRandomSample(uniques, (int) (uniques * 1.5)); @@ -96,8 +95,14 @@ public void testMixedNullsAndNonNulls(double maxStandardError) assertCount(mixed, maxStandardError, estimateGroupByCount(baseline, maxStandardError)); } - @Test(dataProvider = "provideStandardErrors") - public void testMultiplePositions(double maxStandardError) + @Test + public void testMultiplePositions() + { + testMultiplePositions(0.0230); + testMultiplePositions(0.0115); + } + + private void testMultiplePositions(double maxStandardError) { DescriptiveStatistics stats = new DescriptiveStatistics(); @@ -116,8 +121,14 @@ public void testMultiplePositions(double maxStandardError) assertLessThan(stats.getStandardDeviation(), 1.0e-2 + maxStandardError); } - @Test(dataProvider = "provideStandardErrors") - public void testMultiplePositionsPartial(double maxStandardError) + @Test + public void testMultiplePositionsPartial() + { + testMultiplePositionsPartial(0.0230); + testMultiplePositionsPartial(0.0115); + } + + private void testMultiplePositionsPartial(double maxStandardError) { for (int i = 0; i < 100; ++i) { int uniques = ThreadLocalRandom.current().nextInt(getUniqueValuesCount()) + 1; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java index 322028b2075f..f9623ca02748 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationMaskCompiler.java @@ -19,8 +19,7 @@ import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.ShortArrayBlock; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.Optional; @@ -32,23 +31,24 @@ public class TestAggregationMaskCompiler { - @DataProvider - public Object[][] maskBuilderSuppliers() + private static final Supplier INTERPRETED_MASK_BUILDER_SUPPLIER = () -> new InterpretedAggregationMaskBuilder(1); + private static final Supplier COMPILED_MASK_BUILDER_SUPPLIER = () -> { + try { + return generateAggregationMaskBuilder(1).newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + }; + + @Test + public void testSupplier() { - Supplier interpretedMaskBuilderSupplier = () -> new InterpretedAggregationMaskBuilder(1); - Supplier compiledMaskBuilderSupplier = () -> { - try { - return generateAggregationMaskBuilder(1).newInstance(); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - }; - return new Object[][] {{compiledMaskBuilderSupplier}, {interpretedMaskBuilderSupplier}}; + testSupplier(INTERPRETED_MASK_BUILDER_SUPPLIER); + testSupplier(COMPILED_MASK_BUILDER_SUPPLIER); } - @Test(dataProvider = "maskBuilderSuppliers") - public void testSupplier(Supplier maskBuilderSupplier) + private void testSupplier(Supplier maskBuilderSupplier) { // each builder produced from a supplier could be completely independent assertThat(maskBuilderSupplier.get()).isNotSameAs(maskBuilderSupplier.get()); @@ -74,8 +74,14 @@ public void testSupplier(Supplier maskBuilderSupplier) .isSameAs(maskBuilder.buildAggregationMask(pageWithNulls, Optional.empty()).getSelectedPositions()); } - @Test(dataProvider = "maskBuilderSuppliers") - public void testUnsetNulls(Supplier maskBuilderSupplier) + @Test + public void testUnsetNulls() + { + testUnsetNulls(INTERPRETED_MASK_BUILDER_SUPPLIER); + testUnsetNulls(COMPILED_MASK_BUILDER_SUPPLIER); + } + + private void testUnsetNulls(Supplier maskBuilderSupplier) { AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); AggregationMask aggregationMask = maskBuilder.buildAggregationMask(buildSingleColumnPage(0), Optional.empty()); @@ -107,8 +113,14 @@ public void testUnsetNulls(Supplier maskBuilderSupplier) } } - @Test(dataProvider = "maskBuilderSuppliers") - public void testApplyMask(Supplier maskBuilderSupplier) + @Test + public void testApplyMask() + { + testApplyMask(INTERPRETED_MASK_BUILDER_SUPPLIER); + testApplyMask(COMPILED_MASK_BUILDER_SUPPLIER); + } + + private void testApplyMask(Supplier maskBuilderSupplier) { AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); @@ -135,8 +147,14 @@ public void testApplyMask(Supplier maskBuilderSupplier) } } - @Test(dataProvider = "maskBuilderSuppliers") - public void testApplyMaskNulls(Supplier maskBuilderSupplier) + @Test + public void testApplyMaskNulls() + { + testApplyMaskNulls(INTERPRETED_MASK_BUILDER_SUPPLIER); + testApplyMaskNulls(COMPILED_MASK_BUILDER_SUPPLIER); + } + + private void testApplyMaskNulls(Supplier maskBuilderSupplier) { AggregationMaskBuilder maskBuilder = maskBuilderSupplier.get(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateCountDistinctBoolean.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateCountDistinctBoolean.java index a766c9c019e5..ede51acce968 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateCountDistinctBoolean.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateCountDistinctBoolean.java @@ -14,13 +14,10 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Booleans; import io.trino.spi.type.Type; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; -import java.util.List; import java.util.concurrent.ThreadLocalRandom; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -40,24 +37,15 @@ protected Object randomValue() return ThreadLocalRandom.current().nextBoolean(); } - @DataProvider(name = "inputSequences") - public Object[][] inputSequences() - { - return new Object[][] { - {true}, - {false}, - {true, false}, - {true, true, true}, - {false, false, false}, - {true, false, true, false}, - }; - } - - @Test(dataProvider = "inputSequences") - public void testNonEmptyInputs(boolean... inputSequence) + @Test + public void testNonEmptyInputs() { - List values = Booleans.asList(inputSequence); - assertCount(values, 0, distinctCount(values)); + assertCount(Booleans.asList(true), 0, 1); + assertCount(Booleans.asList(false), 0, 1); + assertCount(Booleans.asList(true, false), 0, 2); + assertCount(Booleans.asList(true, true, true), 0, 1); + assertCount(Booleans.asList(false, false, false), 0, 1); + assertCount(Booleans.asList(true, false, true, false), 0, 2); } @Test @@ -66,11 +54,6 @@ public void testNoInput() assertCount(ImmutableList.of(), 0, 0); } - private long distinctCount(List inputSequence) - { - return ImmutableSet.copyOf(inputSequence).size(); - } - @Override protected int getUniqueValuesCount() { diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateSetGenericBoolean.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateSetGenericBoolean.java index 624ba42722b5..ff49acd651cb 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateSetGenericBoolean.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestApproximateSetGenericBoolean.java @@ -14,11 +14,9 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Booleans; import io.trino.spi.type.Type; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.List; import java.util.concurrent.ThreadLocalRandom; @@ -40,29 +38,15 @@ protected Object randomValue() return ThreadLocalRandom.current().nextBoolean(); } - @DataProvider(name = "inputSequences") - public Object[][] inputSequences() + @Test + public void testNonEmptyInputs() { - return new Object[][] { - {true}, - {false}, - {true, false}, - {true, true, true}, - {false, false, false}, - {true, false, true, false}, - }; - } - - @Test(dataProvider = "inputSequences") - public void testNonEmptyInputs(boolean... inputSequence) - { - List values = Booleans.asList(inputSequence); - assertCount(values, distinctCount(values)); - } - - private long distinctCount(List inputSequence) - { - return ImmutableSet.copyOf(inputSequence).size(); + assertCount(Booleans.asList(true), 1); + assertCount(Booleans.asList(false), 1); + assertCount(Booleans.asList(true, false), 2); + assertCount(Booleans.asList(true, true, true), 1); + assertCount(Booleans.asList(false, false, false), 1); + assertCount(Booleans.asList(true, false, true, false), 2); } @Override diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java index a81785db935c..8897405f6150 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java @@ -21,9 +21,7 @@ import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.math.BigDecimal; import java.math.BigInteger; @@ -38,7 +36,6 @@ import static java.math.RoundingMode.HALF_UP; import static org.assertj.core.api.Assertions.assertThat; -@Test(singleThreaded = true) public class TestDecimalAverageAggregation { private static final BigInteger TWO = new BigInteger("2"); @@ -46,17 +43,11 @@ public class TestDecimalAverageAggregation private static final BigInteger TWO_HUNDRED = new BigInteger("200"); private static final DecimalType TYPE = createDecimalType(38, 0); - private LongDecimalWithOverflowAndLongState state; - - @BeforeMethod - public void setUp() - { - state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); - } - @Test public void testOverflow() { + LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); + addToState(state, TWO.pow(126)); assertThat(state.getLong()).isEqualTo(1); @@ -69,12 +60,14 @@ public void testOverflow() assertThat(state.getOverflow()).isEqualTo(1); assertThat(getDecimal(state)).isEqualTo(Int128.valueOf(1L << 63, 0)); - assertAverageEquals(TWO.pow(126)); + assertAverageEquals(state, TWO.pow(126)); } @Test public void testUnderflow() { + LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); + addToState(state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger()); assertThat(state.getLong()).isEqualTo(1); @@ -87,12 +80,14 @@ public void testUnderflow() assertThat(state.getOverflow()).isEqualTo(-1); assertThat(getDecimal(state)).isEqualTo(Int128.valueOf(0x698966AF4AF2770BL, 0xECEBBB8000000002L)); - assertAverageEquals(Decimals.MIN_UNSCALED_DECIMAL.toBigInteger()); + assertAverageEquals(state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger()); } @Test public void testUnderflowAfterOverflow() { + LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); + addToState(state, TWO.pow(126)); addToState(state, TWO.pow(126)); addToState(state, TWO.pow(125)); @@ -107,12 +102,14 @@ public void testUnderflowAfterOverflow() assertThat(state.getOverflow()).isEqualTo(0); assertThat(getDecimal(state)).isEqualTo(Int128.valueOf(TWO.pow(125).negate())); - assertAverageEquals(TWO.pow(125).negate().divide(BigInteger.valueOf(6))); + assertAverageEquals(state, TWO.pow(125).negate().divide(BigInteger.valueOf(6))); } @Test public void testCombineOverflow() { + LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); + addToState(state, TWO.pow(126)); addToState(state, TWO.pow(126)); @@ -133,12 +130,14 @@ public void testCombineOverflow() .add(TWO.pow(126)) .divide(BigInteger.valueOf(4)); - assertAverageEquals(expectedAverage); + assertAverageEquals(state, expectedAverage); } @Test public void testCombineUnderflow() { + LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState(); + addToState(state, TWO.pow(125).negate()); addToState(state, TWO.pow(126).negate()); @@ -160,14 +159,39 @@ public void testCombineUnderflow() .negate() .divide(BigInteger.valueOf(4)); - assertAverageEquals(expectedAverage); + assertAverageEquals(state, expectedAverage); } - @Test(dataProvider = "testNoOverflowDataProvider") - public void testNoOverflow(List numbers) + @Test + public void testNoOverflow() { - testNoOverflow(createDecimalType(38, 0), numbers); - testNoOverflow(createDecimalType(38, 2), numbers); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TEN.pow(37), ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TEN.pow(37).negate(), ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO, ONE)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(ZERO, ONE)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO.negate(), ONE.negate())); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(ONE.negate(), ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(ONE.negate(), ZERO, ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO.negate(), ZERO, ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO.negate(), ZERO)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO_HUNDRED, ONE_HUNDRED)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(ZERO, ONE_HUNDRED)); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(TWO_HUNDRED.negate(), ONE_HUNDRED.negate())); + testNoOverflow(createDecimalType(38, 0), ImmutableList.of(ONE_HUNDRED.negate(), ZERO)); + + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TEN.pow(37), ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TEN.pow(37).negate(), ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO, ONE)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(ZERO, ONE)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO.negate(), ONE.negate())); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(ONE.negate(), ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(ONE.negate(), ZERO, ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO.negate(), ZERO, ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO.negate(), ZERO)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO_HUNDRED, ONE_HUNDRED)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(ZERO, ONE_HUNDRED)); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(TWO_HUNDRED.negate(), ONE_HUNDRED.negate())); + testNoOverflow(createDecimalType(38, 2), ImmutableList.of(ONE_HUNDRED.negate(), ZERO)); } private void testNoOverflow(DecimalType type, List numbers) @@ -185,40 +209,15 @@ private void testNoOverflow(DecimalType type, List numbers) assertThat(decodeBigDecimal(type, average(state, type))).isEqualTo(expectedAverage); } - @DataProvider - public static Object[][] testNoOverflowDataProvider() - { - return new Object[][] { - {ImmutableList.of(TEN.pow(37), ZERO)}, - {ImmutableList.of(TEN.pow(37).negate(), ZERO)}, - {ImmutableList.of(TWO, ONE)}, - {ImmutableList.of(ZERO, ONE)}, - {ImmutableList.of(TWO.negate(), ONE.negate())}, - {ImmutableList.of(ONE.negate(), ZERO)}, - {ImmutableList.of(ONE.negate(), ZERO, ZERO)}, - {ImmutableList.of(TWO.negate(), ZERO, ZERO)}, - {ImmutableList.of(TWO.negate(), ZERO)}, - {ImmutableList.of(TWO_HUNDRED, ONE_HUNDRED)}, - {ImmutableList.of(ZERO, ONE_HUNDRED)}, - {ImmutableList.of(TWO_HUNDRED.negate(), ONE_HUNDRED.negate())}, - {ImmutableList.of(ONE_HUNDRED.negate(), ZERO)} - }; - } - private static BigDecimal decodeBigDecimal(DecimalType type, Int128 average) { BigInteger unscaledVal = average.toBigInteger(); return new BigDecimal(unscaledVal, type.getScale(), new MathContext(type.getPrecision())); } - private void assertAverageEquals(BigInteger expectedAverage) - { - assertAverageEquals(expectedAverage, TYPE); - } - - private void assertAverageEquals(BigInteger expectedAverage, DecimalType type) + private void assertAverageEquals(LongDecimalWithOverflowAndLongState state, BigInteger expectedAverage) { - assertThat(average(state, type).toBigInteger()).isEqualTo(expectedAverage); + assertThat(average(state, TYPE).toBigInteger()).isEqualTo(expectedAverage); } private static void addToState(LongDecimalWithOverflowAndLongState state, BigInteger value) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java index 257d63587ba7..fbdafaa434bd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java @@ -16,8 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.VariableWidthBlockBuilder; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -25,8 +24,28 @@ public class TestLongDecimalWithOverflowAndLongStateSerializer { private static final LongDecimalWithOverflowAndLongStateFactory STATE_FACTORY = new LongDecimalWithOverflowAndLongStateFactory(); - @Test(dataProvider = "input") - public void testSerde(long low, long high, long overflow, long count, int expectedLength) + @Test + public void testSerde() + { + testSerde(3, 0, 0, 1, 1); + testSerde(3, 5, 0, 1, 2); + testSerde(3, 5, 7, 1, 4); + testSerde(3, 0, 0, 2, 3); + testSerde(3, 5, 0, 2, 4); + testSerde(3, 5, 7, 2, 4); + testSerde(3, 0, 7, 1, 3); + testSerde(3, 0, 7, 2, 3); + testSerde(0, 0, 0, 1, 1); + testSerde(0, 5, 0, 1, 2); + testSerde(0, 5, 7, 1, 4); + testSerde(0, 0, 0, 2, 3); + testSerde(0, 5, 0, 2, 4); + testSerde(0, 5, 7, 2, 4); + testSerde(0, 0, 7, 1, 3); + testSerde(0, 0, 7, 2, 3); + } + + private void testSerde(long low, long high, long overflow, long count, int expectedLength) { LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState(); state.getDecimalArray()[0] = high; @@ -66,27 +85,4 @@ private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAnd serializer.deserialize(serialized, 0, outState); return outState; } - - @DataProvider - public Object[][] input() - { - return new Object[][] { - {3, 0, 0, 1, 1}, - {3, 5, 0, 1, 2}, - {3, 5, 7, 1, 4}, - {3, 0, 0, 2, 3}, - {3, 5, 0, 2, 4}, - {3, 5, 7, 2, 4}, - {3, 0, 7, 1, 3}, - {3, 0, 7, 2, 3}, - {0, 0, 0, 1, 1}, - {0, 5, 0, 1, 2}, - {0, 5, 7, 1, 4}, - {0, 0, 0, 2, 3}, - {0, 5, 0, 2, 4}, - {0, 5, 7, 2, 4}, - {0, 0, 7, 1, 3}, - {0, 0, 7, 2, 3} - }; - } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java index 57638bd7bfd7..4c33f3016bfe 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java @@ -16,8 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.VariableWidthBlockBuilder; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -25,8 +24,20 @@ public class TestLongDecimalWithOverflowStateSerializer { private static final LongDecimalWithOverflowStateFactory STATE_FACTORY = new LongDecimalWithOverflowStateFactory(); - @Test(dataProvider = "input") - public void testSerde(long low, long high, long overflow, int expectedLength) + @Test + public void testSerde() + { + testSerde(3, 0, 0, 1); + testSerde(3, 5, 0, 2); + testSerde(3, 5, 7, 3); + testSerde(3, 0, 7, 3); + testSerde(0, 0, 0, 1); + testSerde(0, 5, 0, 2); + testSerde(0, 5, 7, 3); + testSerde(0, 0, 7, 3); + } + + private void testSerde(long low, long high, long overflow, int expectedLength) { LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState(); state.getDecimalArray()[0] = high; @@ -66,19 +77,4 @@ private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState stat serializer.deserialize(serialized, 0, outState); return outState; } - - @DataProvider - public Object[][] input() - { - return new Object[][] { - {3, 0, 0, 1}, - {3, 5, 0, 2}, - {3, 5, 7, 3}, - {3, 0, 7, 3}, - {0, 0, 0, 1}, - {0, 5, 0, 2}, - {0, 5, 7, 3}, - {0, 0, 7, 3} - }; - } } diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java index 6f5b3763addc..9f8cdedef0c0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java +++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java @@ -41,9 +41,10 @@ import io.trino.sql.planner.PartitioningHandle; import io.trino.testing.TestingTransactionHandle; import io.trino.util.FinalizerService; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Optional; @@ -73,8 +74,11 @@ import static io.trino.testing.TestingSession.testSessionBuilder; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) +@Execution(SAME_THREAD) public class TestLocalExchange { private static final List TYPES = ImmutableList.of(BIGINT); @@ -88,8 +92,9 @@ public class TestLocalExchange private final ConcurrentMap partitionManagers = new ConcurrentHashMap<>(); private NodePartitioningManager nodePartitioningManager; + private final PartitioningHandle customScalingPartitioningHandle = getCustomScalingPartitioningHandle(); - @BeforeMethod + @BeforeEach public void setUp() { NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory( @@ -332,8 +337,14 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded() }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testScalingWithTwoDifferentPartitions(PartitioningHandle partitioningHandle) + @Test + public void testScalingWithTwoDifferentPartitions() + { + testScalingWithTwoDifferentPartitions(customScalingPartitioningHandle); + testScalingWithTwoDifferentPartitions(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testScalingWithTwoDifferentPartitions(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( nodePartitioningManager, @@ -533,8 +544,14 @@ public void testNoWriterScalingWhenOnlyWriterScalingMinDataProcessedLimitIsExcee }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testScalingForSkewedWriters(PartitioningHandle partitioningHandle) + @Test + public void testScalingForSkewedWriters() + { + testScalingForSkewedWriters(customScalingPartitioningHandle); + testScalingForSkewedWriters(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testScalingForSkewedWriters(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( nodePartitioningManager, @@ -623,8 +640,14 @@ public void testScalingForSkewedWriters(PartitioningHandle partitioningHandle) }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testNoScalingWhenDataWrittenIsLessThanMinFileSize(PartitioningHandle partitioningHandle) + @Test + public void testNoScalingWhenDataWrittenIsLessThanMinFileSize() + { + testNoScalingWhenDataWrittenIsLessThanMinFileSize(customScalingPartitioningHandle); + testNoScalingWhenDataWrittenIsLessThanMinFileSize(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testNoScalingWhenDataWrittenIsLessThanMinFileSize(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( nodePartitioningManager, @@ -687,8 +710,14 @@ public void testNoScalingWhenDataWrittenIsLessThanMinFileSize(PartitioningHandle }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testNoScalingWhenBufferUtilizationIsLessThanLimit(PartitioningHandle partitioningHandle) + @Test + public void testNoScalingWhenBufferUtilizationIsLessThanLimit() + { + testNoScalingWhenBufferUtilizationIsLessThanLimit(customScalingPartitioningHandle); + testNoScalingWhenBufferUtilizationIsLessThanLimit(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testNoScalingWhenBufferUtilizationIsLessThanLimit(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( nodePartitioningManager, @@ -751,8 +780,14 @@ public void testNoScalingWhenBufferUtilizationIsLessThanLimit(PartitioningHandle }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(PartitioningHandle partitioningHandle) + @Test + public void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit() + { + testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(customScalingPartitioningHandle); + testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(PartitioningHandle partitioningHandle) { AtomicLong totalMemoryUsed = new AtomicLong(); LocalExchange localExchange = new LocalExchange( @@ -832,8 +867,14 @@ public void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(PartitioningHandl }); } - @Test(dataProvider = "scalingPartitionHandles") - public void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandle partitioningHandle) + @Test + public void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit() + { + testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(customScalingPartitioningHandle); + testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandle partitioningHandle) { AtomicLong totalMemoryUsed = new AtomicLong(); LocalExchange localExchange = new LocalExchange( @@ -1316,12 +1357,6 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed() }); } - @DataProvider - public Object[][] scalingPartitionHandles() - { - return new Object[][] {{SCALED_WRITER_HASH_DISTRIBUTION}, {getCustomScalingPartitioningHandle()}}; - } - private PartitioningHandle getCustomScalingPartitioningHandle() { ConnectorPartitioningHandle connectorPartitioningHandle = new ConnectorPartitioningHandle() {}; diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java index 3f119282e95f..514c2dcc5c6c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java @@ -318,14 +318,16 @@ public static class DummySpillerFactory private volatile boolean failSpill; private volatile boolean failUnspill; - public void failSpill() + public DummySpillerFactory failSpill() { failSpill = true; + return this; } - public void failUnspill() + public DummySpillerFactory failUnspill() { failUnspill = true; + return this; } @Override diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java index 091b7f3b4973..e4ac6f8ec309 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java @@ -62,20 +62,18 @@ import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; import io.trino.util.FinalizerService; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -106,19 +104,21 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; -import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.nCopies; import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestHashJoinOperator { private static final int PARTITION_COUNT = 4; @@ -126,62 +126,37 @@ public class TestHashJoinOperator private static final PartitioningSpillerFactory PARTITIONING_SPILLER_FACTORY = new GenericPartitioningSpillerFactory(SINGLE_STREAM_SPILLER_FACTORY); private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; - private NodePartitioningManager nodePartitioningManager; - - @BeforeMethod - public void setUp() - { - // Before/AfterMethod is chosen here because the executor needs to be shutdown - // after every single test case to terminate outstanding threads, if any. - - // The line below is the same as newCachedThreadPool(daemonThreadsNamed(...)) except RejectionExecutionHandler. - // RejectionExecutionHandler is set to DiscardPolicy (instead of the default AbortPolicy) here. - // Otherwise, a large number of RejectedExecutionException will flood logging, resulting in Travis failure. - executor = new ThreadPoolExecutor( - 0, - Integer.MAX_VALUE, - 60L, - SECONDS, - new SynchronousQueue<>(), - daemonThreadsNamed("test-executor-%s"), - new ThreadPoolExecutor.DiscardPolicy()); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - - NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(new FinalizerService()))); - nodePartitioningManager = new NodePartitioningManager( - nodeScheduler, - TYPE_OPERATORS, - CatalogServiceProvider.fail()); - } - - @AfterMethod(alwaysRun = true) + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); + private final NodePartitioningManager nodePartitioningManager = new NodePartitioningManager( + new NodeScheduler(new UniformNodeSelectorFactory( + new InMemoryNodeManager(), + new NodeSchedulerConfig().setIncludeCoordinator(true), + new NodeTaskMap(new FinalizerService()))), + TYPE_OPERATORS, + CatalogServiceProvider.fail()); + + @AfterAll public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); } - @DataProvider(name = "hashJoinTestValues") - public static Object[][] hashJoinTestValuesProvider() + @Test + public void testInnerJoin() { - return new Object[][] { - {true, true, true}, - {true, true, false}, - {true, false, true}, - {true, false, false}, - {false, true, true}, - {false, true, false}, - {false, false, true}, - {false, false, false}}; + testInnerJoin(true, true, true); + testInnerJoin(true, true, false); + testInnerJoin(true, false, true); + testInnerJoin(true, false, false); + testInnerJoin(false, true, true); + testInnerJoin(false, true, false); + testInnerJoin(false, false, true); + testInnerJoin(false, false, false); } - @Test(dataProvider = "hashJoinTestValues") - public void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + private void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -382,90 +357,76 @@ private enum WhenSpill DURING_BUILD, AFTER_BUILD, DURING_USAGE, NEVER } - private enum WhenSpillFails - { - SPILL_BUILD, SPILL_JOIN, UNSPILL_BUILD, UNSPILL_JOIN - } - - @DataProvider - public Object[][] joinWithSpillValues() - { - return joinWithSpillParameters(true).stream() - .map(List::toArray) - .toArray(Object[][]::new); - } - - @DataProvider - public Object[][] joinWithFailingSpillValues() + @Test + public void testInnerJoinWithSpill() + throws Exception { - List> spillFailValues = Arrays.stream(WhenSpillFails.values()) - .map(ImmutableList::of) - .collect(toList()); - return product(joinWithSpillParameters(false), spillFailValues).stream() - .map(List::toArray) - .toArray(Object[][]::new); + for (boolean probeHashEnabled : ImmutableList.of(false, true)) { + // spill all + innerJoinWithSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.NEVER), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.DURING_BUILD), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.AFTER_BUILD), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.DURING_USAGE), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + + // spill one + innerJoinWithSpill(probeHashEnabled, concat(singletonList(WhenSpill.DURING_BUILD), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER)), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, concat(singletonList(WhenSpill.AFTER_BUILD), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER)), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, concat(singletonList(WhenSpill.DURING_USAGE), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER)), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + + innerJoinWithSpill(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.AFTER_BUILD), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER)), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + innerJoinWithSpill(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.DURING_USAGE), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER)), SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); + } } - private static List> joinWithSpillParameters(boolean allowNoSpill) + @Test + public void testInnerJoinWithFailingSpill() { - List> result = new ArrayList<>(); for (boolean probeHashEnabled : ImmutableList.of(false, true)) { - for (WhenSpill whenSpill : WhenSpill.values()) { - // spill all - if (allowNoSpill || whenSpill != WhenSpill.NEVER) { - result.add(ImmutableList.of(probeHashEnabled, nCopies(PARTITION_COUNT, whenSpill))); - } - - if (whenSpill != WhenSpill.NEVER) { - // spill one - result.add(ImmutableList.of(probeHashEnabled, concat(singletonList(whenSpill), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER)))); - } - } - - result.add(ImmutableList.of(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.AFTER_BUILD), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER)))); - result.add(ImmutableList.of(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.DURING_USAGE), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER)))); + // spill all + testInnerJoinWithFailingSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.DURING_USAGE)); + testInnerJoinWithFailingSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.DURING_BUILD)); + testInnerJoinWithFailingSpill(probeHashEnabled, nCopies(PARTITION_COUNT, WhenSpill.AFTER_BUILD)); + + // spill one + testInnerJoinWithFailingSpill(probeHashEnabled, concat(singletonList(WhenSpill.DURING_USAGE), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER))); + testInnerJoinWithFailingSpill(probeHashEnabled, concat(singletonList(WhenSpill.DURING_BUILD), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER))); + testInnerJoinWithFailingSpill(probeHashEnabled, concat(singletonList(WhenSpill.AFTER_BUILD), nCopies(PARTITION_COUNT - 1, WhenSpill.NEVER))); + + testInnerJoinWithFailingSpill(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.AFTER_BUILD), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER))); + testInnerJoinWithFailingSpill(probeHashEnabled, concat(asList(WhenSpill.DURING_BUILD, WhenSpill.DURING_USAGE), nCopies(PARTITION_COUNT - 2, WhenSpill.NEVER))); } - return result; } - @Test(dataProvider = "joinWithSpillValues") - public void testInnerJoinWithSpill(boolean probeHashEnabled, List whenSpill) - throws Exception + private void testInnerJoinWithFailingSpill(boolean probeHashEnabled, List whenSpill) { - innerJoinWithSpill(probeHashEnabled, whenSpill, SINGLE_STREAM_SPILLER_FACTORY, PARTITIONING_SPILLER_FACTORY); - } + assertThatThrownBy(() -> innerJoinWithSpill( + probeHashEnabled, + whenSpill, + new DummySpillerFactory().failSpill(), + new GenericPartitioningSpillerFactory(new DummySpillerFactory()))) + .isInstanceOf(RuntimeException.class) + .hasMessage("Spill failed"); - @Test(dataProvider = "joinWithFailingSpillValues") - public void testInnerJoinWithFailingSpill(boolean probeHashEnabled, List whenSpill, WhenSpillFails whenSpillFails) - { - DummySpillerFactory buildSpillerFactory = new DummySpillerFactory(); - DummySpillerFactory joinSpillerFactory = new DummySpillerFactory(); - PartitioningSpillerFactory partitioningSpillerFactory = new GenericPartitioningSpillerFactory(joinSpillerFactory); - - String expectedMessage; - switch (whenSpillFails) { - case SPILL_BUILD: - buildSpillerFactory.failSpill(); - expectedMessage = "Spill failed"; - break; - case SPILL_JOIN: - joinSpillerFactory.failSpill(); - expectedMessage = "Spill failed"; - break; - case UNSPILL_BUILD: - buildSpillerFactory.failUnspill(); - expectedMessage = "Unspill failed"; - break; - case UNSPILL_JOIN: - joinSpillerFactory.failUnspill(); - expectedMessage = "Unspill failed"; - break; - default: - throw new IllegalArgumentException(format("Unsupported option: %s", whenSpillFails)); - } - assertThatThrownBy(() -> innerJoinWithSpill(probeHashEnabled, whenSpill, buildSpillerFactory, partitioningSpillerFactory)) + assertThatThrownBy(() -> innerJoinWithSpill(probeHashEnabled, + whenSpill, + new DummySpillerFactory(), + new GenericPartitioningSpillerFactory(new DummySpillerFactory().failSpill()))) + .isInstanceOf(RuntimeException.class) + .hasMessage("Spill failed"); + + assertThatThrownBy(() -> innerJoinWithSpill(probeHashEnabled, + whenSpill, + new DummySpillerFactory().failUnspill(), + new GenericPartitioningSpillerFactory(new DummySpillerFactory()))) + .isInstanceOf(RuntimeException.class) + .hasMessage("Unspill failed"); + + assertThatThrownBy(() -> innerJoinWithSpill(probeHashEnabled, + whenSpill, + new DummySpillerFactory(), + new GenericPartitioningSpillerFactory(new DummySpillerFactory().failUnspill()))) .isInstanceOf(RuntimeException.class) - .hasMessage(expectedMessage); + .hasMessage("Unspill failed"); } private void innerJoinWithSpill(boolean probeHashEnabled, List whenSpill, SingleStreamSpillerFactory buildSpillerFactory, PartitioningSpillerFactory joinSpillerFactory) @@ -644,7 +605,8 @@ private static MaterializedResult getProperColumns(Operator joinOperator, List(), - daemonThreadsNamed("test-executor-%s"), - new ThreadPoolExecutor.DiscardPolicy()); - scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - - NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(new FinalizerService()))); - nodePartitioningManager = new NodePartitioningManager( - nodeScheduler, - TYPE_OPERATORS, - CatalogServiceProvider.fail()); - } - - @AfterMethod(alwaysRun = true) + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); + private final NodePartitioningManager nodePartitioningManager = new NodePartitioningManager( + new NodeScheduler(new UniformNodeSelectorFactory( + new InMemoryNodeManager(), + new NodeSchedulerConfig().setIncludeCoordinator(true), + new NodeTaskMap(new FinalizerService()))), + TYPE_OPERATORS, + CatalogServiceProvider.fail()); + + @AfterAll public void tearDown() { executor.shutdownNow(); scheduledExecutor.shutdownNow(); } - @Test(dataProvider = "hashJoinTestValues") - public void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test + public void testInnerJoin() + { + testInnerJoin(false, false, false); + testInnerJoin(false, false, true); + testInnerJoin(false, true, false); + testInnerJoin(false, true, true); + testInnerJoin(true, false, false); + testInnerJoin(true, false, true); + testInnerJoin(true, true, false); + testInnerJoin(true, true, true); + } + + private void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -183,8 +171,20 @@ public void testInnerJoin(boolean parallelBuild, boolean probeHashEnabled, boole assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinRleProbeTestValues") - public void testInnerJoinWithRunLengthEncodedProbe(boolean withFilter, boolean probeHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithRunLengthEncodedProbe() + { + testInnerJoinWithRunLengthEncodedProbe(false, false, false); + testInnerJoinWithRunLengthEncodedProbe(false, false, true); + testInnerJoinWithRunLengthEncodedProbe(false, true, false); + testInnerJoinWithRunLengthEncodedProbe(false, true, true); + testInnerJoinWithRunLengthEncodedProbe(true, false, false); + testInnerJoinWithRunLengthEncodedProbe(true, false, true); + testInnerJoinWithRunLengthEncodedProbe(true, true, false); + testInnerJoinWithRunLengthEncodedProbe(true, true, true); + } + + private void testInnerJoinWithRunLengthEncodedProbe(boolean withFilter, boolean probeHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -252,14 +252,14 @@ private JoinOperatorInfo getJoinOperatorInfo(DriverContext driverContext) return (JoinOperatorInfo) getOnlyElement(driverContext.getOperatorStats()).getInfo(); } - @DataProvider(name = "hashJoinRleProbeTestValues") - public static Object[][] hashJoinRleProbeTestValuesProvider() + @Test + public void testUnwrapsLazyBlocks() { - return cartesianProduct(trueFalse(), trueFalse(), trueFalse()); + testUnwrapsLazyBlocks(false); + testUnwrapsLazyBlocks(true); } - @Test(dataProvider = "singleBigintLookupSourceProvider") - public void testUnwrapsLazyBlocks(boolean singleBigintLookupSource) + private void testUnwrapsLazyBlocks(boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); @@ -304,8 +304,14 @@ public void testUnwrapsLazyBlocks(boolean singleBigintLookupSource) assertThat(output.getBlock(1) instanceof LazyBlock).isFalse(); } - @Test(dataProvider = "singleBigintLookupSourceProvider") - public void testYield(boolean singleBigintLookupSource) + @Test + public void testYield() + { + testYield(false); + testYield(true); + } + + private void testYield(boolean singleBigintLookupSource) { // create a filter function that yields for every probe match // verify we will yield #match times totally @@ -375,8 +381,28 @@ public void testYield(boolean singleBigintLookupSource) assertThat(output.getPositionCount()).isEqualTo(entries); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithNullProbe() + { + testInnerJoinWithNullProbe(false, false, false, false); + testInnerJoinWithNullProbe(false, false, false, true); + testInnerJoinWithNullProbe(false, false, true, false); + testInnerJoinWithNullProbe(false, false, true, true); + testInnerJoinWithNullProbe(false, true, false, false); + testInnerJoinWithNullProbe(false, true, false, true); + testInnerJoinWithNullProbe(false, true, true, false); + testInnerJoinWithNullProbe(false, true, true, true); + testInnerJoinWithNullProbe(true, false, false, false); + testInnerJoinWithNullProbe(true, false, false, true); + testInnerJoinWithNullProbe(true, false, true, false); + testInnerJoinWithNullProbe(true, false, true, true); + testInnerJoinWithNullProbe(true, true, false, false); + testInnerJoinWithNullProbe(true, true, false, true); + testInnerJoinWithNullProbe(true, true, true, false); + testInnerJoinWithNullProbe(true, true, true, true); + } + + private void testInnerJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -415,8 +441,28 @@ public void testInnerJoinWithNullProbe(boolean parallelBuild, boolean probeHashE assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithOutputSingleMatch(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithOutputSingleMatch() + { + testInnerJoinWithOutputSingleMatch(false, false, false, false); + testInnerJoinWithOutputSingleMatch(false, false, false, true); + testInnerJoinWithOutputSingleMatch(false, false, true, false); + testInnerJoinWithOutputSingleMatch(false, false, true, true); + testInnerJoinWithOutputSingleMatch(false, true, false, false); + testInnerJoinWithOutputSingleMatch(false, true, false, true); + testInnerJoinWithOutputSingleMatch(false, true, true, false); + testInnerJoinWithOutputSingleMatch(false, true, true, true); + testInnerJoinWithOutputSingleMatch(true, false, false, false); + testInnerJoinWithOutputSingleMatch(true, false, false, true); + testInnerJoinWithOutputSingleMatch(true, false, true, false); + testInnerJoinWithOutputSingleMatch(true, false, true, true); + testInnerJoinWithOutputSingleMatch(true, true, false, false); + testInnerJoinWithOutputSingleMatch(true, true, false, true); + testInnerJoinWithOutputSingleMatch(true, true, true, false); + testInnerJoinWithOutputSingleMatch(true, true, true, true); + } + + private void testInnerJoinWithOutputSingleMatch(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); // build factory @@ -451,8 +497,20 @@ public void testInnerJoinWithOutputSingleMatch(boolean parallelBuild, boolean pr assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithNullBuild() + { + testInnerJoinWithNullBuild(false, false, false); + testInnerJoinWithNullBuild(false, false, true); + testInnerJoinWithNullBuild(false, true, false); + testInnerJoinWithNullBuild(false, true, true); + testInnerJoinWithNullBuild(true, false, false); + testInnerJoinWithNullBuild(true, false, true); + testInnerJoinWithNullBuild(true, true, false); + testInnerJoinWithNullBuild(true, true, true); + } + + private void testInnerJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -491,8 +549,20 @@ public void testInnerJoinWithNullBuild(boolean parallelBuild, boolean probeHashE assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithNullOnBothSides() + { + testInnerJoinWithNullOnBothSides(false, false, false); + testInnerJoinWithNullOnBothSides(false, false, true); + testInnerJoinWithNullOnBothSides(false, true, false); + testInnerJoinWithNullOnBothSides(false, true, true); + testInnerJoinWithNullOnBothSides(true, false, false); + testInnerJoinWithNullOnBothSides(true, false, true); + testInnerJoinWithNullOnBothSides(true, true, false); + testInnerJoinWithNullOnBothSides(true, true, true); + } + + private void testInnerJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -532,8 +602,20 @@ public void testInnerJoinWithNullOnBothSides(boolean parallelBuild, boolean prob assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValues") - public void testProbeOuterJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test + public void testProbeOuterJoin() + { + testProbeOuterJoin(false, false, false); + testProbeOuterJoin(false, false, true); + testProbeOuterJoin(false, true, false); + testProbeOuterJoin(false, true, true); + testProbeOuterJoin(true, false, false); + testProbeOuterJoin(true, false, true); + testProbeOuterJoin(true, true, false); + testProbeOuterJoin(true, true, true); + } + + private void testProbeOuterJoin(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -578,8 +660,20 @@ public void testProbeOuterJoin(boolean parallelBuild, boolean probeHashEnabled, assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValues") - public void testProbeOuterJoinWithFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test + public void testProbeOuterJoinWithFilterFunction() + { + testProbeOuterJoinWithFilterFunction(false, false, false); + testProbeOuterJoinWithFilterFunction(false, false, true); + testProbeOuterJoinWithFilterFunction(false, true, false); + testProbeOuterJoinWithFilterFunction(false, true, true); + testProbeOuterJoinWithFilterFunction(true, false, false); + testProbeOuterJoinWithFilterFunction(true, false, true); + testProbeOuterJoinWithFilterFunction(true, true, false); + testProbeOuterJoinWithFilterFunction(true, true, true); + } + + private void testProbeOuterJoinWithFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) { TaskContext taskContext = createTaskContext(); @@ -627,8 +721,28 @@ public void testProbeOuterJoinWithFilterFunction(boolean parallelBuild, boolean assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullProbe() + { + testOuterJoinWithNullProbe(false, false, false, false); + testOuterJoinWithNullProbe(false, false, false, true); + testOuterJoinWithNullProbe(false, false, true, false); + testOuterJoinWithNullProbe(false, false, true, true); + testOuterJoinWithNullProbe(false, true, false, false); + testOuterJoinWithNullProbe(false, true, false, true); + testOuterJoinWithNullProbe(false, true, true, false); + testOuterJoinWithNullProbe(false, true, true, true); + testOuterJoinWithNullProbe(true, false, false, false); + testOuterJoinWithNullProbe(true, false, false, true); + testOuterJoinWithNullProbe(true, false, true, false); + testOuterJoinWithNullProbe(true, false, true, true); + testOuterJoinWithNullProbe(true, true, false, false); + testOuterJoinWithNullProbe(true, true, false, true); + testOuterJoinWithNullProbe(true, true, true, false); + testOuterJoinWithNullProbe(true, true, true, true); + } + + private void testOuterJoinWithNullProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -669,8 +783,28 @@ public void testOuterJoinWithNullProbe(boolean parallelBuild, boolean probeHashE assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullProbeAndFilterFunction() + { + testOuterJoinWithNullProbeAndFilterFunction(false, false, false, false); + testOuterJoinWithNullProbeAndFilterFunction(false, false, false, true); + testOuterJoinWithNullProbeAndFilterFunction(false, false, true, false); + testOuterJoinWithNullProbeAndFilterFunction(false, false, true, true); + testOuterJoinWithNullProbeAndFilterFunction(false, true, false, false); + testOuterJoinWithNullProbeAndFilterFunction(false, true, false, true); + testOuterJoinWithNullProbeAndFilterFunction(false, true, true, false); + testOuterJoinWithNullProbeAndFilterFunction(false, true, true, true); + testOuterJoinWithNullProbeAndFilterFunction(true, false, false, false); + testOuterJoinWithNullProbeAndFilterFunction(true, false, false, true); + testOuterJoinWithNullProbeAndFilterFunction(true, false, true, false); + testOuterJoinWithNullProbeAndFilterFunction(true, false, true, true); + testOuterJoinWithNullProbeAndFilterFunction(true, true, false, false); + testOuterJoinWithNullProbeAndFilterFunction(true, true, false, true); + testOuterJoinWithNullProbeAndFilterFunction(true, true, true, false); + testOuterJoinWithNullProbeAndFilterFunction(true, true, true, true); + } + + private void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -714,8 +848,28 @@ public void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild, b assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullBuild() + { + testOuterJoinWithNullBuild(false, false, false, false); + testOuterJoinWithNullBuild(false, false, false, true); + testOuterJoinWithNullBuild(false, false, true, false); + testOuterJoinWithNullBuild(false, false, true, true); + testOuterJoinWithNullBuild(false, true, false, false); + testOuterJoinWithNullBuild(false, true, false, true); + testOuterJoinWithNullBuild(false, true, true, false); + testOuterJoinWithNullBuild(false, true, true, true); + testOuterJoinWithNullBuild(true, false, false, false); + testOuterJoinWithNullBuild(true, false, false, true); + testOuterJoinWithNullBuild(true, false, true, false); + testOuterJoinWithNullBuild(true, false, true, true); + testOuterJoinWithNullBuild(true, true, false, false); + testOuterJoinWithNullBuild(true, true, false, true); + testOuterJoinWithNullBuild(true, true, true, false); + testOuterJoinWithNullBuild(true, true, true, true); + } + + private void testOuterJoinWithNullBuild(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -755,8 +909,28 @@ public void testOuterJoinWithNullBuild(boolean parallelBuild, boolean probeHashE assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullBuildAndFilterFunction() + { + testOuterJoinWithNullBuildAndFilterFunction(false, false, false, false); + testOuterJoinWithNullBuildAndFilterFunction(false, false, false, true); + testOuterJoinWithNullBuildAndFilterFunction(false, false, true, false); + testOuterJoinWithNullBuildAndFilterFunction(false, false, true, true); + testOuterJoinWithNullBuildAndFilterFunction(false, true, false, false); + testOuterJoinWithNullBuildAndFilterFunction(false, true, false, true); + testOuterJoinWithNullBuildAndFilterFunction(false, true, true, false); + testOuterJoinWithNullBuildAndFilterFunction(false, true, true, true); + testOuterJoinWithNullBuildAndFilterFunction(true, false, false, false); + testOuterJoinWithNullBuildAndFilterFunction(true, false, false, true); + testOuterJoinWithNullBuildAndFilterFunction(true, false, true, false); + testOuterJoinWithNullBuildAndFilterFunction(true, false, true, true); + testOuterJoinWithNullBuildAndFilterFunction(true, true, false, false); + testOuterJoinWithNullBuildAndFilterFunction(true, true, false, true); + testOuterJoinWithNullBuildAndFilterFunction(true, true, true, false); + testOuterJoinWithNullBuildAndFilterFunction(true, true, true, true); + } + + private void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -800,8 +974,28 @@ public void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild, b assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullOnBothSides() + { + testOuterJoinWithNullOnBothSides(false, false, false, false); + testOuterJoinWithNullOnBothSides(false, false, false, true); + testOuterJoinWithNullOnBothSides(false, false, true, false); + testOuterJoinWithNullOnBothSides(false, false, true, true); + testOuterJoinWithNullOnBothSides(false, true, false, false); + testOuterJoinWithNullOnBothSides(false, true, false, true); + testOuterJoinWithNullOnBothSides(false, true, true, false); + testOuterJoinWithNullOnBothSides(false, true, true, true); + testOuterJoinWithNullOnBothSides(true, false, false, false); + testOuterJoinWithNullOnBothSides(true, false, false, true); + testOuterJoinWithNullOnBothSides(true, false, true, false); + testOuterJoinWithNullOnBothSides(true, false, true, true); + testOuterJoinWithNullOnBothSides(true, true, false, false); + testOuterJoinWithNullOnBothSides(true, true, false, true); + testOuterJoinWithNullOnBothSides(true, true, true, false); + testOuterJoinWithNullOnBothSides(true, true, true, true); + } + + private void testOuterJoinWithNullOnBothSides(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -842,8 +1036,28 @@ public void testOuterJoinWithNullOnBothSides(boolean parallelBuild, boolean prob assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testOuterJoinWithNullOnBothSidesAndFilterFunction() + { + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, false, false, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, false, false, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, false, true, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, false, true, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, true, false, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, true, false, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, true, true, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(false, true, true, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, false, false, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, false, false, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, false, true, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, false, true, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, true, false, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, true, false, true); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, true, true, false); + testOuterJoinWithNullOnBothSidesAndFilterFunction(true, true, true, true); + } + + private void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -888,8 +1102,16 @@ public void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelBu assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "testMemoryLimitProvider") - public void testMemoryLimit(boolean parallelBuild, boolean buildHashEnabled) + @Test + public void testMemoryLimit() + { + testMemoryLimit(false, false); + testMemoryLimit(false, true); + testMemoryLimit(true, false); + testMemoryLimit(true, true); + } + + private void testMemoryLimit(boolean parallelBuild, boolean buildHashEnabled) { TaskContext taskContext = TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION, DataSize.ofBytes(100)); @@ -903,8 +1125,28 @@ public void testMemoryLimit(boolean parallelBuild, boolean buildHashEnabled) .hasMessageMatching("Query exceeded per-node memory limit of.*"); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithEmptyLookupSource() + { + testInnerJoinWithEmptyLookupSource(false, false, false, false); + testInnerJoinWithEmptyLookupSource(false, false, false, true); + testInnerJoinWithEmptyLookupSource(false, false, true, false); + testInnerJoinWithEmptyLookupSource(false, false, true, true); + testInnerJoinWithEmptyLookupSource(false, true, false, false); + testInnerJoinWithEmptyLookupSource(false, true, false, true); + testInnerJoinWithEmptyLookupSource(false, true, true, false); + testInnerJoinWithEmptyLookupSource(false, true, true, true); + testInnerJoinWithEmptyLookupSource(true, false, false, false); + testInnerJoinWithEmptyLookupSource(true, false, false, true); + testInnerJoinWithEmptyLookupSource(true, false, true, false); + testInnerJoinWithEmptyLookupSource(true, false, true, true); + testInnerJoinWithEmptyLookupSource(true, true, false, false); + testInnerJoinWithEmptyLookupSource(true, true, false, true); + testInnerJoinWithEmptyLookupSource(true, true, true, false); + testInnerJoinWithEmptyLookupSource(true, true, true, true); + } + + private void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -940,8 +1182,28 @@ public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean pr assertThat(outputPage).isNull(); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testLookupOuterJoinWithEmptyLookupSource() + { + testLookupOuterJoinWithEmptyLookupSource(false, false, false, false); + testLookupOuterJoinWithEmptyLookupSource(false, false, false, true); + testLookupOuterJoinWithEmptyLookupSource(false, false, true, false); + testLookupOuterJoinWithEmptyLookupSource(false, false, true, true); + testLookupOuterJoinWithEmptyLookupSource(false, true, false, false); + testLookupOuterJoinWithEmptyLookupSource(false, true, false, true); + testLookupOuterJoinWithEmptyLookupSource(false, true, true, false); + testLookupOuterJoinWithEmptyLookupSource(false, true, true, true); + testLookupOuterJoinWithEmptyLookupSource(true, false, false, false); + testLookupOuterJoinWithEmptyLookupSource(true, false, false, true); + testLookupOuterJoinWithEmptyLookupSource(true, false, true, false); + testLookupOuterJoinWithEmptyLookupSource(true, false, true, true); + testLookupOuterJoinWithEmptyLookupSource(true, true, false, false); + testLookupOuterJoinWithEmptyLookupSource(true, true, false, true); + testLookupOuterJoinWithEmptyLookupSource(true, true, true, false); + testLookupOuterJoinWithEmptyLookupSource(true, true, true, true); + } + + private void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -977,8 +1239,28 @@ public void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, bool assertThat(outputPage).isNull(); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testProbeOuterJoinWithEmptyLookupSource() + { + testProbeOuterJoinWithEmptyLookupSource(false, false, false, false); + testProbeOuterJoinWithEmptyLookupSource(false, false, false, true); + testProbeOuterJoinWithEmptyLookupSource(false, false, true, false); + testProbeOuterJoinWithEmptyLookupSource(false, false, true, true); + testProbeOuterJoinWithEmptyLookupSource(false, true, false, false); + testProbeOuterJoinWithEmptyLookupSource(false, true, false, true); + testProbeOuterJoinWithEmptyLookupSource(false, true, true, false); + testProbeOuterJoinWithEmptyLookupSource(false, true, true, true); + testProbeOuterJoinWithEmptyLookupSource(true, false, false, false); + testProbeOuterJoinWithEmptyLookupSource(true, false, false, true); + testProbeOuterJoinWithEmptyLookupSource(true, false, true, false); + testProbeOuterJoinWithEmptyLookupSource(true, false, true, true); + testProbeOuterJoinWithEmptyLookupSource(true, true, false, false); + testProbeOuterJoinWithEmptyLookupSource(true, true, false, true); + testProbeOuterJoinWithEmptyLookupSource(true, true, true, false); + testProbeOuterJoinWithEmptyLookupSource(true, true, true, true); + } + + private void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -1023,8 +1305,28 @@ public void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, boole assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testFullOuterJoinWithEmptyLookupSource() + { + testFullOuterJoinWithEmptyLookupSource(false, false, false, false); + testFullOuterJoinWithEmptyLookupSource(false, false, false, true); + testFullOuterJoinWithEmptyLookupSource(false, false, true, false); + testFullOuterJoinWithEmptyLookupSource(false, false, true, true); + testFullOuterJoinWithEmptyLookupSource(false, true, false, false); + testFullOuterJoinWithEmptyLookupSource(false, true, false, true); + testFullOuterJoinWithEmptyLookupSource(false, true, true, false); + testFullOuterJoinWithEmptyLookupSource(false, true, true, true); + testFullOuterJoinWithEmptyLookupSource(true, false, false, false); + testFullOuterJoinWithEmptyLookupSource(true, false, false, true); + testFullOuterJoinWithEmptyLookupSource(true, false, true, false); + testFullOuterJoinWithEmptyLookupSource(true, false, true, true); + testFullOuterJoinWithEmptyLookupSource(true, true, false, false); + testFullOuterJoinWithEmptyLookupSource(true, true, false, true); + testFullOuterJoinWithEmptyLookupSource(true, true, true, false); + testFullOuterJoinWithEmptyLookupSource(true, true, true, true); + } + + private void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -1069,8 +1371,28 @@ public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolea assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) + @Test + public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe() + { + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, false, false, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, false, false, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, false, true, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, false, true, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, true, false, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, true, false, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, true, true, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(false, true, true, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, false, false, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, false, false, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, false, true, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, false, true, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, true, false, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, true, false, true); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, true, true, false); + testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(true, true, true, true); + } + + private void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled, boolean singleBigintLookupSource) { TaskContext taskContext = createTaskContext(); @@ -1109,8 +1431,21 @@ public void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallelB assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(dataProvider = "hashJoinTestValues") - public void testInnerJoinWithBlockingLookupSourceAndEmptyProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test + public void testInnerJoinWithBlockingLookupSourceAndEmptyProbe() + throws Exception + { + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(false, false, false); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(false, false, true); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(false, true, false); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(false, true, true); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(true, false, false); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(true, false, true); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(true, true, false); + testInnerJoinWithBlockingLookupSourceAndEmptyProbe(true, true, true); + } + + private void testInnerJoinWithBlockingLookupSourceAndEmptyProbe(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) throws Exception { // join that waits for build side to be collected @@ -1145,8 +1480,21 @@ public void testInnerJoinWithBlockingLookupSourceAndEmptyProbe(boolean parallelB } } - @Test(dataProvider = "hashJoinTestValues") - public void testInnerJoinWithBlockingLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test + public void testInnerJoinWithBlockingLookupSource() + throws Exception + { + testInnerJoinWithBlockingLookupSource(false, false, false); + testInnerJoinWithBlockingLookupSource(false, false, true); + testInnerJoinWithBlockingLookupSource(false, true, false); + testInnerJoinWithBlockingLookupSource(false, true, true); + testInnerJoinWithBlockingLookupSource(true, false, false); + testInnerJoinWithBlockingLookupSource(true, false, true); + testInnerJoinWithBlockingLookupSource(true, true, false); + testInnerJoinWithBlockingLookupSource(true, true, true); + } + + private void testInnerJoinWithBlockingLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) throws Exception { RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), ImmutableList.of(VARCHAR)); @@ -1296,39 +1644,6 @@ private OperatorFactory createJoinOperatorFactoryWithBlockingLookupSource(TaskCo return joinOperatorFactory; } - @DataProvider(name = "hashJoinTestValues") - public static Object[][] hashJoinTestValuesProvider() - { - return DataProviders.cartesianProduct( - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}); - } - - @DataProvider - public static Object[][] testMemoryLimitProvider() - { - return DataProviders.cartesianProduct( - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}); - } - - @DataProvider(name = "singleBigintLookupSourceProvider") - public static Object[][] singleBigintLookupSourceProvider() - { - return new Object[][] {{true}, {false}}; - } - - @DataProvider(name = "hashJoinTestValuesAndsingleBigintLookupSourceProvider") - public static Object[][] hashJoinTestValuesAndsingleBigintLookupSourceProvider() - { - return DataProviders.cartesianProduct( - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}, - new Object[][] {{true}, {false}}); - } - private TaskContext createTaskContext() { return TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java index 444e02965633..c16c6c200c7b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java @@ -42,7 +42,6 @@ import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.predicate.NullableValue; -import io.trino.spi.type.AbstractType; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Decimals; import io.trino.spi.type.TimestampType; @@ -50,11 +49,10 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingTaskContext; import io.trino.type.BlockTypeOperators; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.ArrayList; import java.util.Collection; @@ -98,8 +96,11 @@ import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) public class TestPagePartitioner { private static final DataSize MAX_MEMORY = DataSize.of(50, MEGABYTE); @@ -111,36 +112,21 @@ public class TestPagePartitioner private static final PagesSerdeFactory PAGES_SERDE_FACTORY = new PagesSerdeFactory(new TestingBlockEncodingSerde(), false); private static final PageDeserializer PAGE_DESERIALIZER = PAGES_SERDE_FACTORY.createDeserializer(Optional.empty()); - private ExecutorService executor; - private ScheduledExecutorService scheduledExecutor; - private TestOutputBuffer outputBuffer; + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-executor-%s")); + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(1, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - @BeforeClass - public void setUpClass() - { - executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-executor-%s")); - scheduledExecutor = newScheduledThreadPool(1, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - } - - @AfterClass(alwaysRun = true) + @AfterAll public void tearDownClass() { executor.shutdownNow(); - executor = null; scheduledExecutor.shutdownNow(); - scheduledExecutor = null; - } - - @BeforeMethod - public void setUp() - { - outputBuffer = new TestOutputBuffer(); } @Test public void testOutputForEmptyPage() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page = new Page(createLongsBlock(ImmutableList.of())); pagePartitioner.partitionPage(page, operatorContext()); @@ -156,10 +142,18 @@ private OperatorContext operatorContext() .addOperatorContext(0, new PlanNodeId("plan-node-0"), PartitionedOutputOperator.class.getSimpleName()); } - @Test(dataProvider = "partitioningMode") - public void testOutputEqualsInput(PartitioningMode partitioningMode) + @Test + public void testOutputEqualsInput() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + testOutputEqualsInput(PartitioningMode.ROW_WISE); + testOutputEqualsInput(PartitioningMode.COLUMNAR); + } + + private void testOutputEqualsInput(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page = new Page(createLongSequenceBlock(0, POSITIONS_PER_PAGE)); List expected = readLongs(Stream.of(page), 0); @@ -169,10 +163,18 @@ public void testOutputEqualsInput(PartitioningMode partitioningMode) assertThat(partitioned).containsExactlyInAnyOrderElementsOf(expected); // order is different due to 2 partitions joined } - @Test(dataProvider = "partitioningMode") - public void testOutputForPageWithNoBlockPartitionFunction(PartitioningMode partitioningMode) + @Test + public void testOutputForPageWithNoBlockPartitionFunction() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT) + testOutputForPageWithNoBlockPartitionFunction(PartitioningMode.ROW_WISE); + testOutputForPageWithNoBlockPartitionFunction(PartitioningMode.COLUMNAR); + } + + private void testOutputForPageWithNoBlockPartitionFunction(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT) .withPartitionFunction(new BucketPartitionFunction( ROUND_ROBIN.createBucketFunction(null, false, PARTITION_COUNT, null), IntStream.range(0, PARTITION_COUNT).toArray())) @@ -188,10 +190,18 @@ public void testOutputForPageWithNoBlockPartitionFunction(PartitioningMode parti assertThat(partition1).containsExactly(1L, 3L, 5L, 7L); } - @Test(dataProvider = "partitioningMode") - public void testOutputForMultipleSimplePages(PartitioningMode partitioningMode) + @Test + public void testOutputForMultipleSimplePages() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + testOutputForMultipleSimplePages(PartitioningMode.ROW_WISE); + testOutputForMultipleSimplePages(PartitioningMode.COLUMNAR); + } + + private void testOutputForMultipleSimplePages(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page1 = new Page(createLongSequenceBlock(0, POSITIONS_PER_PAGE)); Page page2 = new Page(createLongSequenceBlock(1, POSITIONS_PER_PAGE)); Page page3 = new Page(createLongSequenceBlock(2, POSITIONS_PER_PAGE)); @@ -203,10 +213,17 @@ public void testOutputForMultipleSimplePages(PartitioningMode partitioningMode) assertThat(partitioned).containsExactlyInAnyOrderElementsOf(expected); // order is different due to 2 partitions joined } - @Test(dataProvider = "partitioningMode") - public void testOutputForSimplePageWithReplication(PartitioningMode partitioningMode) + @Test + public void testOutputForSimplePageWithReplication() + { + testOutputForSimplePageWithReplication(PartitioningMode.ROW_WISE); + testOutputForSimplePageWithReplication(PartitioningMode.COLUMNAR); + } + + private void testOutputForSimplePageWithReplication(PartitioningMode partitioningMode) { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).replicate().build(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).replicate().build(); Page page = new Page(createLongsBlock(0L, 1L, 2L, 3L, null)); processPages(pagePartitioner, partitioningMode, page); @@ -217,10 +234,17 @@ public void testOutputForSimplePageWithReplication(PartitioningMode partitioning assertThat(partition1).containsExactly(0L, 1L, 3L); // position 0 copied to all partitions } - @Test(dataProvider = "partitioningMode") - public void testOutputForSimplePageWithNullChannel(PartitioningMode partitioningMode) + @Test + public void testOutputForSimplePageWithNullChannel() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).withNullChannel(0).build(); + testOutputForSimplePageWithNullChannel(PartitioningMode.ROW_WISE); + testOutputForSimplePageWithNullChannel(PartitioningMode.COLUMNAR); + } + + private void testOutputForSimplePageWithNullChannel(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).withNullChannel(0).build(); Page page = new Page(createLongsBlock(0L, 1L, 2L, 3L, null)); processPages(pagePartitioner, partitioningMode, page); @@ -231,10 +255,17 @@ public void testOutputForSimplePageWithNullChannel(PartitioningMode partitioning assertThat(partition1).containsExactlyInAnyOrder(1L, 3L, null); // null copied to all partitions } - @Test(dataProvider = "partitioningMode") - public void testOutputForSimplePageWithPartitionConstant(PartitioningMode partitioningMode) + @Test + public void testOutputForSimplePageWithPartitionConstant() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT) + testOutputForSimplePageWithPartitionConstant(PartitioningMode.ROW_WISE); + testOutputForSimplePageWithPartitionConstant(PartitioningMode.COLUMNAR); + } + + private void testOutputForSimplePageWithPartitionConstant(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT) .withPartitionConstants(ImmutableList.of(Optional.of(new NullableValue(BIGINT, 1L)))) .withPartitionChannels(-1) .build(); @@ -249,10 +280,17 @@ public void testOutputForSimplePageWithPartitionConstant(PartitioningMode partit assertThat(partition1).containsExactlyElementsOf(allValues); } - @Test(dataProvider = "partitioningMode") - public void testOutputForSimplePageWithPartitionConstantAndHashBlock(PartitioningMode partitioningMode) + @Test + public void testOutputForSimplePageWithPartitionConstantAndHashBlock() + { + testOutputForSimplePageWithPartitionConstantAndHashBlock(PartitioningMode.ROW_WISE); + testOutputForSimplePageWithPartitionConstantAndHashBlock(PartitioningMode.COLUMNAR); + } + + private void testOutputForSimplePageWithPartitionConstantAndHashBlock(PartitioningMode partitioningMode) { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT) + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT) .withPartitionConstants(ImmutableList.of(Optional.empty(), Optional.of(new NullableValue(BIGINT, 1L)))) .withPartitionChannels(0, -1) // use first block and constant block at index 1 as input to partitionFunction .withHashChannels(0, 1) // use both channels to calculate partition (a+b) mod 2 @@ -267,10 +305,17 @@ public void testOutputForSimplePageWithPartitionConstantAndHashBlock(Partitionin assertThat(partition1).containsExactly(0L, 2L); } - @Test(dataProvider = "partitioningMode") - public void testPartitionPositionsWithRleNotNull(PartitioningMode partitioningMode) + @Test + public void testPartitionPositionsWithRleNotNull() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT, BIGINT).build(); + testPartitionPositionsWithRleNotNull(PartitioningMode.ROW_WISE); + testPartitionPositionsWithRleNotNull(PartitioningMode.COLUMNAR); + } + + private void testPartitionPositionsWithRleNotNull(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT, BIGINT).build(); Page page = new Page(createRepeatedValuesBlock(0, POSITIONS_PER_PAGE), createLongSequenceBlock(0, POSITIONS_PER_PAGE)); processPages(pagePartitioner, partitioningMode, page); @@ -282,10 +327,17 @@ public void testPartitionPositionsWithRleNotNull(PartitioningMode partitioningMo assertThat(outputBuffer.getEnqueuedDeserialized(1)).isEmpty(); } - @Test(dataProvider = "partitioningMode") - public void testPartitionPositionsWithRleNotNullWithReplication(PartitioningMode partitioningMode) + @Test + public void testPartitionPositionsWithRleNotNullWithReplication() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT, BIGINT).replicate().build(); + testPartitionPositionsWithRleNotNullWithReplication(PartitioningMode.ROW_WISE); + testPartitionPositionsWithRleNotNullWithReplication(PartitioningMode.COLUMNAR); + } + + private void testPartitionPositionsWithRleNotNullWithReplication(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT, BIGINT).replicate().build(); Page page = new Page(createRepeatedValuesBlock(0, POSITIONS_PER_PAGE), createLongSequenceBlock(0, POSITIONS_PER_PAGE)); processPages(pagePartitioner, partitioningMode, page); @@ -296,10 +348,17 @@ public void testPartitionPositionsWithRleNotNullWithReplication(PartitioningMode assertThat(partition1).containsExactly(0L); // position 0 copied to all partitions } - @Test(dataProvider = "partitioningMode") - public void testPartitionPositionsWithRleNullWithNullChannel(PartitioningMode partitioningMode) + @Test + public void testPartitionPositionsWithRleNullWithNullChannel() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT, BIGINT).withNullChannel(0).build(); + testPartitionPositionsWithRleNullWithNullChannel(PartitioningMode.ROW_WISE); + testPartitionPositionsWithRleNullWithNullChannel(PartitioningMode.COLUMNAR); + } + + private void testPartitionPositionsWithRleNullWithNullChannel(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT, BIGINT).withNullChannel(0).build(); Page page = new Page(RunLengthEncodedBlock.create(createLongsBlock((Long) null), POSITIONS_PER_PAGE), createLongSequenceBlock(0, POSITIONS_PER_PAGE)); processPages(pagePartitioner, partitioningMode, page); @@ -310,10 +369,17 @@ public void testPartitionPositionsWithRleNullWithNullChannel(PartitioningMode pa assertThat(partition1).containsExactlyElementsOf(readLongs(Stream.of(page), 1)); } - @Test(dataProvider = "partitioningMode") - public void testOutputForDictionaryBlock(PartitioningMode partitioningMode) + @Test + public void testOutputForDictionaryBlock() + { + testOutputForDictionaryBlock(PartitioningMode.ROW_WISE); + testOutputForDictionaryBlock(PartitioningMode.COLUMNAR); + } + + private void testOutputForDictionaryBlock(PartitioningMode partitioningMode) { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page = new Page(createLongDictionaryBlock(0, 10)); // must have at least 10 position to have non-trivial dict processPages(pagePartitioner, partitioningMode, page); @@ -324,10 +390,17 @@ public void testOutputForDictionaryBlock(PartitioningMode partitioningMode) assertThat(partition1).containsExactlyElementsOf(nCopies(5, 1L)); } - @Test(dataProvider = "partitioningMode") - public void testOutputForOneValueDictionaryBlock(PartitioningMode partitioningMode) + @Test + public void testOutputForOneValueDictionaryBlock() { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + testOutputForOneValueDictionaryBlock(PartitioningMode.ROW_WISE); + testOutputForOneValueDictionaryBlock(PartitioningMode.COLUMNAR); + } + + private void testOutputForOneValueDictionaryBlock(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page = new Page(DictionaryBlock.create(4, createLongsBlock(0), new int[] {0, 0, 0, 0})); processPages(pagePartitioner, partitioningMode, page); @@ -338,10 +411,17 @@ public void testOutputForOneValueDictionaryBlock(PartitioningMode partitioningMo assertThat(partition1).isEmpty(); } - @Test(dataProvider = "partitioningMode") - public void testOutputForViewDictionaryBlock(PartitioningMode partitioningMode) + @Test + public void testOutputForViewDictionaryBlock() + { + testOutputForViewDictionaryBlock(PartitioningMode.ROW_WISE); + testOutputForViewDictionaryBlock(PartitioningMode.COLUMNAR); + } + + private void testOutputForViewDictionaryBlock(PartitioningMode partitioningMode) { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).build(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).build(); Page page = new Page(DictionaryBlock.create(4, createLongSequenceBlock(4, 8), new int[] {1, 0, 3, 2})); processPages(pagePartitioner, partitioningMode, page); @@ -352,10 +432,48 @@ public void testOutputForViewDictionaryBlock(PartitioningMode partitioningMode) assertThat(partition1).containsExactlyInAnyOrder(5L, 7L); } - @Test(dataProvider = "typesWithPartitioningMode") - public void testOutputForSimplePageWithType(Type type, PartitioningMode partitioningMode) - { - PagePartitioner pagePartitioner = pagePartitioner(BIGINT, type).build(); + @Test + public void testOutputForSimplePageWithType() + { + testOutputForSimplePageWithType(BIGINT, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(BOOLEAN, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(INTEGER, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(createCharType(10), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(createUnboundedVarcharType(), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(DOUBLE, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(SMALLINT, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(TINYINT, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(UUID, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(VARBINARY, PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(createDecimalType(1), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(new ArrayType(BIGINT), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(TimestampType.createTimestampType(9), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(TimestampType.createTimestampType(3), PartitioningMode.ROW_WISE); + testOutputForSimplePageWithType(IPADDRESS, PartitioningMode.ROW_WISE); + + testOutputForSimplePageWithType(BIGINT, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(BOOLEAN, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(INTEGER, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(createCharType(10), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(createUnboundedVarcharType(), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(DOUBLE, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(SMALLINT, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(TINYINT, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(UUID, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(VARBINARY, PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(createDecimalType(1), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(new ArrayType(BIGINT), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(TimestampType.createTimestampType(9), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(TimestampType.createTimestampType(3), PartitioningMode.COLUMNAR); + testOutputForSimplePageWithType(IPADDRESS, PartitioningMode.COLUMNAR); + } + + private void testOutputForSimplePageWithType(Type type, PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT, type).build(); Page page = new Page( createLongSequenceBlock(0, POSITIONS_PER_PAGE), // partition block createBlockForType(type, POSITIONS_PER_PAGE)); @@ -367,18 +485,56 @@ public void testOutputForSimplePageWithType(Type type, PartitioningMode partitio assertThat(partitioned).containsExactlyInAnyOrderElementsOf(expected); // order is different due to 2 partitions joined } - @Test(dataProvider = "types") - public void testOutputWithMixedRowWiseAndColumnarPartitioning(Type type) + @Test + public void testOutputWithMixedRowWiseAndColumnarPartitioning() + { + testOutputEqualsInput(BIGINT, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(BOOLEAN, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(INTEGER, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(createCharType(10), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(createUnboundedVarcharType(), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(DOUBLE, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(SMALLINT, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(TINYINT, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(UUID, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(VARBINARY, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(createDecimalType(1), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(new ArrayType(BIGINT), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(TimestampType.createTimestampType(9), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(TimestampType.createTimestampType(3), PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + testOutputEqualsInput(IPADDRESS, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); + + testOutputEqualsInput(BIGINT, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(BOOLEAN, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(INTEGER, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(createCharType(10), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(createUnboundedVarcharType(), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(DOUBLE, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(SMALLINT, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(TINYINT, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(UUID, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(VARBINARY, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(createDecimalType(1), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(new ArrayType(BIGINT), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(TimestampType.createTimestampType(9), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(TimestampType.createTimestampType(3), PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testOutputEqualsInput(IPADDRESS, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + } + + @Test + public void testMemoryReleased() { - testOutputEqualsInput(type, PartitioningMode.COLUMNAR, PartitioningMode.ROW_WISE); - testOutputEqualsInput(type, PartitioningMode.ROW_WISE, PartitioningMode.COLUMNAR); + testMemoryReleased(PartitioningMode.ROW_WISE); + testMemoryReleased(PartitioningMode.COLUMNAR); } - @Test(dataProvider = "partitioningMode") - public void testMemoryReleased(PartitioningMode partitioningMode) + private void testMemoryReleased(PartitioningMode partitioningMode) { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).withMemoryContext(memoryContext).build(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).withMemoryContext(memoryContext).build(); Page page = new Page(createLongsBlock(0L, 1L, 2L, 3L, null)); processPages(pagePartitioner, partitioningMode, page); @@ -386,13 +542,20 @@ public void testMemoryReleased(PartitioningMode partitioningMode) assertThat(memoryContext.getBytes()).isEqualTo(0); } - @Test(dataProvider = "partitioningMode") - public void testMemoryReleasedOnFailure(PartitioningMode partitioningMode) + @Test + public void testMemoryReleasedOnFailure() + { + testMemoryReleasedOnFailure(PartitioningMode.ROW_WISE); + testMemoryReleasedOnFailure(PartitioningMode.COLUMNAR); + } + + private void testMemoryReleasedOnFailure(PartitioningMode partitioningMode) { AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); RuntimeException exception = new RuntimeException(); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); outputBuffer.throwOnEnqueue(exception); - PagePartitioner pagePartitioner = pagePartitioner(BIGINT).withMemoryContext(memoryContext).build(); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, BIGINT).withMemoryContext(memoryContext).build(); Page page = new Page(createLongsBlock(0L, 1L, 2L, 3L, null)); partitioningMode.partitionPage(pagePartitioner, page); @@ -403,7 +566,8 @@ public void testMemoryReleasedOnFailure(PartitioningMode partitioningMode) private void testOutputEqualsInput(Type type, PartitioningMode mode1, PartitioningMode mode2) { - PagePartitionerBuilder pagePartitionerBuilder = pagePartitioner(BIGINT, type, type); + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + PagePartitionerBuilder pagePartitionerBuilder = pagePartitioner(outputBuffer, BIGINT, type, type); PagePartitioner pagePartitioner = pagePartitionerBuilder.build(); Page input = new Page( createLongSequenceBlock(0, POSITIONS_PER_PAGE), // partition block @@ -422,48 +586,6 @@ private void testOutputEqualsInput(Type type, PartitioningMode mode1, Partitioni outputBuffer.clear(); } - @DataProvider(name = "partitioningMode") - public static Object[][] partitioningMode() - { - return new Object[][] {{PartitioningMode.ROW_WISE}, {PartitioningMode.COLUMNAR}}; - } - - @DataProvider(name = "types") - public static Object[][] types() - { - return getTypes().stream().map(type -> new Object[] {type}).toArray(Object[][]::new); - } - - @DataProvider(name = "typesWithPartitioningMode") - public static Object[][] typesWithPartitioningMode() - { - return getTypes().stream() - .flatMap(type -> Stream.of(PartitioningMode.values()) - .map(partitioningMode -> new Object[] {type, partitioningMode})) - .toArray(Object[][]::new); - } - - private static ImmutableList getTypes() - { - return ImmutableList.of( - BIGINT, - BOOLEAN, - INTEGER, - createCharType(10), - createUnboundedVarcharType(), - DOUBLE, - SMALLINT, - TINYINT, - UUID, - VARBINARY, - createDecimalType(1), - createDecimalType(Decimals.MAX_SHORT_PRECISION + 1), - new ArrayType(BIGINT), - TimestampType.createTimestampType(9), - TimestampType.createTimestampType(3), - IPADDRESS); - } - private static Block createBlockForType(Type type, int positionsPerPage) { return createRandomBlockForType(type, positionsPerPage, 0.2F); @@ -500,17 +622,17 @@ private static List readChannel(Stream pages, int channel, Type ty return unmodifiableList(result); } - private PagePartitionerBuilder pagePartitioner(Type... types) + private PagePartitionerBuilder pagePartitioner(TestOutputBuffer outputBuffer, Type... types) { - return pagePartitioner(ImmutableList.copyOf(types)); + return pagePartitioner(ImmutableList.copyOf(types), outputBuffer); } - private PagePartitionerBuilder pagePartitioner(List types) + private PagePartitionerBuilder pagePartitioner(List types, TestOutputBuffer outputBuffer) { - return pagePartitioner().withTypes(types); + return pagePartitioner(outputBuffer).withTypes(types); } - private PagePartitionerBuilder pagePartitioner() + private PagePartitionerBuilder pagePartitioner(TestOutputBuffer outputBuffer) { return new PagePartitionerBuilder(executor, scheduledExecutor, outputBuffer); } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java index a41bd26a5334..01ece57858d9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java @@ -40,12 +40,9 @@ import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; import io.trino.type.BlockTypeOperators; -import io.trino.type.UnknownType; import it.unimi.dsi.fastutil.ints.IntArrayList; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; -import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -83,72 +80,97 @@ public class TestPositionsAppender { private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(new BlockTypeOperators()); - @Test(dataProvider = "types") - public void testMixedBlockTypes(TestType type) - { - List input = ImmutableList.of( - input(emptyBlock(type)), - input(nullBlock(type, 3), 0, 2), - input(notNullBlock(type, 3), 1, 2), - input(partiallyNullBlock(type, 4), 0, 1, 2, 3), - input(partiallyNullBlock(type, 4)), // empty position list - input(rleBlock(type, 4), 0, 2), - input(rleBlock(type, 2), 0, 1), // rle all positions - input(nullRleBlock(type, 4), 1, 2), - input(dictionaryBlock(type, 4, 2, 0), 0, 3), // dict not null - input(dictionaryBlock(type, 8, 4, 0.5F), 1, 3, 5), // dict mixed - input(dictionaryBlock(type, 8, 4, 1), 1, 3, 5), // dict null - input(rleBlock(dictionaryBlock(type, 1, 2, 0), 3), 2), // rle -> dict - input(rleBlock(dictionaryBlock(notNullBlock(type, 2), new int[] {1}), 3), 2), // rle -> dict with position 0 mapped to > 0 - input(rleBlock(dictionaryBlock(rleBlock(type, 4), 1), 3), 1), // rle -> dict -> rle - input(dictionaryBlock(dictionaryBlock(type, 5, 4, 0.5F), 3), 2), // dict -> dict - input(dictionaryBlock(dictionaryBlock(dictionaryBlock(type, 5, 4, 0.5F), 3), 3), 2), // dict -> dict -> dict - input(dictionaryBlock(rleBlock(type, 4), 3), 0, 2), // dict -> rle - input(notNullBlock(type, 4).getRegion(2, 2), 0, 1), // not null block with offset - input(partiallyNullBlock(type, 4).getRegion(2, 2), 0, 1), // nullable block with offset - input(rleBlock(notNullBlock(type, 4).getRegion(2, 1), 3), 1)); // rle block with offset - - testAppend(type, input); + @Test + public void testMixedBlockTypes() + { + for (TestType type : TestType.values()) { + List input = ImmutableList.of( + input(emptyBlock(type)), + input(nullBlock(type, 3), 0, 2), + input(notNullBlock(type, 3), 1, 2), + input(partiallyNullBlock(type, 4), 0, 1, 2, 3), + input(partiallyNullBlock(type, 4)), // empty position list + input(rleBlock(type, 4), 0, 2), + input(rleBlock(type, 2), 0, 1), // rle all positions + input(nullRleBlock(type, 4), 1, 2), + input(dictionaryBlock(type, 4, 2, 0), 0, 3), // dict not null + input(dictionaryBlock(type, 8, 4, 0.5F), 1, 3, 5), // dict mixed + input(dictionaryBlock(type, 8, 4, 1), 1, 3, 5), // dict null + input(rleBlock(dictionaryBlock(type, 1, 2, 0), 3), 2), // rle -> dict + input(rleBlock(dictionaryBlock(notNullBlock(type, 2), new int[] {1}), 3), 2), // rle -> dict with position 0 mapped to > 0 + input(rleBlock(dictionaryBlock(rleBlock(type, 4), 1), 3), 1), // rle -> dict -> rle + input(dictionaryBlock(dictionaryBlock(type, 5, 4, 0.5F), 3), 2), // dict -> dict + input(dictionaryBlock(dictionaryBlock(dictionaryBlock(type, 5, 4, 0.5F), 3), 3), 2), // dict -> dict -> dict + input(dictionaryBlock(rleBlock(type, 4), 3), 0, 2), // dict -> rle + input(notNullBlock(type, 4).getRegion(2, 2), 0, 1), // not null block with offset + input(partiallyNullBlock(type, 4).getRegion(2, 2), 0, 1), // nullable block with offset + input(rleBlock(notNullBlock(type, 4).getRegion(2, 1), 3), 1)); // rle block with offset + + testAppend(type, input); + } } - @Test(dataProvider = "types") - public void testNullRle(TestType type) + @Test + public void testNullRle() { - testNullRle(type.getType(), nullBlock(type, 2)); - testNullRle(type.getType(), nullRleBlock(type, 2)); - testNullRle(type.getType(), createRandomBlockForType(type, 4, 0.5f)); + for (TestType type : TestType.values()) { + testNullRle(type.getType(), nullBlock(type, 2)); + testNullRle(type.getType(), nullRleBlock(type, 2)); + testNullRle(type.getType(), createRandomBlockForType(type, 4, 0.5f)); + } } - @Test(dataProvider = "types") - public void testRleSwitchToFlat(TestType type) - { - List inputs = ImmutableList.of( - input(rleBlock(type, 3), 0, 1), - input(notNullBlock(type, 2), 0, 1)); - testAppend(type, inputs); + @Test + public void testRleSwitchToFlat() + { + for (TestType type : TestType.values()) { + List inputs = ImmutableList.of( + input(rleBlock(type, 3), 0, 1), + input(notNullBlock(type, 2), 0, 1)); + testAppend(type, inputs); + + List dictionaryInputs = ImmutableList.of( + input(rleBlock(type, 3), 0, 1), + input(dictionaryBlock(type, 2, 4, 0), 0, 1)); + testAppend(type, dictionaryInputs); + } + } - List dictionaryInputs = ImmutableList.of( - input(rleBlock(type, 3), 0, 1), - input(dictionaryBlock(type, 2, 4, 0), 0, 1)); - testAppend(type, dictionaryInputs); + @Test + public void testFlatAppendRle() + { + for (TestType type : TestType.values()) { + List inputs = ImmutableList.of( + input(notNullBlock(type, 2), 0, 1), + input(rleBlock(type, 3), 0, 1)); + testAppend(type, inputs); + + List dictionaryInputs = ImmutableList.of( + input(dictionaryBlock(type, 2, 4, 0), 0, 1), + input(rleBlock(type, 3), 0, 1)); + testAppend(type, dictionaryInputs); + } } - @Test(dataProvider = "types") - public void testFlatAppendRle(TestType type) + @Test + public void testMultipleRleBlocksWithDifferentValues() { - List inputs = ImmutableList.of( - input(notNullBlock(type, 2), 0, 1), - input(rleBlock(type, 3), 0, 1)); - testAppend(type, inputs); - - List dictionaryInputs = ImmutableList.of( - input(dictionaryBlock(type, 2, 4, 0), 0, 1), - input(rleBlock(type, 3), 0, 1)); - testAppend(type, dictionaryInputs); + testMultipleRleBlocksWithDifferentValues(TestType.BIGINT, createLongsBlock(0), createLongsBlock(1)); + testMultipleRleBlocksWithDifferentValues(TestType.BOOLEAN, createBooleansBlock(true), createBooleansBlock(false)); + testMultipleRleBlocksWithDifferentValues(TestType.INTEGER, createIntsBlock(0), createIntsBlock(1)); + testMultipleRleBlocksWithDifferentValues(TestType.CHAR_10, createStringsBlock("0"), createStringsBlock("1")); + testMultipleRleBlocksWithDifferentValues(TestType.VARCHAR, createStringsBlock("0"), createStringsBlock("1")); + testMultipleRleBlocksWithDifferentValues(TestType.DOUBLE, createDoublesBlock(0.0), createDoublesBlock(1.0)); + testMultipleRleBlocksWithDifferentValues(TestType.SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)); + testMultipleRleBlocksWithDifferentValues(TestType.TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)); + testMultipleRleBlocksWithDifferentValues(TestType.VARBINARY, createSlicesBlock(Slices.allocate(Long.BYTES)), createSlicesBlock(Slices.allocate(Long.BYTES).getOutput().appendLong(1).slice())); + testMultipleRleBlocksWithDifferentValues(TestType.LONG_DECIMAL, createLongDecimalsBlock("0"), createLongDecimalsBlock("1")); + testMultipleRleBlocksWithDifferentValues(TestType.ARRAY_BIGINT, createArrayBigintBlock(ImmutableList.of(ImmutableList.of(0L))), createArrayBigintBlock(ImmutableList.of(ImmutableList.of(1L)))); + testMultipleRleBlocksWithDifferentValues(TestType.LONG_TIMESTAMP, createLongTimestampBlock(createTimestampType(9), new LongTimestamp(0, 0)), createLongTimestampBlock(createTimestampType(9), new LongTimestamp(1, 0))); + testMultipleRleBlocksWithDifferentValues(TestType.VARCHAR_WITH_TEST_BLOCK, adapt(createStringsBlock("0")), adapt(createStringsBlock("1"))); } - @Test(dataProvider = "differentValues") - public void testMultipleRleBlocksWithDifferentValues(TestType type, Block value1, Block value2) + private void testMultipleRleBlocksWithDifferentValues(TestType type, Block value1, Block value2) { List input = ImmutableList.of( input(rleBlock(value1, 3), 0, 1), @@ -156,44 +178,30 @@ public void testMultipleRleBlocksWithDifferentValues(TestType type, Block value1 testAppend(type, input); } - @DataProvider(name = "differentValues") - public static Object[][] differentValues() + @Test + public void testMultipleRleWithTheSameValueProduceRle() { - return new Object[][] - { - {TestType.BIGINT, createLongsBlock(0), createLongsBlock(1)}, - {TestType.BOOLEAN, createBooleansBlock(true), createBooleansBlock(false)}, - {TestType.INTEGER, createIntsBlock(0), createIntsBlock(1)}, - {TestType.CHAR_10, createStringsBlock("0"), createStringsBlock("1")}, - {TestType.VARCHAR, createStringsBlock("0"), createStringsBlock("1")}, - {TestType.DOUBLE, createDoublesBlock(0.0), createDoublesBlock(1.0)}, - {TestType.SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)}, - {TestType.TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)}, - {TestType.VARBINARY, createSlicesBlock(Slices.allocate(Long.BYTES)), createSlicesBlock(Slices.allocate(Long.BYTES).getOutput().appendLong(1).slice())}, - {TestType.LONG_DECIMAL, createLongDecimalsBlock("0"), createLongDecimalsBlock("1")}, - {TestType.ARRAY_BIGINT, createArrayBigintBlock(ImmutableList.of(ImmutableList.of(0L))), createArrayBigintBlock(ImmutableList.of(ImmutableList.of(1L)))}, - {TestType.LONG_TIMESTAMP, createLongTimestampBlock(createTimestampType(9), new LongTimestamp(0, 0)), - createLongTimestampBlock(createTimestampType(9), new LongTimestamp(1, 0))}, - {TestType.VARCHAR_WITH_TEST_BLOCK, adapt(createStringsBlock("0")), adapt(createStringsBlock("1"))} - }; - } + for (TestType type : TestType.values()) { + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - @Test(dataProvider = "types") - public void testMultipleRleWithTheSameValueProduceRle(TestType type) - { - UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + Block value = notNullBlock(type, 1); + positionsAppender.append(allPositions(3), rleBlock(value, 3)); + positionsAppender.append(allPositions(2), rleBlock(value, 2)); - Block value = notNullBlock(type, 1); - positionsAppender.append(allPositions(3), rleBlock(value, 3)); - positionsAppender.append(allPositions(2), rleBlock(value, 2)); + Block actual = positionsAppender.build(); + assertThat(actual.getPositionCount()).isEqualTo(5); + assertInstanceOf(actual, RunLengthEncodedBlock.class); + } + } - Block actual = positionsAppender.build(); - assertThat(actual.getPositionCount()).isEqualTo(5); - assertInstanceOf(actual, RunLengthEncodedBlock.class); + @Test + public void testRleAppendForComplexTypeWithNullElement() + { + testRleAppendForComplexTypeWithNullElement(TestType.ROW_BIGINT_VARCHAR, RowBlock.fromFieldBlocks(1, new Block[] {nullBlock(BIGINT, 1), nullBlock(VARCHAR, 1)})); + testRleAppendForComplexTypeWithNullElement(TestType.ARRAY_BIGINT, ArrayBlock.fromElementBlock(1, Optional.empty(), new int[] {0, 1}, nullBlock(BIGINT, 1))); } - @Test(dataProvider = "complexTypesWithNullElementBlock") - public void testRleAppendForComplexTypeWithNullElement(TestType type, Block value) + private void testRleAppendForComplexTypeWithNullElement(TestType type, Block value) { checkArgument(value.getPositionCount() == 1); UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -207,31 +215,35 @@ public void testRleAppendForComplexTypeWithNullElement(TestType type, Block valu assertBlockEquals(type.getType(), actual, RunLengthEncodedBlock.create(value, 5)); } - @Test(dataProvider = "types") - public void testRleAppendedWithSinglePositionDoesNotProduceRle(TestType type) + @Test + public void testRleAppendedWithSinglePositionDoesNotProduceRle() { - UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + for (TestType type : TestType.values()) { + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - Block value = notNullBlock(type, 1); - positionsAppender.append(allPositions(3), rleBlock(value, 3)); - positionsAppender.append(allPositions(2), rleBlock(value, 2)); - positionsAppender.append(0, rleBlock(value, 2)); + Block value = notNullBlock(type, 1); + positionsAppender.append(allPositions(3), rleBlock(value, 3)); + positionsAppender.append(allPositions(2), rleBlock(value, 2)); + positionsAppender.append(0, rleBlock(value, 2)); - Block actual = positionsAppender.build(); - assertThat(actual.getPositionCount()).isEqualTo(6); - assertThat(actual instanceof RunLengthEncodedBlock) - .describedAs(actual.getClass().getSimpleName()) - .isFalse(); + Block actual = positionsAppender.build(); + assertThat(actual.getPositionCount()).isEqualTo(6); + assertThat(actual instanceof RunLengthEncodedBlock) + .describedAs(actual.getClass().getSimpleName()) + .isFalse(); + } } - @Test(dataProvider = "types") - public static void testMultipleTheSameDictionariesProduceDictionary(TestType type) + @Test + public void testMultipleTheSameDictionariesProduceDictionary() { - UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + for (TestType type : TestType.values()) { + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); - // test if appender can accept different dictionary after a build - testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); + testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); + // test if appender can accept different dictionary after a build + testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); + } } private static void testMultipleTheSameDictionariesProduceDictionary(TestType type, UnnestingPositionsAppender positionsAppender) @@ -246,91 +258,98 @@ private static void testMultipleTheSameDictionariesProduceDictionary(TestType ty assertThat(((DictionaryBlock) actual).getDictionary()).isEqualTo(dictionary); } - @Test(dataProvider = "types") - public void testDictionarySwitchToFlat(TestType type) + @Test + public void testDictionarySwitchToFlat() { - List inputs = ImmutableList.of( - input(dictionaryBlock(type, 3, 4, 0), 0, 1), - input(notNullBlock(type, 2), 0, 1)); - testAppend(type, inputs); + for (TestType type : TestType.values()) { + List inputs = ImmutableList.of( + input(dictionaryBlock(type, 3, 4, 0), 0, 1), + input(notNullBlock(type, 2), 0, 1)); + testAppend(type, inputs); + } } - @Test(dataProvider = "types") - public void testFlatAppendDictionary(TestType type) + @Test + public void testFlatAppendDictionary() { - List inputs = ImmutableList.of( - input(notNullBlock(type, 2), 0, 1), - input(dictionaryBlock(type, 3, 4, 0), 0, 1)); - testAppend(type, inputs); + for (TestType type : TestType.values()) { + List inputs = ImmutableList.of( + input(notNullBlock(type, 2), 0, 1), + input(dictionaryBlock(type, 3, 4, 0), 0, 1)); + testAppend(type, inputs); + } } - @Test(dataProvider = "types") - public void testDictionaryAppendDifferentDictionary(TestType type) + @Test + public void testDictionaryAppendDifferentDictionary() { - List dictionaryInputs = ImmutableList.of( - input(dictionaryBlock(type, 3, 4, 0), 0, 1), - input(dictionaryBlock(type, 2, 4, 0), 0, 1)); - testAppend(type, dictionaryInputs); + for (TestType type : TestType.values()) { + List dictionaryInputs = ImmutableList.of( + input(dictionaryBlock(type, 3, 4, 0), 0, 1), + input(dictionaryBlock(type, 2, 4, 0), 0, 1)); + testAppend(type, dictionaryInputs); + } } - @Test(dataProvider = "types") - public void testDictionarySingleThenFlat(TestType type) + @Test + public void testDictionarySingleThenFlat() { - BlockView firstInput = input(dictionaryBlock(type, 1, 4, 0), 0); - BlockView secondInput = input(dictionaryBlock(type, 2, 4, 0), 0, 1); - UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); + for (TestType type : TestType.values()) { + BlockView firstInput = input(dictionaryBlock(type, 1, 4, 0), 0); + BlockView secondInput = input(dictionaryBlock(type, 2, 4, 0), 0, 1); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - firstInput.positions().forEach((int position) -> positionsAppender.append(position, firstInput.block())); - positionsAppender.append(secondInput.positions(), secondInput.block()); + firstInput.positions().forEach((int position) -> positionsAppender.append(position, firstInput.block())); + positionsAppender.append(secondInput.positions(), secondInput.block()); - assertBuildResult(type, ImmutableList.of(firstInput, secondInput), positionsAppender, initialRetainedSize); + assertBuildResult(type, ImmutableList.of(firstInput, secondInput), positionsAppender, initialRetainedSize); + } } - @Test(dataProvider = "types") - public void testConsecutiveBuilds(TestType type) - { - UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + @Test + public void testConsecutiveBuilds() + { + for (TestType type : TestType.values()) { + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + + // empty block + positionsAppender.append(positions(), emptyBlock(type)); + assertThat(positionsAppender.build().getPositionCount()).isEqualTo(0); + + Block block = createRandomBlockForType(type, 2, 0.5f); + // append only null position + int nullPosition = block.isNull(0) ? 0 : 1; + positionsAppender.append(positions(nullPosition), block); + Block actualNullBlock = positionsAppender.build(); + assertThat(actualNullBlock.getPositionCount()).isEqualTo(1); + assertThat(actualNullBlock.isNull(0)).isTrue(); + + // append null and not null position + positionsAppender.append(allPositions(2), block); + assertBlockEquals(type.getType(), positionsAppender.build(), block); + + // append not null rle + Block rleBlock = rleBlock(type, 10); + positionsAppender.append(allPositions(10), rleBlock); + assertBlockEquals(type.getType(), positionsAppender.build(), rleBlock); + + // append null rle + Block nullRleBlock = nullRleBlock(type, 10); + positionsAppender.append(allPositions(10), nullRleBlock); + assertBlockEquals(type.getType(), positionsAppender.build(), nullRleBlock); + + // append dictionary + Block dictionaryBlock = dictionaryBlock(type, 10, 5, 0); + positionsAppender.append(allPositions(10), dictionaryBlock); + assertBlockEquals(type.getType(), positionsAppender.build(), dictionaryBlock); + + // just build to confirm appender was reset + assertThat(positionsAppender.build().getPositionCount()).isEqualTo(0); + } + } - // empty block - positionsAppender.append(positions(), emptyBlock(type)); - assertThat(positionsAppender.build().getPositionCount()).isEqualTo(0); - - Block block = createRandomBlockForType(type, 2, 0.5f); - // append only null position - int nullPosition = block.isNull(0) ? 0 : 1; - positionsAppender.append(positions(nullPosition), block); - Block actualNullBlock = positionsAppender.build(); - assertThat(actualNullBlock.getPositionCount()).isEqualTo(1); - assertThat(actualNullBlock.isNull(0)).isTrue(); - - // append null and not null position - positionsAppender.append(allPositions(2), block); - assertBlockEquals(type.getType(), positionsAppender.build(), block); - - // append not null rle - Block rleBlock = rleBlock(type, 10); - positionsAppender.append(allPositions(10), rleBlock); - assertBlockEquals(type.getType(), positionsAppender.build(), rleBlock); - - // append null rle - Block nullRleBlock = nullRleBlock(type, 10); - positionsAppender.append(allPositions(10), nullRleBlock); - assertBlockEquals(type.getType(), positionsAppender.build(), nullRleBlock); - - // append dictionary - Block dictionaryBlock = dictionaryBlock(type, 10, 5, 0); - positionsAppender.append(allPositions(10), dictionaryBlock); - assertBlockEquals(type.getType(), positionsAppender.build(), dictionaryBlock); - - // just build to confirm appender was reset - assertThat(positionsAppender.build().getPositionCount()).isEqualTo(0); - } - - // testcase for jit bug described https://github.com/trinodb/trino/issues/12821. - // this test needs to be run first (hence the lowest priority) as the test order - // influences jit compilation, making this problem to not occur if other tests are run first. - @Test(priority = Integer.MIN_VALUE) + @Test public void testSliceRle() { UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(VARCHAR, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -362,23 +381,6 @@ public void testRowWithNestedFields() assertBlockEquals(type, actual, rowBLock); } - @DataProvider(name = "complexTypesWithNullElementBlock") - public static Object[][] complexTypesWithNullElementBlock() - { - return new Object[][] { - {TestType.ROW_BIGINT_VARCHAR, RowBlock.fromFieldBlocks(1, new Block[] {nullBlock(BIGINT, 1), nullBlock(VARCHAR, 1)})}, - {TestType.ARRAY_BIGINT, ArrayBlock.fromElementBlock(1, Optional.empty(), new int[] {0, 1}, nullBlock(BIGINT, 1))}}; - } - - @DataProvider(name = "types") - public static Object[][] types() - { - return Arrays.stream(TestType.values()) - .filter(testType -> testType != TestType.UNKNOWN) - .map(type -> new Object[] {type}) - .toArray(Object[][]::new); - } - private static ValueBlock singleValueBlock(String value) { BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 1); @@ -577,8 +579,8 @@ private enum TestType LONG_TIMESTAMP(createTimestampType(9)), ROW_BIGINT_VARCHAR(anonymousRow(BigintType.BIGINT, VarcharType.VARCHAR)), ARRAY_BIGINT(new ArrayType(BigintType.BIGINT)), - VARCHAR_WITH_TEST_BLOCK(VarcharType.VARCHAR, adaptation()), - UNKNOWN(UnknownType.UNKNOWN); + VARCHAR_WITH_TEST_BLOCK(VarcharType.VARCHAR, adaptation()); +// UNKNOWN(UnknownType.UNKNOWN); private final Type type; private final Function blockAdaptation; diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java index bdcc3daca7a7..069feb0b1169 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java @@ -23,11 +23,13 @@ import io.trino.spi.block.LazyBlock; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; -import org.testng.annotations.AfterClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.Arrays; import java.util.concurrent.ScheduledExecutorService; @@ -44,21 +46,16 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Fail.fail; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestDictionaryAwarePageProjection { private static final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("TestDictionaryAwarePageProjection-%s")); - @DataProvider(name = "forceYield") - public static Object[][] forceYieldAndProduceLazyBlock() - { - return new Object[][] { - {true, false}, - {false, true}, - {false, false}}; - } - - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { executor.shutdownNow(); @@ -73,54 +70,66 @@ public void testDelegateMethods() assertThat(projection.getType()).isEqualTo(BIGINT); } - @Test(dataProvider = "forceYield") - public void testSimpleBlock(boolean forceYield, boolean produceLazyBlock) + @Test + public void testSimpleBlock() { - Block block = createLongSequenceBlock(0, 100); - testProject(block, block.getClass(), forceYield, produceLazyBlock); + ValueBlock block = createLongSequenceBlock(0, 100); + testProject(block, block.getClass(), true, false); + testProject(block, block.getClass(), false, true); + testProject(block, block.getClass(), false, false); } - @Test(dataProvider = "forceYield") - public void testRleBlock(boolean forceYield, boolean produceLazyBlock) + @Test + public void testRleBlock() { Block value = createLongSequenceBlock(42, 43); RunLengthEncodedBlock block = (RunLengthEncodedBlock) RunLengthEncodedBlock.create(value, 100); - testProject(block, RunLengthEncodedBlock.class, forceYield, produceLazyBlock); + testProject(block, RunLengthEncodedBlock.class, true, false); + testProject(block, RunLengthEncodedBlock.class, false, true); + testProject(block, RunLengthEncodedBlock.class, false, false); } - @Test(dataProvider = "forceYield") - public void testRleBlockWithFailure(boolean forceYield, boolean produceLazyBlock) + @Test + public void testRleBlockWithFailure() { Block value = createLongSequenceBlock(-43, -42); RunLengthEncodedBlock block = (RunLengthEncodedBlock) RunLengthEncodedBlock.create(value, 100); - testProjectFails(block, RunLengthEncodedBlock.class, forceYield, produceLazyBlock); + testProjectFails(block, RunLengthEncodedBlock.class, true, false); + testProjectFails(block, RunLengthEncodedBlock.class, false, true); + testProjectFails(block, RunLengthEncodedBlock.class, false, false); } - @Test(dataProvider = "forceYield") - public void testDictionaryBlock(boolean forceYield, boolean produceLazyBlock) + @Test + public void testDictionaryBlock() { Block block = createDictionaryBlock(10, 100); - testProject(block, DictionaryBlock.class, forceYield, produceLazyBlock); + testProject(block, DictionaryBlock.class, true, false); + testProject(block, DictionaryBlock.class, false, true); + testProject(block, DictionaryBlock.class, false, false); } - @Test(dataProvider = "forceYield") - public void testDictionaryBlockWithFailure(boolean forceYield, boolean produceLazyBlock) + @Test + public void testDictionaryBlockWithFailure() { Block block = createDictionaryBlockWithFailure(10, 100); - testProjectFails(block, DictionaryBlock.class, forceYield, produceLazyBlock); + testProjectFails(block, DictionaryBlock.class, true, false); + testProjectFails(block, DictionaryBlock.class, false, true); + testProjectFails(block, DictionaryBlock.class, false, false); } - @Test(dataProvider = "forceYield") - public void testDictionaryBlockProcessingWithUnusedFailure(boolean forceYield, boolean produceLazyBlock) + @Test + public void testDictionaryBlockProcessingWithUnusedFailure() { Block block = createDictionaryBlockWithUnusedEntries(10, 100); // failures in the dictionary processing will cause a fallback to normal columnar processing - testProject(block, LongArrayBlock.class, forceYield, produceLazyBlock); + testProject(block, LongArrayBlock.class, true, false); + testProject(block, LongArrayBlock.class, false, true); + testProject(block, LongArrayBlock.class, false, false); } @Test @@ -136,8 +145,15 @@ public void testDictionaryProcessingIgnoreYield() testProjectFastReturnIgnoreYield(block, projection, false); } - @Test(dataProvider = "forceYield") - public void testDictionaryProcessingEnableDisable(boolean forceYield, boolean produceLazyBlock) + @Test + public void testDictionaryProcessingEnableDisable() + { + testDictionaryProcessingEnableDisable(true, false); + testDictionaryProcessingEnableDisable(false, true); + testDictionaryProcessingEnableDisable(false, false); + } + + private void testDictionaryProcessingEnableDisable(boolean forceYield, boolean produceLazyBlock) { DictionaryAwarePageProjection projection = createProjection(produceLazyBlock); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestScalarValidation.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestScalarValidation.java index 3e4927d96387..e73030b32fa1 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestScalarValidation.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestScalarValidation.java @@ -24,15 +24,19 @@ import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; @SuppressWarnings("UtilityClassWithoutPrivateConstructor") public class TestScalarValidation { - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Parametric class method .* is annotated with @ScalarFunction") + @Test public void testBogusParametricMethodAnnotation() { - extractParametricScalar(BogusParametricMethodAnnotation.class); + assertThatThrownBy(() -> extractParametricScalar(BogusParametricMethodAnnotation.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Parametric class method .* is annotated with @ScalarFunction"); } @ScalarFunction @@ -42,20 +46,24 @@ public static final class BogusParametricMethodAnnotation public static void bad() {} } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = "Parametric class .* does not have any annotated methods") + @Test public void testNoParametricMethods() { - extractParametricScalar(NoParametricMethods.class); + assertThatThrownBy(() -> extractParametricScalar(NoParametricMethods.class)) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("Parametric class .* does not have any annotated methods"); } @SuppressWarnings("EmptyClass") @ScalarFunction public static final class NoParametricMethods {} - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* is missing @SqlType annotation") + @Test public void testMethodMissingReturnAnnotation() { - extractScalars(MethodMissingReturnAnnotation.class); + assertThatThrownBy(() -> extractScalars(MethodMissingReturnAnnotation.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* is missing @SqlType annotation"); } public static final class MethodMissingReturnAnnotation @@ -64,10 +72,12 @@ public static final class MethodMissingReturnAnnotation public static void bad() {} } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = "Method .* annotated with @SqlType is missing @ScalarFunction or @ScalarOperator") + @Test public void testMethodMissingScalarAnnotation() { - extractScalars(MethodMissingScalarAnnotation.class); + assertThatThrownBy(() -> extractScalars(MethodMissingScalarAnnotation.class)) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("Method .* annotated with @SqlType is missing @ScalarFunction or @ScalarOperator"); } public static final class MethodMissingScalarAnnotation @@ -77,10 +87,12 @@ public static final class MethodMissingScalarAnnotation public static void bad() {} } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has wrapper return type Long but is missing @SqlNullable") + @Test public void testPrimitiveWrapperReturnWithoutNullable() { - extractScalars(PrimitiveWrapperReturnWithoutNullable.class); + assertThatThrownBy(() -> extractScalars(PrimitiveWrapperReturnWithoutNullable.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* has wrapper return type Long but is missing @SqlNullable"); } public static final class PrimitiveWrapperReturnWithoutNullable @@ -93,10 +105,12 @@ public static Long bad() } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* annotated with @SqlNullable has primitive return type long") + @Test public void testPrimitiveReturnWithNullable() { - extractScalars(PrimitiveReturnWithNullable.class); + assertThatThrownBy(() -> extractScalars(PrimitiveReturnWithNullable.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* annotated with @SqlNullable has primitive return type long"); } public static final class PrimitiveReturnWithNullable @@ -110,10 +124,12 @@ public static long bad() } } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*") + @Test public void testPrimitiveWrapperParameterWithoutNullable() { - extractScalars(PrimitiveWrapperParameterWithoutNullable.class); + assertThatThrownBy(() -> extractScalars(PrimitiveWrapperParameterWithoutNullable.class)) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*"); } public static final class PrimitiveWrapperParameterWithoutNullable @@ -126,10 +142,12 @@ public static long bad(@SqlType(StandardTypes.BOOLEAN) Boolean boxed) } } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = "Method .* has parameter with primitive type double annotated with @SqlNullable") + @Test public void testPrimitiveParameterWithNullable() { - extractScalars(PrimitiveParameterWithNullable.class); + assertThatThrownBy(() -> extractScalars(PrimitiveParameterWithNullable.class)) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("Method .* has parameter with primitive type double annotated with @SqlNullable"); } public static final class PrimitiveParameterWithNullable @@ -142,10 +160,12 @@ public static long bad(@SqlNullable @SqlType(StandardTypes.DOUBLE) double primit } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* is missing @SqlType annotation for parameter") + @Test public void testParameterWithoutType() { - extractScalars(ParameterWithoutType.class); + assertThatThrownBy(() -> extractScalars(ParameterWithoutType.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* is missing @SqlType annotation for parameter"); } public static final class ParameterWithoutType @@ -158,10 +178,12 @@ public static long bad(long missing) } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* annotated with @ScalarFunction must be public") + @Test public void testNonPublicAnnnotatedMethod() { - extractScalars(NonPublicAnnnotatedMethod.class); + assertThatThrownBy(() -> extractScalars(NonPublicAnnnotatedMethod.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* annotated with @ScalarFunction must be public"); } public static final class NonPublicAnnnotatedMethod @@ -174,10 +196,12 @@ private static long bad() } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* is annotated with @Nullable but not @SqlNullable") + @Test public void testMethodWithLegacyNullable() { - extractScalars(MethodWithLegacyNullable.class); + assertThatThrownBy(() -> extractScalars(MethodWithLegacyNullable.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* is annotated with @Nullable but not @SqlNullable"); } public static final class MethodWithLegacyNullable @@ -191,10 +215,12 @@ public static Long bad() } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has @IsNull parameter that does not follow a @SqlType parameter") + @Test public void testParameterWithConnectorAndIsNull() { - extractScalars(ParameterWithConnectorAndIsNull.class); + assertThatThrownBy(() -> extractScalars(ParameterWithConnectorAndIsNull.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* has @IsNull parameter that does not follow a @SqlType parameter"); } public static final class ParameterWithConnectorAndIsNull @@ -207,10 +233,12 @@ public static long bad(ConnectorSession session, @IsNull boolean isNull) } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has @IsNull parameter that does not follow a @SqlType parameter") + @Test public void testParameterWithOnlyIsNull() { - extractScalars(ParameterWithOnlyIsNull.class); + assertThatThrownBy(() -> extractScalars(ParameterWithOnlyIsNull.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* has @IsNull parameter that does not follow a @SqlType parameter"); } public static final class ParameterWithOnlyIsNull @@ -223,10 +251,12 @@ public static long bad(@IsNull boolean isNull) } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has non-boolean parameter with @IsNull") + @Test public void testParameterWithNonBooleanIsNull() { - extractScalars(ParameterWithNonBooleanIsNull.class); + assertThatThrownBy(() -> extractScalars(ParameterWithNonBooleanIsNull.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* has non-boolean parameter with @IsNull"); } public static final class ParameterWithNonBooleanIsNull @@ -239,10 +269,12 @@ public static long bad(@SqlType(StandardTypes.BIGINT) long value, @IsNull int is } } - @Test(expectedExceptions = TrinoException.class, expectedExceptionsMessageRegExp = "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*") + @Test public void testParameterWithBoxedPrimitiveIsNull() { - extractScalars(ParameterWithBoxedPrimitiveIsNull.class); + assertThatThrownBy(() -> extractScalars(ParameterWithBoxedPrimitiveIsNull.class)) + .isInstanceOf(TrinoException.class) + .hasMessageMatching("A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method .*"); } public static final class ParameterWithBoxedPrimitiveIsNull @@ -255,10 +287,12 @@ public static long bad(@SqlType(StandardTypes.BIGINT) Long value, @IsNull boolea } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Method .* has @IsNull parameter that has other annotations") + @Test public void testParameterWithOtherAnnotationsWithIsNull() { - extractScalars(ParameterWithOtherAnnotationsWithIsNull.class); + assertThatThrownBy(() -> extractScalars(ParameterWithOtherAnnotationsWithIsNull.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Method .* has @IsNull parameter that has other annotations"); } public static final class ParameterWithOtherAnnotationsWithIsNull @@ -271,10 +305,12 @@ public static long bad(@SqlType(StandardTypes.BIGINT) long value, @IsNull @SqlNu } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got bad on method .*") + @Test public void testNonUpperCaseTypeParameters() { - extractScalars(TypeParameterWithNonUpperCaseAnnotation.class); + assertThatThrownBy(() -> extractScalars(TypeParameterWithNonUpperCaseAnnotation.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got bad on method .*"); } public static final class TypeParameterWithNonUpperCaseAnnotation @@ -288,10 +324,12 @@ public static long bad(@TypeParameter("array(bad)") Type type, @SqlType(Standard } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got 1E on method .*") + @Test public void testLeadingNumericTypeParameters() { - extractScalars(TypeParameterWithLeadingNumbers.class); + assertThatThrownBy(() -> extractScalars(TypeParameterWithLeadingNumbers.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got 1E on method .*"); } public static final class TypeParameterWithLeadingNumbers @@ -305,10 +343,12 @@ public static long bad(@TypeParameter("array(1E)") Type type, @SqlType(StandardT } } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter not to take parameters, but got 'e' on method .*") + @Test public void testNonPrimitiveTypeParameters() { - extractScalars(TypeParameterWithNonPrimitiveAnnotation.class); + assertThatThrownBy(() -> extractScalars(TypeParameterWithNonPrimitiveAnnotation.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Expected type parameter not to take parameters, but got 'e' on method .*"); } public static final class TypeParameterWithNonPrimitiveAnnotation @@ -357,10 +397,12 @@ public void testValidTypeParametersForConstructors() extractParametricScalar(ConstructorWithValidTypeParameters.class); } - @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter not to take parameters, but got 'k' on method .*") + @Test public void testInvalidTypeParametersForConstructors() { - extractParametricScalar(ConstructorWithInvalidTypeParameters.class); + assertThatThrownBy(() -> extractParametricScalar(ConstructorWithInvalidTypeParameters.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageMatching("Expected type parameter not to take parameters, but got 'k' on method .*"); } private static void extractParametricScalar(Class clazz) diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index c1d7139e5e02..e46f87352446 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -58,9 +58,10 @@ import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import javax.crypto.SecretKey; @@ -122,7 +123,11 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestResourceSecurity { private static final String LOCALHOST_KEYSTORE = Resources.getResource("cert/localhost.pem").getPath(); @@ -164,7 +169,7 @@ public class TestResourceSecurity private OkHttpClient client; private Path passwordConfigDummy; - @BeforeClass + @BeforeAll public void setup() throws IOException { @@ -784,8 +789,16 @@ public HttpCookie getNonceCookie() } } - @Test(dataProvider = "groups") - public void testOAuth2Groups(Optional> groups) + @Test + public void testOAuth2Groups() + throws Exception + { + testOAuth2Groups(Optional.empty()); + testOAuth2Groups(Optional.of(ImmutableSet.of())); + testOAuth2Groups(Optional.of(ImmutableSet.of("admin", "public"))); + } + + private void testOAuth2Groups(Optional> groups) throws Exception { try (TokenServer tokenServer = new TokenServer(Optional.empty()); @@ -855,18 +868,15 @@ public List loadForRequest(HttpUrl url) } } - @DataProvider(name = "groups") - public static Object[][] groups() + @Test + public void testJwtAndOAuth2AuthenticatorsSeparation() + throws Exception { - return new Object[][] { - {Optional.empty()}, - {Optional.of(ImmutableSet.of())}, - {Optional.of(ImmutableSet.of("admin", "public"))} - }; + testJwtAndOAuth2AuthenticatorsSeparation("jwt,oauth2"); + testJwtAndOAuth2AuthenticatorsSeparation("oauth2,jwt"); } - @Test(dataProvider = "authenticators") - public void testJwtAndOAuth2AuthenticatorsSeparation(String authenticators) + private void testJwtAndOAuth2AuthenticatorsSeparation(String authenticators) throws Exception { TestingHttpServer jwkServer = createTestingJwkServer(); @@ -914,15 +924,6 @@ public void testJwtAndOAuth2AuthenticatorsSeparation(String authenticators) } } - @DataProvider(name = "authenticators") - public static Object[][] authenticators() - { - return new Object[][] { - {"jwt,oauth2"}, - {"oauth2,jwt"} - }; - } - @Test public void testJwtWithRefreshTokensForOAuth2Enabled() throws Exception diff --git a/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestJweTokenSerializer.java b/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestJweTokenSerializer.java index adfeb76d2c9a..918cd48e784e 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestJweTokenSerializer.java +++ b/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestJweTokenSerializer.java @@ -18,8 +18,7 @@ import io.jsonwebtoken.ExpiredJwtException; import io.jsonwebtoken.Jwts; import io.trino.server.security.oauth2.TokenPairSerializer.TokenPair; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.net.URI; import java.security.GeneralSecurityException; @@ -60,24 +59,29 @@ public void testSerialization() assertThat(deserializedTokenPair.refreshToken()).isEqualTo(Optional.of("refresh_token")); } - @Test(dataProvider = "wrongSecretsProvider") - public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret) + @Test + public void testDeserializationWithWrongSecret() { - assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret))) + assertThatThrownBy(() -> assertRoundTrip(Optional.of(randomEncodedSecret()), Optional.of(randomEncodedSecret()))) .isInstanceOf(RuntimeException.class) .hasMessageContaining("decryption failed: Tag mismatch"); - } - @DataProvider - public Object[][] wrongSecretsProvider() - { - return new Object[][]{ - {randomEncodedSecret(), randomEncodedSecret()}, - {randomEncodedSecret(16), randomEncodedSecret(24)}, - {null, null}, // This will generate two different secret keys - {null, randomEncodedSecret()}, - {randomEncodedSecret(), null} - }; + assertThatThrownBy(() -> assertRoundTrip(Optional.of(randomEncodedSecret(16)), Optional.of(randomEncodedSecret(24)))) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("decryption failed: Tag mismatch"); + + // This will generate two different secret keys + assertThatThrownBy(() -> assertRoundTrip(Optional.empty(), Optional.empty())) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("decryption failed: Tag mismatch"); + + assertThatThrownBy(() -> assertRoundTrip(Optional.empty(), Optional.of(randomEncodedSecret()))) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("decryption failed: Tag mismatch"); + + assertThatThrownBy(() -> assertRoundTrip(Optional.of(randomEncodedSecret()), Optional.empty())) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("decryption failed: Tag mismatch"); } @Test diff --git a/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestOidcDiscovery.java b/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestOidcDiscovery.java index bac63b99be72..ee2871dcf284 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestOidcDiscovery.java +++ b/core/trino-main/src/test/java/io/trino/server/security/oauth2/TestOidcDiscovery.java @@ -29,8 +29,7 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.net.URI; @@ -49,8 +48,15 @@ public class TestOidcDiscovery { - @Test(dataProvider = "staticConfiguration") - public void testStaticConfiguration(Optional accessTokenPath, Optional userinfoPath) + @Test + public void testStaticConfiguration() + throws Exception + { + testStaticConfiguration(Optional.empty(), Optional.empty()); + testStaticConfiguration(Optional.of("/access-token-issuer"), Optional.of("/userinfo")); + } + + private void testStaticConfiguration(Optional accessTokenPath, Optional userinfoPath) throws Exception { try (MetadataServer metadataServer = new MetadataServer(ImmutableMap.of("/jwks.json", "jwk/jwk-public.json"))) { @@ -72,17 +78,16 @@ public void testStaticConfiguration(Optional accessTokenPath, Optional accessTokenIssuer, Optional userinfoUrl) + private void testOidcDiscovery(String configuration, Optional accessTokenIssuer, Optional userinfoUrl) throws Exception { try (MetadataServer metadataServer = new MetadataServer( @@ -100,16 +105,6 @@ public void testOidcDiscovery(String configuration, Optional accessToken } } - @DataProvider(name = "oidcDiscovery") - public static Object[][] oidcDiscovery() - { - return new Object[][] { - {"openid-configuration.json", Optional.empty(), Optional.of("/connect/userinfo")}, - {"openid-configuration-without-userinfo.json", Optional.empty(), Optional.empty()}, - {"openid-configuration-with-access-token-issuer.json", Optional.of("http://access-token-issuer.com/adfs/services/trust"), Optional.of("/connect/userinfo")}, - }; - } - @Test public void testIssuerCheck() { diff --git a/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java b/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java index f289a9e287dc..008dd7f0e21e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/sql/ExpressionTestUtils.java @@ -44,9 +44,6 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.transaction.TransactionBuilder.transaction; -import static org.testng.internal.EclipseInterface.ASSERT_LEFT; -import static org.testng.internal.EclipseInterface.ASSERT_MIDDLE; -import static org.testng.internal.EclipseInterface.ASSERT_RIGHT; public final class ExpressionTestUtils { @@ -73,7 +70,7 @@ private static void failNotEqual(Object actual, Object expected, String message) if (message != null) { formatted = message + " "; } - throw new AssertionError(formatted + ASSERT_LEFT + expected + ASSERT_MIDDLE + actual + ASSERT_RIGHT); + throw new AssertionError(formatted + " expected [" + expected + "] but found [" + actual + "]"); } public static Expression createExpression(Session session, String expression, TransactionManager transactionManager, PlannerContext plannerContext, TypeProvider symbolTypes) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 32a3e4bc0321..d46dcfd2abd7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -76,8 +76,10 @@ import io.trino.testing.TestingSession; import io.trino.testing.TestingTransactionHandle; import io.trino.transaction.TestingTransactionManager; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.Arrays; import java.util.Collection; @@ -115,8 +117,11 @@ import static io.trino.transaction.TransactionBuilder.transaction; import static io.trino.type.UnknownType.UNKNOWN; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@TestInstance(PER_METHOD) +@Execution(SAME_THREAD) public class TestEffectivePredicateExtractor { private static final Symbol A = new Symbol("a"); @@ -176,7 +181,7 @@ public TableProperties getTableProperties(Session session, TableHandle handle) private TableScanNode baseTableScan; private ExpressionIdentityNormalizer expressionNormalizer; - @BeforeMethod + @BeforeEach public void setUp() { scanAssignments = ImmutableMap.builder() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java index d343eceee623..6380bffab53e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestCounterBasedAnonymizer.java @@ -22,7 +22,6 @@ import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.IntervalLiteral; -import io.trino.sql.tree.Literal; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.NullLiteral; @@ -30,8 +29,7 @@ import io.trino.sql.tree.SymbolReference; import io.trino.sql.tree.TimeLiteral; import io.trino.sql.tree.TimestampLiteral; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Optional; @@ -64,34 +62,51 @@ public void testSymbolReferenceAnonymization() .isEqualTo("((\"symbol_1\" > 'long_literal_1') AND (\"symbol_2\" < 'long_literal_2') AND (\"symbol_3\" = 'long_literal_3'))"); } - @Test(dataProvider = "literals") - public void testLiteralAnonymization(Literal actual, String expected) + @Test + public void testLiteralAnonymization() { CounterBasedAnonymizer anonymizer = new CounterBasedAnonymizer(); - assertThat(anonymizer.anonymize(actual)).isEqualTo(expected); - } - @DataProvider - public static Object[][] literals() - { - return new Object[][] { - {new BinaryLiteral("DEF321"), "'binary_literal_1'"}, - {new StringLiteral("abc"), "'string_literal_1'"}, - {new GenericLiteral("bigint", "1"), "'bigint_literal_1'"}, - {new CharLiteral("a"), "'char_literal_1'"}, - {new DecimalLiteral("123"), "'decimal_literal_1'"}, - {new DoubleLiteral(String.valueOf(6554)), "'double_literal_1'"}, - {new DoubleLiteral(String.valueOf(Double.MAX_VALUE)), "'double_literal_1'"}, - {new LongLiteral(String.valueOf(6554)), "'long_literal_1'"}, - {new LongLiteral(String.valueOf(Long.MAX_VALUE)), "'long_literal_1'"}, - {new BooleanLiteral("true"), "true"}, - {new TimeLiteral("03:04:05"), "'time_literal_1'"}, - {new TimestampLiteral("2012-10-31 01:00 UTC"), "'timestamp_literal_1'"}, - {new NullLiteral(), "null"}, - { - new IntervalLiteral("33", IntervalLiteral.Sign.POSITIVE, IntervalLiteral.IntervalField.DAY, Optional.empty()), - "'interval_literal_1'" - } - }; + assertThat(anonymizer.anonymize(new BinaryLiteral("DEF321"))) + .isEqualTo("'binary_literal_1'"); + + assertThat(anonymizer.anonymize(new StringLiteral("abc"))) + .isEqualTo("'string_literal_2'"); + + assertThat(anonymizer.anonymize(new GenericLiteral("bigint", "1"))) + .isEqualTo("'bigint_literal_3'"); + + assertThat(anonymizer.anonymize(new CharLiteral("a"))) + .isEqualTo("'char_literal_4'"); + + assertThat(anonymizer.anonymize(new DecimalLiteral("123"))) + .isEqualTo("'decimal_literal_5'"); + + assertThat(anonymizer.anonymize(new DoubleLiteral(String.valueOf(6554)))) + .isEqualTo("'double_literal_6'"); + + assertThat(anonymizer.anonymize(new DoubleLiteral(String.valueOf(Double.MAX_VALUE)))) + .isEqualTo("'double_literal_7'"); + + assertThat(anonymizer.anonymize(new LongLiteral(String.valueOf(6554)))) + .isEqualTo("'long_literal_8'"); + + assertThat(anonymizer.anonymize(new LongLiteral(String.valueOf(Long.MAX_VALUE)))) + .isEqualTo("'long_literal_9'"); + + assertThat(anonymizer.anonymize(new BooleanLiteral("true"))) + .isEqualTo("true"); + + assertThat(anonymizer.anonymize(new TimeLiteral("03:04:05"))) + .isEqualTo("'time_literal_10'"); + + assertThat(anonymizer.anonymize(new TimestampLiteral("2012-10-31 01:00 UTC"))) + .isEqualTo("'timestamp_literal_11'"); + + assertThat(anonymizer.anonymize(new NullLiteral())) + .isEqualTo("null"); + + assertThat(anonymizer.anonymize(new IntervalLiteral("33", IntervalLiteral.Sign.POSITIVE, IntervalLiteral.IntervalField.DAY, Optional.empty()))) + .isEqualTo("'interval_literal_12'"); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java index b980fe60c51e..0012f7cd6218 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java @@ -39,10 +39,11 @@ import io.trino.sql.planner.plan.TableScanNode; import io.trino.testing.LocalQueryRunner; import io.trino.testing.TestingTransactionHandle; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; import java.util.Optional; @@ -55,10 +56,20 @@ import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; import static io.trino.testing.TestingHandles.createTestCatalogHandle; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public class TestValidateScaledWritersUsage extends BasePlanTest { + private static final PartitioningHandle CUSTOM_HANDLE = new PartitioningHandle( + Optional.of(TEST_CATALOG_HANDLE), + Optional.of(new ConnectorTransactionHandle() { }), + new ConnectorPartitioningHandle() { }, + true); + private LocalQueryRunner queryRunner; private PlannerContext plannerContext; private PlanBuilder planBuilder; @@ -67,7 +78,7 @@ public class TestValidateScaledWritersUsage private CatalogHandle catalog; private SchemaTableName schemaTableName; - @BeforeClass + @BeforeAll public void setup() { schemaTableName = new SchemaTableName("any", "any"); @@ -85,7 +96,7 @@ public void setup() tableScanNode = planBuilder.tableScan(nationTableHandle, ImmutableList.of(symbol), ImmutableMap.of(symbol, nationkeyColumnHandle)); } - @AfterClass(alwaysRun = true) + @AfterAll public void tearDown() { queryRunner.close(); @@ -104,8 +115,15 @@ private MockConnectorFactory createConnectorFactory(String name) .build(); } - @Test(dataProvider = "scaledWriterPartitioningHandles") - public void testScaledWritersUsedAndTargetSupportsIt(PartitioningHandle scaledWriterPartitionHandle) + @Test + public void testScaledWritersUsedAndTargetSupportsIt() + { + testScaledWritersUsedAndTargetSupportsIt(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION); + testScaledWritersUsedAndTargetSupportsIt(SCALED_WRITER_HASH_DISTRIBUTION); + testScaledWritersUsedAndTargetSupportsIt(CUSTOM_HANDLE); + } + + private void testScaledWritersUsedAndTargetSupportsIt(PartitioningHandle scaledWriterPartitionHandle) { PlanNode tableWriterSource = planBuilder.exchange(ex -> ex @@ -125,8 +143,15 @@ public void testScaledWritersUsedAndTargetSupportsIt(PartitioningHandle scaledWr validatePlan(root); } - @Test(dataProvider = "scaledWriterPartitioningHandles") - public void testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(PartitioningHandle scaledWriterPartitionHandle) + @Test + public void testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask() + { + testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION); + testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(SCALED_WRITER_HASH_DISTRIBUTION); + testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(CUSTOM_HANDLE); + } + + private void testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(PartitioningHandle scaledWriterPartitionHandle) { PlanNode tableWriterSource = planBuilder.exchange(ex -> ex @@ -149,8 +174,15 @@ public void testScaledWritersUsedAndTargetDoesNotSupportScalingPerTask(Partition .hasMessage("The scaled writer per task partitioning scheme is set but writer target catalog:INSTANCE doesn't support it"); } - @Test(dataProvider = "scaledWriterPartitioningHandles") - public void testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(PartitioningHandle scaledWriterPartitionHandle) + @Test + public void testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks() + { + testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION); + testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(SCALED_WRITER_HASH_DISTRIBUTION); + testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(CUSTOM_HANDLE); + } + + private void testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(PartitioningHandle scaledWriterPartitionHandle) { PlanNode tableWriterSource = planBuilder.exchange(ex -> ex @@ -173,8 +205,15 @@ public void testScaledWritersUsedAndTargetDoesNotSupportScalingAcrossTasks(Parti .hasMessage("The scaled writer across tasks partitioning scheme is set but writer target catalog:INSTANCE doesn't support it"); } - @Test(dataProvider = "scaledWriterPartitioningHandles") - public void testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition(PartitioningHandle scaledWriterPartitionHandle) + @Test + public void testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition() + { + testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION); + testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition(SCALED_WRITER_HASH_DISTRIBUTION); + testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition(CUSTOM_HANDLE); + } + + private void testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartition(PartitioningHandle scaledWriterPartitionHandle) { PlanNode tableWriterSource = planBuilder.exchange(ex -> ex @@ -202,8 +241,15 @@ public void testScaledWriterUsedAndTargetDoesNotSupportMultipleWritersPerPartiti } } - @Test(dataProvider = "scaledWriterPartitioningHandles") - public void testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition(PartitioningHandle scaledWriterPartitionHandle) + @Test + public void testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition() + { + testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION); + testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition(SCALED_WRITER_HASH_DISTRIBUTION); + testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition(CUSTOM_HANDLE); + } + + private void testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMultipleWritersPerPartition(PartitioningHandle scaledWriterPartitionHandle) { PlanNode tableWriterSource = planBuilder.exchange(ex -> ex @@ -237,20 +283,6 @@ public void testScaledWriterWithMultipleSourceExchangesAndTargetDoesNotSupportMu } } - @DataProvider - public Object[][] scaledWriterPartitioningHandles() - { - return new Object[][] { - {SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION}, - {SCALED_WRITER_HASH_DISTRIBUTION}, - {new PartitioningHandle( - Optional.of(TEST_CATALOG_HANDLE), - Optional.of(new ConnectorTransactionHandle() {}), - new ConnectorPartitioningHandle() {}, - true)} - }; - } - private void validatePlan(PlanNode root) { queryRunner.inTransaction(session -> { diff --git a/core/trino-main/src/test/java/io/trino/util/TestLongLong2LongOpenCustomBigHashMap.java b/core/trino-main/src/test/java/io/trino/util/TestLongLong2LongOpenCustomBigHashMap.java index e51ccfa23b87..b4a1b0164dba 100644 --- a/core/trino-main/src/test/java/io/trino/util/TestLongLong2LongOpenCustomBigHashMap.java +++ b/core/trino-main/src/test/java/io/trino/util/TestLongLong2LongOpenCustomBigHashMap.java @@ -13,8 +13,7 @@ */ package io.trino.util; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; @@ -38,14 +37,16 @@ public boolean equals(long a1, long a2, long b1, long b2) } }; - @DataProvider - public static Object[][] nullKeyValues() + @Test + public void testBasicOps() { - return new Object[][] {{0L, 0L}, {1L, 1L}, {-1L, -1L}, {0L, -1L}}; + testBasicOps(0L, 0L); + testBasicOps(1L, 1L); + testBasicOps(-1L, -1L); + testBasicOps(0L, -1L); } - @Test(dataProvider = "nullKeyValues") - public void testBasicOps(long nullKey1, long nullKey2) + private void testBasicOps(long nullKey1, long nullKey2) { int expected = 100_000; LongLong2LongOpenCustomBigHashMap map = new LongLong2LongOpenCustomBigHashMap(expected, DEFAULT_STRATEGY, nullKey1, nullKey2); @@ -101,8 +102,16 @@ public void testBasicOps(long nullKey1, long nullKey2) } } - @Test(dataProvider = "nullKeyValues") - public void testHashCollision(long nullKey1, long nullKey2) + @Test + public void testHashCollision() + { + testHashCollision(0L, 0L); + testHashCollision(1L, 1L); + testHashCollision(-1L, -1L); + testHashCollision(0L, -1L); + } + + private void testHashCollision(long nullKey1, long nullKey2) { LongLong2LongOpenCustomBigHashMap.HashStrategy collisionHashStrategy = new LongLong2LongOpenCustomBigHashMap.HashStrategy() { @@ -168,8 +177,16 @@ public boolean equals(long a1, long a2, long b1, long b2) } } - @Test(dataProvider = "nullKeyValues") - public void testRehash(long nullKey1, long nullKey2) + @Test + public void testRehash() + { + testRehash(0L, 0L); + testRehash(1L, 1L); + testRehash(-1L, -1L); + testRehash(0L, -1L); + } + + private void testRehash(long nullKey1, long nullKey2) { int initialCapacity = 1; LongLong2LongOpenCustomBigHashMap map = new LongLong2LongOpenCustomBigHashMap(initialCapacity, DEFAULT_STRATEGY, nullKey1, nullKey2); diff --git a/testing/trino-faulttolerant-tests/pom.xml b/testing/trino-faulttolerant-tests/pom.xml index 266fb41468e8..b0135bbd5c0b 100644 --- a/testing/trino-faulttolerant-tests/pom.xml +++ b/testing/trino-faulttolerant-tests/pom.xml @@ -13,16 +13,6 @@ ${project.parent.basedir} - - - instances @@ -434,33 +424,10 @@ testcontainers test - - - org.testng - testng - test - - - org.apache.maven.plugins - maven-surefire-plugin - - - - org.apache.maven.surefire - surefire-junit-platform - ${dep.plugin.surefire.version} - - - org.apache.maven.surefire - surefire-testng - ${dep.plugin.surefire.version} - - - org.basepom.maven duplicate-finder-maven-plugin diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestFaultTolerantExecutionDynamicFiltering.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestFaultTolerantExecutionDynamicFiltering.java index 9247c1ce3c60..facbec507a3c 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestFaultTolerantExecutionDynamicFiltering.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestFaultTolerantExecutionDynamicFiltering.java @@ -25,7 +25,9 @@ import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; import io.trino.testing.QueryRunner; import io.trino.testing.TestingMetadata; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; import java.util.Set; @@ -34,8 +36,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@Execution(SAME_THREAD) public class TestFaultTolerantExecutionDynamicFiltering extends AbstractTestCoordinatorDynamicFiltering { @@ -71,8 +74,7 @@ protected RetryPolicy getRetryPolicy() // results in each instance of DynamicFilterSourceOperator receiving fewer input rows. Therefore, testing max-distinct-values-per-driver // requires larger build side and the assertions on the collected domain are adjusted for multiple ranges instead of single range. @Override - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + protected void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -90,8 +92,7 @@ public void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistr } @Override - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + protected void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -109,7 +110,8 @@ public void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistribut } @Override - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithNonSelectiveBuildSide() { assertQueryDynamicFilters( diff --git a/testing/trino-tests/src/test/java/io/trino/execution/AbstractTestCoordinatorDynamicFiltering.java b/testing/trino-tests/src/test/java/io/trino/execution/AbstractTestCoordinatorDynamicFiltering.java index 3ceec0e49c32..0c441452f3e5 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/AbstractTestCoordinatorDynamicFiltering.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/AbstractTestCoordinatorDynamicFiltering.java @@ -48,9 +48,11 @@ import io.trino.testing.TestingPageSinkProvider; import io.trino.testing.TestingTransactionHandle; import org.intellij.lang.annotations.Language; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Map; @@ -84,7 +86,11 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.completedFuture; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) public abstract class AbstractTestCoordinatorDynamicFiltering extends AbstractTestQueryFramework { @@ -97,7 +103,7 @@ public abstract class AbstractTestCoordinatorDynamicFiltering private volatile Consumer> expectedCoordinatorDynamicFilterAssertion; private volatile Consumer> expectedTableScanDynamicFilterAssertion; - @BeforeClass + @BeforeAll public void setup() { // create lineitem table in test connector @@ -119,8 +125,16 @@ public void setup() protected abstract RetryPolicy getRetryPolicy(); - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithEmptyBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testJoinWithEmptyBuildSide() + { + testJoinWithEmptyBuildSide(BROADCAST, true); + testJoinWithEmptyBuildSide(PARTITIONED, true); + testJoinWithEmptyBuildSide(PARTITIONED, false); + } + + private void testJoinWithEmptyBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -129,8 +143,16 @@ public void testJoinWithEmptyBuildSide(JoinDistributionType joinDistributionType TupleDomain.none()); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithLargeBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testJoinWithLargeBuildSide() + { + testJoinWithLargeBuildSide(BROADCAST, true); + testJoinWithLargeBuildSide(PARTITIONED, true); + testJoinWithLargeBuildSide(PARTITIONED, false); + } + + private void testJoinWithLargeBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -139,8 +161,16 @@ public void testJoinWithLargeBuildSide(JoinDistributionType joinDistributionType TupleDomain.all()); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testMultiColumnJoinWithDifferentCardinalitiesInBuildSide() + { + testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(BROADCAST, true); + testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(PARTITIONED, true); + testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(PARTITIONED, false); + } + + private void testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { // orderkey has high cardinality, suppkey has low cardinality due to filter assertQueryDynamicFilters( @@ -154,8 +184,16 @@ public void testMultiColumnJoinWithDifferentCardinalitiesInBuildSide(JoinDistrib multipleValues(BIGINT, LongStream.rangeClosed(1L, 10L).boxed().collect(toImmutableList()))))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testJoinWithSelectiveBuildSide() + { + testJoinWithSelectiveBuildSide(BROADCAST, true); + testJoinWithSelectiveBuildSide(PARTITIONED, true); + testJoinWithSelectiveBuildSide(PARTITIONED, false); + } + + private void testJoinWithSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -166,7 +204,8 @@ public void testJoinWithSelectiveBuildSide(JoinDistributionType joinDistribution singleValue(BIGINT, 1L)))); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testInequalityJoinWithSelectiveBuildSide() { assertQueryDynamicFilters( @@ -195,7 +234,8 @@ public void testInequalityJoinWithSelectiveBuildSide() Domain.create(ValueSet.ofRanges(Range.greaterThan(BIGINT, 1L)), false)))); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testIsNotDistinctFromJoinWithSelectiveBuildSide() { assertQueryDynamicFilters( @@ -218,7 +258,8 @@ public void testIsNotDistinctFromJoinWithSelectiveBuildSide() Domain.onlyNull(BIGINT)))); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testJoinWithImplicitCoercion() { // setup fact table with integer suppkey @@ -246,8 +287,16 @@ public void testJoinWithImplicitCoercion() multipleValues(createVarcharType(40), values)))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testJoinWithNonSelectiveBuildSide() + { + testJoinWithNonSelectiveBuildSide(BROADCAST, true); + testJoinWithNonSelectiveBuildSide(PARTITIONED, true); + testJoinWithNonSelectiveBuildSide(PARTITIONED, false); + } + + protected void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -258,8 +307,16 @@ public void testJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistribut Domain.create(ValueSet.ofRanges(range(BIGINT, 1L, true, 100L, true)), false)))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testJoinWithMultipleDynamicFiltersOnProbe(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testJoinWithMultipleDynamicFiltersOnProbe() + { + testJoinWithMultipleDynamicFiltersOnProbe(BROADCAST, true); + testJoinWithMultipleDynamicFiltersOnProbe(PARTITIONED, true); + testJoinWithMultipleDynamicFiltersOnProbe(PARTITIONED, false); + } + + private void testJoinWithMultipleDynamicFiltersOnProbe(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { // supplier names Supplier#000000001 and Supplier#000000002 match suppkey 1 and 2 assertQueryDynamicFilters( @@ -274,7 +331,8 @@ public void testJoinWithMultipleDynamicFiltersOnProbe(JoinDistributionType joinD singleValue(BIGINT, 2L)))); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithEmptyBuildSide() { assertQueryDynamicFilters( @@ -283,7 +341,8 @@ public void testRightJoinWithEmptyBuildSide() TupleDomain.none()); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithNonSelectiveBuildSide() { assertQueryDynamicFilters( @@ -294,7 +353,8 @@ public void testRightJoinWithNonSelectiveBuildSide() Domain.create(ValueSet.ofRanges(range(BIGINT, 1L, true, 100L, true)), false)))); } - @Test(timeOut = 30_000) + @Test + @Timeout(30) public void testRightJoinWithSelectiveBuildSide() { assertQueryDynamicFilters( @@ -305,8 +365,16 @@ public void testRightJoinWithSelectiveBuildSide() singleValue(BIGINT, 1L)))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithEmptyBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testSemiJoinWithEmptyBuildSide() + { + testSemiJoinWithEmptyBuildSide(BROADCAST, true); + testSemiJoinWithEmptyBuildSide(PARTITIONED, true); + testSemiJoinWithEmptyBuildSide(PARTITIONED, false); + } + + private void testSemiJoinWithEmptyBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -315,8 +383,16 @@ public void testSemiJoinWithEmptyBuildSide(JoinDistributionType joinDistribution TupleDomain.none()); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithLargeBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testSemiJoinWithLargeBuildSide() + { + testSemiJoinWithLargeBuildSide(BROADCAST, true); + testSemiJoinWithLargeBuildSide(PARTITIONED, true); + testSemiJoinWithLargeBuildSide(PARTITIONED, false); + } + + private void testSemiJoinWithLargeBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -325,8 +401,16 @@ public void testSemiJoinWithLargeBuildSide(JoinDistributionType joinDistribution TupleDomain.all()); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testSemiJoinWithSelectiveBuildSide() + { + testSemiJoinWithSelectiveBuildSide(BROADCAST, true); + testSemiJoinWithSelectiveBuildSide(PARTITIONED, true); + testSemiJoinWithSelectiveBuildSide(PARTITIONED, false); + } + + private void testSemiJoinWithSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -337,8 +421,16 @@ public void testSemiJoinWithSelectiveBuildSide(JoinDistributionType joinDistribu singleValue(BIGINT, 1L)))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testSemiJoinWithNonSelectiveBuildSide() + { + testSemiJoinWithNonSelectiveBuildSide(BROADCAST, true); + testSemiJoinWithNonSelectiveBuildSide(PARTITIONED, true); + testSemiJoinWithNonSelectiveBuildSide(PARTITIONED, false); + } + + protected void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { assertQueryDynamicFilters( noJoinReordering(joinDistributionType, coordinatorDynamicFiltersDistribution), @@ -349,8 +441,16 @@ public void testSemiJoinWithNonSelectiveBuildSide(JoinDistributionType joinDistr Domain.create(ValueSet.ofRanges(range(BIGINT, 1L, true, 100L, true)), false)))); } - @Test(timeOut = 30_000, dataProvider = "testJoinDistributionType") - public void testSemiJoinWithMultipleDynamicFiltersOnProbe(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) + @Test + @Timeout(30) + public void testSemiJoinWithMultipleDynamicFiltersOnProbe() + { + testSemiJoinWithMultipleDynamicFiltersOnProbe(BROADCAST, true); + testSemiJoinWithMultipleDynamicFiltersOnProbe(PARTITIONED, true); + testSemiJoinWithMultipleDynamicFiltersOnProbe(PARTITIONED, false); + } + + private void testSemiJoinWithMultipleDynamicFiltersOnProbe(JoinDistributionType joinDistributionType, boolean coordinatorDynamicFiltersDistribution) { // supplier names Supplier#000000001 and Supplier#000000002 match suppkey 1 and 2 assertQueryDynamicFilters( @@ -378,15 +478,6 @@ protected Session getDefaultSession() .build(); } - @DataProvider - public Object[][] testJoinDistributionType() - { - return new Object[][] { - {BROADCAST, true}, - {PARTITIONED, true}, - {PARTITIONED, false}}; - } - protected Session noJoinReordering(JoinDistributionType distributionType, boolean coordinatorDynamicFiltersDistribution) { return Session.builder(noJoinReordering(distributionType)) diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestCoordinatorDynamicFiltering.java b/testing/trino-tests/src/test/java/io/trino/execution/TestCoordinatorDynamicFiltering.java index 1a52a4fe8fc9..d85e6430513d 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestCoordinatorDynamicFiltering.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestCoordinatorDynamicFiltering.java @@ -17,11 +17,12 @@ import io.trino.operator.RetryPolicy; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.QueryRunner; -import org.testng.annotations.Test; +import org.junit.jupiter.api.parallel.Execution; import static io.trino.operator.RetryPolicy.NONE; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; -@Test(singleThreaded = true) +@Execution(SAME_THREAD) public class TestCoordinatorDynamicFiltering extends AbstractTestCoordinatorDynamicFiltering {