Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.query;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.execution.warnings.WarningCollector;
Expand Down Expand Up @@ -43,6 +44,7 @@
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder;
import static io.trino.sql.planner.assertions.PlanAssert.assertPlan;
Expand Down Expand Up @@ -379,22 +381,27 @@ public QueryAssert isFullyPushedDown()
* <b>Note:</b> the primary intent of this assertion is to ensure the test is updated to {@link #isFullyPushedDown()}
* when pushdown capabilities are improved.
*/
public QueryAssert isNotFullyPushedDown(Class<? extends PlanNode> retainedNode)
public QueryAssert isNotFullyPushedDown(Class<? extends PlanNode>... retainedNodes)
{
checkArgument(retainedNodes.length > 0, "No retainedNodes");

// Compare the results with pushdown disabled, so that explicit matches() call is not needed
verifyResultsWithPushdownDisabled();

transaction(runner.getTransactionManager(), runner.getAccessControl())
.execute(session, session -> {
Plan plan = runner.createPlan(session, query, WarningCollector.NOOP);
PlanMatchPattern expectedPlan = PlanMatchPattern.node(TableScanNode.class);
for (Class<? extends PlanNode> retainedNode : ImmutableList.copyOf(retainedNodes).reverse()) {
expectedPlan = PlanMatchPattern.node(retainedNode, expectedPlan);
}
expectedPlan = PlanMatchPattern.anyTree(expectedPlan);
assertPlan(
session,
runner.getMetadata(),
(node, sourceStats, lookup, ignore, types) -> PlanNodeStatsEstimate.unknown(),
plan,
PlanMatchPattern.anyTree(
PlanMatchPattern.node(retainedNode,
PlanMatchPattern.node(TableScanNode.class))));
expectedPlan);
});

return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ public void testApplyFilterAfterAggregationPushdown()
ConnectorTableHandle aggregatedTable = applyCountAggregation(session, baseTableHandle, ImmutableList.of(ImmutableList.of(groupByColumn)));

Domain domain = Domain.singleValue(VARCHAR, utf8Slice("one"));
JdbcTableHandle tableHandleWithFilter = applyConstraint(session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, domain))));
JdbcTableHandle tableHandleWithFilter = applyFilter(session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, domain))));

assertEquals(tableHandleWithFilter.getConstraint().getDomains(), Optional.of(ImmutableMap.of(groupByColumn, domain)));
}
Expand All @@ -299,12 +299,12 @@ public void testCombineFiltersWithAggregationPushdown()
ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));

Domain firstDomain = Domain.multipleValues(VARCHAR, ImmutableList.of(utf8Slice("one"), utf8Slice("two")));
JdbcTableHandle filterResult = applyConstraint(session, baseTableHandle, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, firstDomain))));
JdbcTableHandle filterResult = applyFilter(session, baseTableHandle, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, firstDomain))));

ConnectorTableHandle aggregatedTable = applyCountAggregation(session, filterResult, ImmutableList.of(ImmutableList.of(groupByColumn)));

Domain secondDomain = Domain.multipleValues(VARCHAR, ImmutableList.of(utf8Slice("one"), utf8Slice("three")));
JdbcTableHandle tableHandleWithFilter = applyConstraint(session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, secondDomain))));
JdbcTableHandle tableHandleWithFilter = applyFilter(session, aggregatedTable, new Constraint(TupleDomain.withColumnDomains(ImmutableMap.of(groupByColumn, secondDomain))));
assertEquals(
tableHandleWithFilter.getConstraint().getDomains(),
Optional.of(ImmutableMap.of(groupByColumn, Domain.singleValue(VARCHAR, utf8Slice("one")))));
Expand Down Expand Up @@ -365,7 +365,7 @@ private JdbcTableHandle applyCountAggregation(ConnectorSession session, Connecto
return (JdbcTableHandle) aggResult.get().getHandle();
}

