diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 931b31b9a576..d5b819ec80b5 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -179,6 +179,7 @@ import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.optimizations.PlanOptimizer; +import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.TableScanNode; @@ -892,6 +893,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S } verify(builder.get() != null, "Output operator was not created"); + builder.get().columnNames(((OutputNode) plan.getRoot()).getColumnNames()); return new MaterializedResultWithPlan(builder.get().build(), plan); } catch (IOException e) { diff --git a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java index ce0b3b5afdb0..6c39a5cf938c 100644 --- a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java +++ b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java @@ -101,7 +101,12 @@ public class MaterializedResult public MaterializedResult(List rows, List types) { - this(rows, types, ImmutableList.of(), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty()); + this(rows, types, Optional.empty()); + } + + public MaterializedResult(List rows, List types, Optional> columnNames) + { + this(rows, types, columnNames.orElse(ImmutableList.of()), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty()); } public MaterializedResult( @@ -457,6 +462,7 @@ public static class Builder private final ConnectorSession session; private final List types; private final ImmutableList.Builder rows = ImmutableList.builder(); + private Optional> columnNames = Optional.empty(); Builder(ConnectorSession session, List types) { @@ -512,9 +518,15 @@ public synchronized Builder page(Page page) return this; } + public synchronized Builder columnNames(List columnNames) + { + this.columnNames = Optional.of(ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"))); + return this; + } + public synchronized MaterializedResult build() { - return new MaterializedResult(rows.build(), types); + return new MaterializedResult(rows.build(), types, columnNames); } } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index 91ca0442bda8..9d56b564feca 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -15,6 +15,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; @@ -51,11 +52,13 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Predicate; import java.util.stream.Collectors; -import java.util.stream.IntStream; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; @@ -323,30 +326,59 @@ private QueryAssert( this.skipResultsCorrectnessCheckForPushdown = skipResultsCorrectnessCheckForPushdown; } - // TODO for better readability, replace this with `exceptColumns(String... columnNamesToExclude)` leveraging MaterializedResult.getColumnNames - @Deprecated - public QueryAssert projected(int... columns) + public QueryAssert exceptColumns(String... columnNamesToExclude) { + validateIfColumnsPresent(columnNamesToExclude); + checkArgument(columnNamesToExclude.length > 0, "At least one column must be excluded"); + checkArgument(columnNamesToExclude.length < actual.getColumnNames().size(), "All columns cannot be excluded"); + return projected(((Predicate) Set.of(columnNamesToExclude)::contains).negate()); + } + + public QueryAssert projected(String... columnNamesToInclude) + { + validateIfColumnsPresent(columnNamesToInclude); + checkArgument(columnNamesToInclude.length > 0, "At least one column must be projected"); + return projected(Set.of(columnNamesToInclude)::contains); + } + + private QueryAssert projected(Predicate columnFilter) + { + List columnNames = actual.getColumnNames(); + Map columnsIndexToNameMap = new HashMap<>(); + for (int i = 0; i < columnNames.size(); i++) { + String columnName = columnNames.get(i); + if (columnFilter.test(columnName)) { + columnsIndexToNameMap.put(i, columnName); + } + } + return new QueryAssert( runner, session, - format("%s projected with %s", query, Arrays.toString(columns)), + format("%s projected with %s", query, columnsIndexToNameMap.values()), new MaterializedResult( actual.getMaterializedRows().stream() .map(row -> new MaterializedRow( row.getPrecision(), - IntStream.of(columns) - .mapToObj(row::getField) + columnsIndexToNameMap.keySet().stream() + .map(row::getField) .collect(toList()))) // values are nullable .collect(toImmutableList()), - IntStream.of(columns) - .mapToObj(actual.getTypes()::get) + columnsIndexToNameMap.keySet().stream() + .map(actual.getTypes()::get) .collect(toImmutableList())), ordered, skipTypesCheck, skipResultsCorrectnessCheckForPushdown); } + private void validateIfColumnsPresent(String... columns) + { + Set columnNames = ImmutableSet.copyOf(actual.getColumnNames()); + Arrays.stream(columns) + .forEach(column -> checkArgument(columnNames.contains(column), "[%s] column is not present in %s".formatted(column, columnNames))); + } + public QueryAssert matches(BiFunction evaluator) { MaterializedResult expected = evaluator.apply(session, runner); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java index 42badbdf25e1..d4793de2664a 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcTableStatisticsTest.java @@ -145,7 +145,7 @@ public void testStatsWithPredicatePushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('nationkey', 5e0, 0e0, null)," + @@ -161,7 +161,7 @@ public void testStatsWithVarcharPredicatePushdown() // Predicate on a varchar column. May or may not be pushed down, may or may not be subsumed. assertThat(query("SHOW STATS FOR (SELECT * FROM nation WHERE name = 'PERU')")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('nationkey', 1e0, 0e0, null)," + @@ -178,7 +178,7 @@ public void testStatsWithVarcharPredicatePushdown() gatherStats(table.getName()); assertThat(query("SHOW STATS FOR (SELECT * FROM " + table.getName() + " WHERE fl = 'B')")) - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('nationkey', 5e0, 0e0, null)," + @@ -205,7 +205,7 @@ public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled() assertThat(query(session, "SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('nationkey', 25e0, 0e0, null)," + @@ -227,7 +227,7 @@ public void testStatsWithLimitPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 2e0, 0e0, null)," + @@ -247,7 +247,7 @@ public void testStatsWithTopNPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 2e0, 0e0, null)," + @@ -266,7 +266,7 @@ public void testStatsWithDistinctPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 5e0, 0e0, null)," + @@ -285,7 +285,7 @@ public void testStatsWithDistinctLimitPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 3e0, 0e0, null)," + @@ -303,7 +303,7 @@ public void testStatsWithAggregationPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 5e0, 0e0, null)," + @@ -323,7 +323,7 @@ public void testStatsWithSimpleJoinPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('n_name', 5e0, 0e0, null)," + @@ -341,7 +341,7 @@ public void testStatsWithJoinPushdown() assertThat(query("SHOW STATS FOR (" + query + ")")) // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. - .projected(0, 2, 3, 4) + .exceptColumns("data_size", "low_value", "high_value") .skippingTypesCheck() .matches("VALUES " + "('regionkey', 1e0, 0e0, null)," + diff --git a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java index 0591ecc0e5fe..b0457809f5a7 100644 --- a/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java +++ b/plugin/trino-bigquery/src/test/java/io/trino/plugin/bigquery/BaseBigQueryConnectorTest.java @@ -253,7 +253,7 @@ public void testPartitionDateColumn() assertQuery(format("SELECT value FROM %s WHERE \"$partition_date\" = DATE '2159-12-31'", table.getName()), "VALUES 2"); // Verify DESCRIBE result doesn't have hidden columns - assertThat(query("DESCRIBE " + table.getName())).projected(0).skippingTypesCheck().matches("VALUES 'value'"); + assertThat(query("DESCRIBE " + table.getName())).projected("Column").skippingTypesCheck().matches("VALUES 'value'"); } } @@ -272,7 +272,7 @@ public void testPartitionTimeColumn() assertQuery(format("SELECT value FROM %s WHERE \"$partition_time\" = CAST('2159-12-31 23:00:00 UTC' AS TIMESTAMP(6) WITH TIME ZONE)", table.getName()), "VALUES 2"); // Verify DESCRIBE result doesn't have hidden columns - assertThat(query("DESCRIBE " + table.getName())).projected(0).skippingTypesCheck().matches("VALUES 'value'"); + assertThat(query("DESCRIBE " + table.getName())).projected("Column").skippingTypesCheck().matches("VALUES 'value'"); } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 481047da461a..8b58409ffec3 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -3198,7 +3198,7 @@ protected void testBucketTransformForType( assertThat(query("SHOW STATS FOR " + tableName)) .skippingTypesCheck() - .projected(0, 2, 3, 4) // data size, min and max may vary between types + .exceptColumns("data_size", "low_value", "high_value") // these may vary between types .matches("VALUES " + " ('d', 3e0, " + (format == AVRO ? "0.1e0" : "0.25e0") + ", NULL), " + " (NULL, NULL, NULL, 4e0)"); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcConnectorTest.java index eca20205d4a0..ad2594271902 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcConnectorTest.java @@ -65,7 +65,7 @@ public void testTinyintType() Files.delete(orcFilePath.resolveSibling(format(".%s.crc", orcFilePath.getFileName()))); assertThat(query("DESCRIBE " + table.getName())) - .projected(1) + .projected("Type") .matches("VALUES varchar 'integer'"); assertQuery("SELECT * FROM " + table.getName(), "VALUES 127, NULL"); } @@ -81,7 +81,7 @@ public void testSmallintType() Files.delete(orcFilePath.resolveSibling(format(".%s.crc", orcFilePath.getFileName()))); assertThat(query("DESCRIBE " + table.getName())) - .projected(1) + .projected("Type") .matches("VALUES varchar 'integer'"); assertQuery("SELECT * FROM " + table.getName(), "VALUES 32767, NULL"); } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 59e3ff30788d..29127905b220 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -142,14 +142,12 @@ public void testCreateTableAsSelectWithUnicode() public void testReadFromLateBindingView(String redshiftType, String trinoType) { try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) { - assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName()))) - .projected(1) + assertThat(query("SELECT true FROM %s WHERE value IS NULL".formatted(view.getName()))) .containsAll("VALUES (true)"); assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) - .projected(1) .skippingTypesCheck() - .containsAll("VALUES ('%s')".formatted(trinoType)); + .containsAll("VALUES ('value', '%s', '', '')".formatted(trinoType)); } } @@ -165,9 +163,8 @@ public void testReadNullFromView(String redshiftType, String trinoType, boolean .matches("VALUES null"); assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) - .projected(1) .skippingTypesCheck() - .matches("VALUES '%s'".formatted(trinoType)); + .matches("VALUES ('value', '%s', '', '')".formatted(trinoType)); } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 14e5e19051d9..6290d1ef6215 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -6596,6 +6596,34 @@ public void testJsonArrayFunction() .matches("VALUES (CAST('[\"AFRICA\",0]' AS varchar(100))), ('[\"AMERICA\",1]'), ('[\"ASIA\",2]'), ('[\"EUROPE\",3]'), ('[\"MIDDLE EAST\",4]')"); } + @Test + public void testColumnNames() + { + MaterializedResult showFunctionsResult = computeActual("SHOW FUNCTIONS"); + assertEquals(showFunctionsResult.getColumnNames(), ImmutableList.of("Function", "Return Type", "Argument Types", "Function Type", "Deterministic", "Description")); + + MaterializedResult showCatalogsResult = computeActual("SHOW CATALOGS"); + assertEquals(showCatalogsResult.getColumnNames(), ImmutableList.of("Catalog")); + + MaterializedResult selectAllResult = computeActual("SELECT * FROM nation"); + assertEquals(selectAllResult.getColumnNames(), ImmutableList.of("nationkey", "name", "regionkey", "comment")); + + MaterializedResult selectResult = computeActual("SELECT nationkey, regionkey FROM nation"); + assertEquals(selectResult.getColumnNames(), ImmutableList.of("nationkey", "regionkey")); + + MaterializedResult selectJsonArrayResult = computeActual("SELECT json_array(name, regionkey) from nation"); + assertEquals(selectJsonArrayResult.getColumnNames(), ImmutableList.of("_col0")); + + MaterializedResult selectJsonArrayAsResult = computeActual("SELECT json_array(name, regionkey) result from nation"); + assertEquals(selectJsonArrayAsResult.getColumnNames(), ImmutableList.of("result")); + + MaterializedResult showColumnResult = computeActual("SHOW COLUMNS FROM nation"); + assertEquals(showColumnResult.getColumnNames(), ImmutableList.of("Column", "Type", "Extra", "Comment")); + + MaterializedResult showCreateTableResult = computeActual("SHOW CREATE TABLE nation"); + assertEquals(showCreateTableResult.getColumnNames(), ImmutableList.of("Create Table")); + } + private static ZonedDateTime zonedDateTime(String value) { return ZONED_DATE_TIME_FORMAT.parse(value, ZonedDateTime::from); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 204e699f1e41..3737e88de3e1 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -887,12 +887,12 @@ public void testView() // column listing assertThat(query("SHOW COLUMNS FROM " + testView)) - .projected(0) // column types can very between connectors + .projected("Column") // column types can very between connectors .skippingTypesCheck() .matches("VALUES 'orderkey', 'orderstatus', 'half'"); assertThat(query("DESCRIBE " + testView)) - .projected(0) // column types can very between connectors + .projected("Column") // column types can very between connectors .skippingTypesCheck() .matches("VALUES 'orderkey', 'orderstatus', 'half'"); @@ -1065,12 +1065,12 @@ public void testMaterializedView() // column listing assertThat(query("SHOW COLUMNS FROM " + view.getObjectName())) - .projected(0) // column types can very between connectors + .projected("Column") // column types can very between connectors .skippingTypesCheck() .matches("VALUES 'nationkey', 'name', 'regionkey', 'comment'"); assertThat(query("DESCRIBE " + view.getObjectName())) - .projected(0) // column types can very between connectors + .projected("Column") // column types can very between connectors .skippingTypesCheck() .matches("VALUES 'nationkey', 'name', 'regionkey', 'comment'"); @@ -4426,7 +4426,7 @@ public void testAddColumnConcurrently() .collect(toImmutableList()); assertThat(query("DESCRIBE " + tableName)) - .projected(0) + .projected("Column") .skippingTypesCheck() .matches(Stream.concat(Stream.of("col"), addedColumns.stream()) .map(value -> format("'%s'", value)) diff --git a/testing/trino-tests/src/test/java/io/trino/tests/BaseQueryAssertionsTest.java b/testing/trino-tests/src/test/java/io/trino/tests/BaseQueryAssertionsTest.java index 4e6f5de2bf57..0b9abeb8f2d4 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/BaseQueryAssertionsTest.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/BaseQueryAssertionsTest.java @@ -344,4 +344,43 @@ public void testIsNotFullyPushedDown() "\n" + "Output[columnNames = [name]]\n"); } + + @Test + public void testProjectedColumns() + { + assertThat(query("SHOW COLUMNS FROM nation")) + .projected("Column") + .skippingTypesCheck() + .matches("VALUES 'nationkey', 'name', 'regionkey', 'comment'"); + + assertThat(query("SHOW COLUMNS FROM nation")) + .exceptColumns("Type", "Extra", "Comment") + .skippingTypesCheck() + .matches("VALUES 'nationkey', 'name', 'regionkey', 'comment'"); + + assertThatThrownBy( + () -> assertThat(query("SHOW COLUMNS FROM nation")) + .projected("Column", "Non_Existent")) + .hasMessageContaining("[Non_Existent] column is not present in [Column, Type, Extra, Comment]"); + + assertThatThrownBy( + () -> assertThat(query("SHOW COLUMNS FROM nation")) + .exceptColumns("Type", "Extra", "Comment", "Non_Existent")) + .hasMessageContaining("[Non_Existent] column is not present in [Column, Type, Extra, Comment]"); + + assertThatThrownBy( + () -> assertThat(query("SHOW COLUMNS FROM nation")) + .projected()) // project no columns + .hasMessageContaining("At least one column must be projected"); + + assertThatThrownBy( + () -> assertThat(query("SHOW COLUMNS FROM nation")) + .exceptColumns()) // exclude no columns + .hasMessageContaining("At least one column must be excluded"); + + assertThatThrownBy( + () -> assertThat(query("SHOW COLUMNS FROM nation")) + .exceptColumns("Column", "Type", "Extra", "Comment")) // exclude all columns + .hasMessageContaining("All columns cannot be excluded"); + } }