From 236d00b74f3ca0e89bd2df1463c30b947ffb371e Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Fri, 22 Jan 2021 12:41:13 +0530 Subject: [PATCH] Cleanup SqlServer stats function pushdowns --- .../plugin/sqlserver/ImplementSqlServerStddevPop.java | 7 +++++-- .../io/trino/plugin/sqlserver/ImplementSqlServerStdev.java | 2 ++ .../trino/plugin/sqlserver/ImplementSqlServerVariance.java | 2 ++ .../plugin/sqlserver/ImplementSqlServerVariancePop.java | 4 +++- .../java/io/trino/plugin/sqlserver/SqlServerClient.java | 1 + 5 files changed, 13 insertions(+), 3 deletions(-) diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java index 4071064bb721..8bf0f70e5c5f 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStddevPop.java @@ -19,6 +19,7 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.expression.AggregateFunctionRule; +import io.trino.plugin.jdbc.expression.AggregateFunctionRule.RewriteContext; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.expression.Variable; import io.trino.spi.type.DoubleType; @@ -26,6 +27,7 @@ import java.util.Optional; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; @@ -44,7 +46,7 @@ public class ImplementSqlServerStddevPop public Pattern getPattern() { return basicAggregation() - .with(functionName().equalTo("stddev_pop")) + .with(functionName().matching(name -> name.equalsIgnoreCase("stddev_pop"))) .with(singleInput().matching( variable() .with(expressionType().matching(DoubleType.class::isInstance)) @@ -52,10 +54,11 @@ public Pattern getPattern() } @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, AggregateFunctionRule.RewriteContext context) + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) { Variable input = captures.get(INPUT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verifyNotNull(columnHandle, "Unbound variable: %s", input); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java index 716113310b54..8e779e02bd5c 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerStdev.java @@ -28,6 +28,7 @@ import java.util.Optional; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; @@ -59,6 +60,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap { Variable input = captures.get(INPUT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verifyNotNull(columnHandle, "Unbound variable: %s", input); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java index 317f9f41c7ec..7023a7a1935d 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariance.java @@ -28,6 +28,7 @@ import java.util.Optional; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; @@ -59,6 +60,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap { Variable input = captures.get(INPUT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verifyNotNull(columnHandle, "Unbound variable: %s", input); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java index 2f59c9356b85..faee50b9dc5d 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementSqlServerVariancePop.java @@ -26,6 +26,7 @@ import java.util.Optional; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; @@ -44,7 +45,7 @@ public class ImplementSqlServerVariancePop public Pattern getPattern() { return basicAggregation() - .with(functionName().equalTo("var_pop")) + .with(functionName().matching(name -> name.equalsIgnoreCase("var_pop"))) .with(singleInput().matching( variable() .with(expressionType().matching(DoubleType.class::isInstance)) @@ -56,6 +57,7 @@ public Optional rewrite(AggregateFunction aggregateFunction, Cap { Variable input = captures.get(INPUT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verifyNotNull(columnHandle, "Unbound variable: %s", input); verify(columnHandle.getColumnType().equals(DOUBLE)); verify(aggregateFunction.getOutputType().equals(DOUBLE)); diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 18c026f450a9..bfab7928d3a5 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -142,6 +142,7 @@ public SqlServerClient(BaseJdbcConfig config, ConnectionFactory connectionFactor .add(new ImplementSqlServerStddevPop()) .add(new ImplementSqlServerVariance()) .add(new ImplementSqlServerVariancePop()) + // SQL Server doesn't have covar_samp and covar_pop functions so we can't implement pushdown for them .build()); }