private JdbcTableHandle applyConstraint(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint)
private JdbcTableHandle applyFilter(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint)
{
Optional<ConstraintApplicationResult<ConnectorTableHandle>> filterResult = metadata.applyFilter(
session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
final class TestingDatabase
implements AutoCloseable
{
private final String databaseName;
private final Connection connection;
private final JdbcClient jdbcClient;

public TestingDatabase()
throws SQLException
{
String connectionUrl = "jdbc:h2:mem:test" + System.nanoTime() + ThreadLocalRandom.current().nextLong();
databaseName = "TEST" + System.nanoTime() + ThreadLocalRandom.current().nextLong();
String connectionUrl = "jdbc:h2:mem:" + databaseName;
jdbcClient = new TestingH2JdbcClient(
new BaseJdbcConfig(),
new DriverConnectionFactory(new Driver(), connectionUrl, new Properties(), new EmptyCredentialProvider()));
Expand Down Expand Up @@ -78,6 +80,11 @@ public void close()
connection.close();
}

public String getDatabaseName()
{
return databaseName;
}

public Connection getConnection()
{
return connection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.memsql;

import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.testing.AbstractTestIntegrationSmokeTest;
import io.trino.testing.MaterializedResult;
Expand Down Expand Up @@ -243,6 +244,16 @@ public void testPredicatePushdown()
assertThat(query("SELECT orderkey FROM orders WHERE orderdate = DATE '1992-09-29'"))
.matches("VALUES BIGINT '1250', 34406, 38436, 57570")
.isFullyPushedDown();

// predicate over aggregation key (likely to be optimized before being pushed down into the connector)
assertThat(query("SELECT * FROM (SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey) WHERE regionkey = 3"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isNotFullyPushedDown(AggregationNode.class);

// predicate over aggregation result
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isNotFullyPushedDown(AggregationNode.class);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ public void testAggregationPushdown()
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class);

// GROUP BY above WHERE and LIMIT
assertThat(query("" +
"SELECT regionkey, sum(nationkey) " +
"FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " +
"GROUP BY regionkey"))
.isNotFullyPushedDown(AggregationNode.class);

// decimals
try (AutoCloseable ignoreTable = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) {
execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (100.000, 100000000.000000000)");
execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (123.321, 123456789.987654321)");
Expand Down Expand Up @@ -332,6 +340,16 @@ public void testPredicatePushdown()
.isFullyPushedDown();

execute("DROP TABLE tpch.binary_test");

// predicate over aggregation key (likely to be optimized before being pushed down into the connector)
assertThat(query("SELECT * FROM (SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey) WHERE regionkey = 3"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isFullyPushedDown();

// predicate over aggregation result
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isNotFullyPushedDown(FilterNode.class);
}

private AutoCloseable withTable(String tableName, String tableDefinition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.AbstractTestIntegrationSmokeTest;
import io.trino.testing.MaterializedResult;
import io.trino.testing.sql.SqlExecutor;
Expand Down Expand Up @@ -121,6 +123,16 @@ public void testPredicatePushdown()
assertThat(query("SELECT orderkey FROM orders WHERE orderdate = DATE '1992-09-29'"))
.matches("VALUES CAST(1250 AS DECIMAL(19,0)), 34406, 38436, 57570")
.isFullyPushedDown();

// predicate over aggregation key (likely to be optimized before being pushed down into the connector)
assertThat(query("SELECT * FROM (SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey) WHERE regionkey = 3"))
.matches("VALUES (CAST(3 AS decimal(19,0)), CAST(77 AS decimal(38,0)))")
.isNotFullyPushedDown(AggregationNode.class, ProjectNode.class);

// predicate over aggregation result
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77"))
.matches("VALUES (CAST(3 AS decimal(19,0)), CAST(77 AS decimal(38,0)))")
.isNotFullyPushedDown(AggregationNode.class, ProjectNode.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ public void testPredicatePushdown()
assertThat(query("SELECT orderkey FROM orders WHERE orderdate = DATE '1992-09-29'"))
.matches("VALUES BIGINT '1250', 34406, 38436, 57570")
.isFullyPushedDown();

// predicate over aggregation key (likely to be optimized before being pushed down into the connector)
assertThat(query("SELECT * FROM (SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey) WHERE regionkey = 3"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isFullyPushedDown();

// predicate over aggregation result
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isNotFullyPushedDown(FilterNode.class);
}

@Test
Expand Down Expand Up @@ -386,6 +396,13 @@ public void testAggregationPushdown()
// GROUP BY and WHERE on "other" (not aggregation key, not aggregation input)
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown();

// GROUP BY above WHERE and LIMIT
assertThat(query("" +
"SELECT regionkey, sum(nationkey) " +
"FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " +
"GROUP BY regionkey"))
.isNotFullyPushedDown(AggregationNode.class);

// decimals
try (AutoCloseable ignore = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) {
execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (100.000, 100000000.000000000)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.AbstractTestIntegrationSmokeTest;
import io.trino.testing.QueryRunner;
Expand Down Expand Up @@ -108,6 +109,14 @@ public void testAggregationPushdown()

assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isFullyPushedDown();

// GROUP BY above WHERE and LIMIT
assertThat(query("" +
"SELECT regionkey, sum(nationkey) " +
"FROM (SELECT * FROM nation WHERE regionkey < 3 LIMIT 11) " +
"GROUP BY regionkey"))
.isNotFullyPushedDown(AggregationNode.class);

// decimals
try (AutoCloseable ignoreTable = withTable("test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10), varchar_column varchar(10))")) {
sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (100.000, 100000000.000000000, 'ala')");
sqlServer.execute("INSERT INTO test_aggregation_pushdown VALUES (123.321, 123456789.987654321, 'kot')");
Expand Down Expand Up @@ -140,8 +149,8 @@ public void testAggregationPushdown()

// not supported yet
assertThat(query("SELECT min(DISTINCT short_decimal) FROM test_aggregation_pushdown")).isNotFullyPushedDown(AggregationNode.class);
// TODO: Improve assertion framework. Here min(long_decimal) is pushed down. There remains ProjectNode above it which relates to DISTINCT in the query.
assertThat(query("SELECT DISTINCT short_decimal, min(long_decimal) FROM test_aggregation_pushdown GROUP BY short_decimal")).isNotFullyPushedDown(ProjectNode.class);
assertThat(query("SELECT DISTINCT short_decimal, min(long_decimal) FROM test_aggregation_pushdown GROUP BY short_decimal"))
.isNotFullyPushedDown(AggregationNode.class, ProjectNode.class);
}

// array_agg returns array, which is not supported
Expand Down Expand Up @@ -253,9 +262,63 @@ public void testColumnComment()
}

@Test
public void testDecimalPredicatePushdown()
public void testPredicatePushdown()
throws Exception
{
// varchar equality
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'ROMANIA'"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();

// varchar range
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();

// varchar different case
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'romania'"))
// TODO https://github.com/trinodb/trino/issues/6671: .returnsEmptyResult()
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
// TODO https://github.com/trinodb/trino/issues/6671: isNotFullyPushedDown(FilterNode.class)
.isFullyPushedDown();

// bigint equality
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE nationkey = 19"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();

// bigint equality with small compaction threshold
assertThat(query(
Session.builder(getSession())
.setCatalogSessionProperty("sqlserver", "domain_compaction_threshold", "1")
.build(),
"SELECT regionkey, nationkey, name FROM nation WHERE nationkey IN (19, 21)"))
.matches("VALUES " +
"(BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25))), " +
"(BIGINT '2', BIGINT '21', CAST('VIETNAM' AS varchar(25)))")
.isNotFullyPushedDown(FilterNode.class);

// bigint range, with decimal to bigint simplification
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE nationkey BETWEEN 18.5 AND 19.5"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();

// date equality
assertThat(query("SELECT orderkey FROM orders WHERE orderdate = DATE '1992-09-29'"))
.matches("VALUES BIGINT '1250', 34406, 38436, 57570")
.isFullyPushedDown();

// predicate over aggregation key (likely to be optimized before being pushed down into the connector)
assertThat(query("SELECT * FROM (SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey) WHERE regionkey = 3"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isFullyPushedDown();

// predicate over aggregation result
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey HAVING sum(nationkey) = 77"))
.matches("VALUES (BIGINT '3', BIGINT '77')")
.isNotFullyPushedDown(FilterNode.class);

// decimals
try (AutoCloseable ignoreTable = withTable("test_decimal_pushdown",
"(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) {
sqlServer.execute("INSERT INTO test_decimal_pushdown VALUES (123.321, 123456789.987654321)");
Expand Down