Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.TableScanNode;
Expand Down Expand Up @@ -892,6 +893,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S
}

verify(builder.get() != null, "Output operator was not created");
builder.get().columnNames(((OutputNode) plan.getRoot()).getColumnNames());
return new MaterializedResultWithPlan(builder.get().build(), plan);
}
catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ public class MaterializedResult

public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types)
{
this(rows, types, ImmutableList.of(), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
this(rows, types, Optional.empty());
}

public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types, Optional<List<String>> columnNames)
{
this(rows, types, columnNames.orElse(ImmutableList.of()), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
}

public MaterializedResult(
Expand Down Expand Up @@ -457,6 +462,7 @@ public static class Builder
private final ConnectorSession session;
private final List<Type> types;
private final ImmutableList.Builder<MaterializedRow> rows = ImmutableList.builder();
private Optional<List<String>> columnNames = Optional.empty();

Builder(ConnectorSession session, List<Type> types)
{
Expand Down Expand Up @@ -512,9 +518,15 @@ public synchronized Builder page(Page page)
return this;
}

public synchronized Builder columnNames(List<String> columnNames)
{
this.columnNames = Optional.of(ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")));
return this;
}

public synchronized MaterializedResult build()
{
return new MaterializedResult(rows.build(), types);
return new MaterializedResult(rows.build(), types, columnNames);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
Expand Down Expand Up @@ -51,11 +52,13 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder;
Expand Down Expand Up @@ -323,30 +326,59 @@ private QueryAssert(
this.skipResultsCorrectnessCheckForPushdown = skipResultsCorrectnessCheckForPushdown;
}

// TODO for better readability, replace this with `exceptColumns(String... columnNamesToExclude)` leveraging MaterializedResult.getColumnNames
@Deprecated
public QueryAssert projected(int... columns)
public QueryAssert exceptColumns(String... columnNamesToExclude)
Comment thread
krvikash marked this conversation as resolved.
Outdated
{
validateIfColumnsPresent(columnNamesToExclude);
checkArgument(columnNamesToExclude.length > 0, "At least one column must be excluded");
checkArgument(columnNamesToExclude.length < actual.getColumnNames().size(), "All columns cannot be excluded");
return projected(((Predicate<String>) Set.of(columnNamesToExclude)::contains).negate());
}

public QueryAssert projected(String... columnNamesToInclude)
{
validateIfColumnsPresent(columnNamesToInclude);
checkArgument(columnNamesToInclude.length > 0, "At least one column must be projected");
return projected(Set.of(columnNamesToInclude)::contains);
}

private QueryAssert projected(Predicate<String> columnFilter)
{
List<String> columnNames = actual.getColumnNames();
Map<Integer, String> columnsIndexToNameMap = new HashMap<>();
for (int i = 0; i < columnNames.size(); i++) {
String columnName = columnNames.get(i);
if (columnFilter.test(columnName)) {
columnsIndexToNameMap.put(i, columnName);
}
}

return new QueryAssert(
runner,
session,
format("%s projected with %s", query, Arrays.toString(columns)),
format("%s projected with %s", query, columnsIndexToNameMap.values()),
new MaterializedResult(
actual.getMaterializedRows().stream()
.map(row -> new MaterializedRow(
row.getPrecision(),
IntStream.of(columns)
.mapToObj(row::getField)
columnsIndexToNameMap.keySet().stream()
.map(row::getField)
.collect(toList()))) // values are nullable
.collect(toImmutableList()),
IntStream.of(columns)
.mapToObj(actual.getTypes()::get)
columnsIndexToNameMap.keySet().stream()
.map(actual.getTypes()::get)
.collect(toImmutableList())),
ordered,
skipTypesCheck,
skipResultsCorrectnessCheckForPushdown);
}

private void validateIfColumnsPresent(String... columns)
{
Set<String> columnNames = ImmutableSet.copyOf(actual.getColumnNames());
Arrays.stream(columns)
.forEach(column -> checkArgument(columnNames.contains(column), "[%s] column is not present in %s".formatted(column, columnNames)));
}

public QueryAssert matches(BiFunction<Session, QueryRunner, MaterializedResult> evaluator)
{
MaterializedResult expected = evaluator.apply(session, runner);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void testStatsWithPredicatePushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('nationkey', 5e0, 0e0, null)," +
Expand All @@ -161,7 +161,7 @@ public void testStatsWithVarcharPredicatePushdown()
// Predicate on a varchar column. May or may not be pushed down, may or may not be subsumed.
assertThat(query("SHOW STATS FOR (SELECT * FROM nation WHERE name = 'PERU')"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('nationkey', 1e0, 0e0, null)," +
Expand All @@ -178,7 +178,7 @@ public void testStatsWithVarcharPredicatePushdown()
gatherStats(table.getName());

assertThat(query("SHOW STATS FOR (SELECT * FROM " + table.getName() + " WHERE fl = 'B')"))
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('nationkey', 5e0, 0e0, null)," +
Expand All @@ -205,7 +205,7 @@ public void testStatsWithPredicatePushdownWithStatsPrecalculationDisabled()

assertThat(query(session, "SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('nationkey', 25e0, 0e0, null)," +
Expand All @@ -227,7 +227,7 @@ public void testStatsWithLimitPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 2e0, 0e0, null)," +
Expand All @@ -247,7 +247,7 @@ public void testStatsWithTopNPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 2e0, 0e0, null)," +
Expand All @@ -266,7 +266,7 @@ public void testStatsWithDistinctPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
Comment thread
krvikash marked this conversation as resolved.
Outdated
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 5e0, 0e0, null)," +
Expand All @@ -285,7 +285,7 @@ public void testStatsWithDistinctLimitPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 3e0, 0e0, null)," +
Expand All @@ -303,7 +303,7 @@ public void testStatsWithAggregationPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 5e0, 0e0, null)," +
Expand All @@ -323,7 +323,7 @@ public void testStatsWithSimpleJoinPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('n_name', 5e0, 0e0, null)," +
Expand All @@ -341,7 +341,7 @@ public void testStatsWithJoinPushdown()

assertThat(query("SHOW STATS FOR (" + query + ")"))
// Not testing average length and min/max, as this would make the test less reusable and is not that important to test.
.projected(0, 2, 3, 4)
.exceptColumns("data_size", "low_value", "high_value")
.skippingTypesCheck()
.matches("VALUES " +
"('regionkey', 1e0, 0e0, null)," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ public void testPartitionDateColumn()
assertQuery(format("SELECT value FROM %s WHERE \"$partition_date\" = DATE '2159-12-31'", table.getName()), "VALUES 2");

// Verify DESCRIBE result doesn't have hidden columns
assertThat(query("DESCRIBE " + table.getName())).projected(0).skippingTypesCheck().matches("VALUES 'value'");
assertThat(query("DESCRIBE " + table.getName())).projected("Column").skippingTypesCheck().matches("VALUES 'value'");
}
}

Expand All @@ -272,7 +272,7 @@ public void testPartitionTimeColumn()
assertQuery(format("SELECT value FROM %s WHERE \"$partition_time\" = CAST('2159-12-31 23:00:00 UTC' AS TIMESTAMP(6) WITH TIME ZONE)", table.getName()), "VALUES 2");

// Verify DESCRIBE result doesn't have hidden columns
assertThat(query("DESCRIBE " + table.getName())).projected(0).skippingTypesCheck().matches("VALUES 'value'");
assertThat(query("DESCRIBE " + table.getName())).projected("Column").skippingTypesCheck().matches("VALUES 'value'");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3198,7 +3198,7 @@ protected void testBucketTransformForType(

assertThat(query("SHOW STATS FOR " + tableName))
.skippingTypesCheck()
.projected(0, 2, 3, 4) // data size, min and max may vary between types
.exceptColumns("data_size", "low_value", "high_value") // these may vary between types
.matches("VALUES " +
" ('d', 3e0, " + (format == AVRO ? "0.1e0" : "0.25e0") + ", NULL), " +
" (NULL, NULL, NULL, 4e0)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testTinyintType()
Files.delete(orcFilePath.resolveSibling(format(".%s.crc", orcFilePath.getFileName())));

assertThat(query("DESCRIBE " + table.getName()))
.projected(1)
.projected("Type")
.matches("VALUES varchar 'integer'");
assertQuery("SELECT * FROM " + table.getName(), "VALUES 127, NULL");
}
Expand All @@ -81,7 +81,7 @@ public void testSmallintType()
Files.delete(orcFilePath.resolveSibling(format(".%s.crc", orcFilePath.getFileName())));

assertThat(query("DESCRIBE " + table.getName()))
.projected(1)
.projected("Type")
.matches("VALUES varchar 'integer'");
assertQuery("SELECT * FROM " + table.getName(), "VALUES 32767, NULL");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,12 @@ public void testCreateTableAsSelectWithUnicode()
public void testReadFromLateBindingView(String redshiftType, String trinoType)
{
try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) {
assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName())))
.projected(1)
Comment thread
krvikash marked this conversation as resolved.
Outdated
assertThat(query("SELECT true FROM %s WHERE value IS NULL".formatted(view.getName())))
.containsAll("VALUES (true)");

assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName())))
.projected(1)
Comment thread
krvikash marked this conversation as resolved.
Outdated
.skippingTypesCheck()
.containsAll("VALUES ('%s')".formatted(trinoType));
.containsAll("VALUES ('value', '%s', '', '')".formatted(trinoType));
}
}

