diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java index dcc80d61ab1ad..98399ccc19ef4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java @@ -15,6 +15,7 @@ package com.facebook.presto.sql; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.metadata.Metadata; @@ -40,7 +41,9 @@ import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.Relation; import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.sql.tree.Table; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -60,6 +63,7 @@ import static com.facebook.presto.common.predicate.TupleDomain.extractFixedValues; import static com.facebook.presto.common.type.StandardTypes.HYPER_LOG_LOG; import static com.facebook.presto.common.type.StandardTypes.VARBINARY; +import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.sql.ExpressionUtils.combineDisjuncts; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Operator.DIVIDE; @@ -399,6 +403,15 @@ public static Expression convertMaterializedDataPredicatesToExpression( } } + public static Relation resolveTableName(Relation relation, Session session, Metadata metadata) + { + if (!(relation instanceof Table)) { + return relation; + } + QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, relation, ((Table) relation).getName(), metadata); + return new Table(QualifiedName.of(qualifiedTableName.getSchemaName(), qualifiedTableName.getObjectName())); + } + private static Expression convertSymbolReferencesToIdentifiers(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java index 3b6355bdfd317..ac0b7591ab38b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java @@ -101,6 +101,7 @@ import static com.facebook.presto.sql.MaterializedViewUtils.COUNT; import static com.facebook.presto.sql.MaterializedViewUtils.NON_ASSOCIATIVE_REWRITE_FUNCTIONS; import static com.facebook.presto.sql.MaterializedViewUtils.SUM; +import static com.facebook.presto.sql.MaterializedViewUtils.resolveTableName; import static com.facebook.presto.sql.analyzer.MaterializedViewInformationExtractor.MaterializedViewInfo; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; @@ -641,7 +642,8 @@ protected Node visitAliasedRelation(AliasedRelation node, Void context) @Override protected Node visitRelation(Relation node, Void context) { - if (materializedViewInfo.getBaseTable().isPresent() && node.equals(materializedViewInfo.getBaseTable().get())) { + if (materializedViewInfo.getBaseTable().isPresent() && resolveTableName(node, session, metadata) + .equals(resolveTableName(materializedViewInfo.getBaseTable().get(), session, metadata))) { return materializedView; } throw new IllegalStateException("Mismatching table or non-supporting relation format in base query"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java index 108ebb4da54e3..bf70110ea0eb5 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java @@ -450,6 +450,42 @@ public void testWithTableAlias() assertOptimizedQuery(baseQuerySqlWithTablePrefix, expectedRewrittenSql, originalViewSqlWithTablePrefix, BASE_TABLE_1, VIEW_1); } + @Test + public void testWithSchemaQualifiedTableName() + { + String schemaQualifiedTable = SESSION_SCHEMA + "." + BASE_TABLE_1; + + String originalViewSql = format("SELECT a, b FROM %s", BASE_TABLE_1); + String baseQuerySql = format("SELECT a, b FROM %s", schemaQualifiedTable); + String expectedRewrittenSql = format("SELECT a, b FROM %s", VIEW_1); + + assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1); + + originalViewSql = format("SELECT a, b FROM %s", schemaQualifiedTable); + baseQuerySql = format("SELECT a, b FROM %s", BASE_TABLE_1); + expectedRewrittenSql = format("SELECT a, b FROM %s", VIEW_1); + + assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1); + + originalViewSql = format("SELECT a, b FROM %s", schemaQualifiedTable); + baseQuerySql = format("SELECT a, b FROM %s", schemaQualifiedTable); + expectedRewrittenSql = format("SELECT a, b FROM %s", VIEW_1); + + assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1); + + originalViewSql = format("SELECT a, b, c FROM %s", BASE_TABLE_1); + baseQuerySql = format("SELECT a, b FROM %s WHERE c > 10", schemaQualifiedTable); + expectedRewrittenSql = format("SELECT a, b FROM %s WHERE c > 10", VIEW_1); + + assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1); + + originalViewSql = format("SELECT SUM(a) as sum_a, b FROM %s GROUP BY b", BASE_TABLE_1); + baseQuerySql = format("SELECT SUM(a), b FROM %s GROUP BY b", schemaQualifiedTable); + expectedRewrittenSql = format("SELECT SUM(sum_a), b FROM %s GROUP BY b", VIEW_1); + + assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1); + } + @Test public void testAggregationWithTableAlias() {