Expand All @@ -165,9 +163,8 @@ public void testReadNullFromView(String redshiftType, String trinoType, boolean
.matches("VALUES null");

assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName())))
.projected(1)
.skippingTypesCheck()
.matches("VALUES '%s'".formatted(trinoType));
.matches("VALUES ('value', '%s', '', '')".formatted(trinoType));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6596,6 +6596,34 @@ public void testJsonArrayFunction()
.matches("VALUES (CAST('[\"AFRICA\",0]' AS varchar(100))), ('[\"AMERICA\",1]'), ('[\"ASIA\",2]'), ('[\"EUROPE\",3]'), ('[\"MIDDLE EAST\",4]')");
}

@Test
public void testColumnNames()
{
MaterializedResult showFunctionsResult = computeActual("SHOW FUNCTIONS");
assertEquals(showFunctionsResult.getColumnNames(), ImmutableList.of("Function", "Return Type", "Argument Types", "Function Type", "Deterministic", "Description"));

MaterializedResult showCatalogsResult = computeActual("SHOW CATALOGS");
assertEquals(showCatalogsResult.getColumnNames(), ImmutableList.of("Catalog"));

MaterializedResult selectAllResult = computeActual("SELECT * FROM nation");
assertEquals(selectAllResult.getColumnNames(), ImmutableList.of("nationkey", "name", "regionkey", "comment"));

MaterializedResult selectResult = computeActual("SELECT nationkey, regionkey FROM nation");
assertEquals(selectResult.getColumnNames(), ImmutableList.of("nationkey", "regionkey"));

MaterializedResult selectJsonArrayResult = computeActual("SELECT json_array(name, regionkey) from nation");
assertEquals(selectJsonArrayResult.getColumnNames(), ImmutableList.of("_col0"));

MaterializedResult selectJsonArrayAsResult = computeActual("SELECT json_array(name, regionkey) result from nation");
assertEquals(selectJsonArrayAsResult.getColumnNames(), ImmutableList.of("result"));

MaterializedResult showColumnResult = computeActual("SHOW COLUMNS FROM nation");
assertEquals(showColumnResult.getColumnNames(), ImmutableList.of("Column", "Type", "Extra", "Comment"));

MaterializedResult showCreateTableResult = computeActual("SHOW CREATE TABLE nation");
assertEquals(showCreateTableResult.getColumnNames(), ImmutableList.of("Create Table"));
}

private static ZonedDateTime zonedDateTime(String value)
{
return ZONED_DATE_TIME_FORMAT.parse(value, ZonedDateTime::from);
Expand Down
Loading