diff --git a/wren-base/src/main/java/io/wren/base/dto/Measure.java b/wren-base/src/main/java/io/wren/base/dto/Measure.java index 1592fc5f1..27f29d6ae 100644 --- a/wren-base/src/main/java/io/wren/base/dto/Measure.java +++ b/wren-base/src/main/java/io/wren/base/dto/Measure.java @@ -81,6 +81,11 @@ public Map getProperties() return properties; } + public Column toColumn() + { + return new Column(name, type, null, false, false, refColumn, properties); + } + @Override public boolean equals(Object o) { diff --git a/wren-base/src/main/java/io/wren/base/dto/Window.java b/wren-base/src/main/java/io/wren/base/dto/Window.java index a31b5c0b6..606ed0c19 100644 --- a/wren-base/src/main/java/io/wren/base/dto/Window.java +++ b/wren-base/src/main/java/io/wren/base/dto/Window.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; +import io.wren.base.WrenTypes; import java.util.Map; import java.util.Objects; @@ -95,6 +96,11 @@ public Map getProperties() return properties; } + public Column toColumn() + { + return new Column(name, WrenTypes.TIMESTAMP, null, false, false, refColumn, properties); + } + @Override public String toString() { diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/CacheRewrite.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/CacheRewrite.java deleted file mode 100644 index 4b2ec8c95..000000000 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/CacheRewrite.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.wren.base.sqlrewrite; - -import com.google.common.collect.ImmutableSet; -import io.airlift.log.Logger; -import io.trino.sql.SqlFormatter; -import io.trino.sql.parser.SqlBaseLexer; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.Join; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.Query; -import io.trino.sql.tree.QuerySpecification; -import io.trino.sql.tree.Statement; -import io.trino.sql.tree.Table; -import io.trino.sql.tree.With; -import io.trino.sql.tree.WithQuery; -import io.wren.base.CatalogSchemaTableName; -import io.wren.base.SessionContext; -import io.wren.base.WrenMDL; -import io.wren.base.sqlrewrite.analyzer.CacheAnalysis; -import io.wren.base.sqlrewrite.analyzer.Field; -import io.wren.base.sqlrewrite.analyzer.Scope; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; - -import static io.trino.sql.QueryUtil.getQualifiedName; -import static io.trino.sql.SqlFormatter.Dialect.DUCKDB; -import static io.wren.base.sqlrewrite.Utils.analyzeFrom; -import static io.wren.base.sqlrewrite.Utils.parseSql; -import static io.wren.base.sqlrewrite.Utils.toCatalogSchemaTableName; -import static java.lang.String.format; -import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; - -public class CacheRewrite -{ - private static final Logger LOG = Logger.get(CacheRewrite.class); - private static final Set KEYWORDS = ImmutableSet.copyOf(SqlBaseLexer.ruleNames); - - private CacheRewrite() {} - - public static Optional rewrite( - SessionContext sessionContext, - String sql, - Function> converter, - WrenMDL wrenMDL) - { - try { - Statement statement = parseSql(sql); - CacheAnalysis aggregationAnalysis = new CacheAnalysis(); - Statement rewritten = (Statement) new Rewriter(sessionContext, converter, wrenMDL, aggregationAnalysis).process(statement, Optional.empty()); - if (rewritten instanceof Query - && aggregationAnalysis.onlyCachedTables()) { - return Optional.of(SqlFormatter.formatSql(rewritten, DUCKDB)); - } - } - catch (Exception e) { - LOG.warn(e, "Failed to rewrite query: %s", sql); - } - return Optional.empty(); - } - - private static class Rewriter - extends BaseRewriter> - { - private final SessionContext sessionContext; - private final Function> converter; - private final Map visitedAggregationTables = new HashMap<>(); - private final WrenMDL wrenMDL; - private final CacheAnalysis aggregationAnalysis; - - public Rewriter( - SessionContext sessionContext, - Function> converter, - WrenMDL wrenMDL, - CacheAnalysis aggregationAnalysis) - { - this.sessionContext = requireNonNull(sessionContext, "sessionContext is null"); - this.converter = requireNonNull(converter, "converter is null"); - this.wrenMDL = requireNonNull(wrenMDL, "wrenMDL is null"); - this.aggregationAnalysis = requireNonNull(aggregationAnalysis, "aggregationAnalysis is null"); - } - - @Override - protected Node visitQuery(Query node, Optional scope) - { - Optional withScope = analyzeWith(node, scope); - return super.visitQuery(node, withScope); - } - - @Override - protected Node visitQuerySpecification(QuerySpecification node, Optional scope) - { - Optional relationScope; - if (node.getFrom().isPresent()) { - relationScope = Optional.of(analyzeFrom(wrenMDL, sessionContext, node.getFrom().get(), scope)); - } - else { - relationScope = scope; - } - return super.visitQuerySpecification(node, relationScope); - } - - @Override - protected Node visitJoin(Join node, Optional scope) - { - return new Join( - node.getType(), - visitAndCast(node.getLeft(), scope), - visitAndCast(node.getRight(), scope), - node.getCriteria().map(criteria -> visitJoinCriteria(criteria, scope))); - } - - @Override - protected Node visitDereferenceExpression(DereferenceExpression node, Optional scope) - { - Expression base = node.getBase(); - if (scope.isPresent()) { - List field = scope.get().getRelationType().resolveFields(getQualifiedName(node)); - if (field.size() == 1) { - QualifiedName qualifiedName = getQualifiedName(base); - if (field.get(0).getRelationAlias().isEmpty() - && visitedAggregationTables.containsKey(qualifiedName)) { - return new DereferenceExpression( - node.getLocation(), - DereferenceExpression.from(QualifiedName.of(visitedAggregationTables.get(qualifiedName))), - node.getField()); - } - } - } - return new DereferenceExpression( - node.getLocation(), - base, - node.getField()); - } - - @Override - protected Node visitTable(Table node, Optional scope) - { - if (scope.isPresent()) { - Optional withQuery = scope.get().getNamedQuery(node.getName().getSuffix()); - if (withQuery.isPresent()) { - return node; - } - } - - CatalogSchemaTableName catalogSchemaTableName = toCatalogSchemaTableName(sessionContext, node.getName()); - aggregationAnalysis.addTable(catalogSchemaTableName); - if (wrenMDL.getCacheInfo(catalogSchemaTableName).isPresent()) { - Optional cachedTableOpt = convertTable(catalogSchemaTableName); - if (cachedTableOpt.isPresent()) { - aggregationAnalysis.addCachedTables(catalogSchemaTableName); - String cachedTable = cachedTableOpt.get(); - String schemaName = catalogSchemaTableName.getSchemaTableName().getSchemaName(); - String tableName = catalogSchemaTableName.getSchemaTableName().getTableName(); - visitedAggregationTables.put(QualifiedName.of(tableName), cachedTable); - visitedAggregationTables.put(QualifiedName.of(schemaName, tableName), cachedTable); - visitedAggregationTables.put(QualifiedName.of(catalogSchemaTableName.getCatalogName(), schemaName, tableName), cachedTable); - if (node.getLocation().isPresent()) { - return new Table( - node.getLocation().get(), - QualifiedName.of(cachedTable)); - } - return new Table(QualifiedName.of(cachedTable)); - } - } - return node; - } - - private Optional convertTable(CatalogSchemaTableName cachedTable) - { - return converter.apply(cachedTable); - } - - private Optional analyzeWith(Query node, Optional scope) - { - if (node.getWith().isEmpty()) { - return Optional.of(Scope.builder().parent(scope).build()); - } - - With with = node.getWith().get(); - Scope.Builder withScopeBuilder = Scope.builder().parent(scope); - - for (WithQuery withQuery : with.getQueries()) { - String name = withQuery.getName().getValue(); - if (withScopeBuilder.containsNamedQuery(name)) { - throw new IllegalArgumentException(format("WITH query name '%s' specified more than once", name)); - } - if (with.isRecursive()) { - withScopeBuilder.namedQuery(name, withQuery); - visitAndCast(withQuery.getQuery(), Optional.of(withScopeBuilder.build())); - } - else { - visitAndCast(withQuery.getQuery(), Optional.of(withScopeBuilder.build())); - withScopeBuilder.namedQuery(name, withQuery); - } - } - - return Optional.of(withScopeBuilder.build()); - } - } - - protected static Identifier identifier(String name) - { - if (KEYWORDS.contains(name.toUpperCase(ENGLISH))) { - return new Identifier(name, true); - } - return new Identifier(name); - } -} diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/Field.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/Field.java index df41d7ab9..e770a9931 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/Field.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/Field.java @@ -16,6 +16,7 @@ import io.trino.sql.tree.QualifiedName; import io.wren.base.CatalogSchemaTableName; +import io.wren.base.dto.Column; import io.wren.base.sqlrewrite.Utils; import java.util.Optional; @@ -33,7 +34,7 @@ public class Field private final CatalogSchemaTableName tableName; private final String columnName; private final Optional sourceDatasetName; - private final Optional sourceColumnName; + private final Optional sourceColumn; private final Optional name; private Field( @@ -42,14 +43,14 @@ private Field( String columnName, String name, String sourceDatasetName, - String sourceColumnName) + Column sourceColumn) { this.relationAlias = Optional.ofNullable(relationAlias); this.tableName = requireNonNull(tableName, "modelName is null"); this.columnName = requireNonNull(columnName, "columnName is null"); this.name = Optional.ofNullable(name); this.sourceDatasetName = Optional.ofNullable(sourceDatasetName); - this.sourceColumnName = Optional.ofNullable(sourceColumnName); + this.sourceColumn = Optional.ofNullable(sourceColumn); } public Optional getRelationAlias() @@ -77,9 +78,9 @@ public Optional getSourceDatasetName() return sourceDatasetName; } - public Optional getSourceColumnName() + public Optional getSourceColumn() { - return sourceColumnName; + return sourceColumn; } public boolean matchesPrefix(Optional prefix) @@ -128,7 +129,7 @@ public String toString() ", columnName='" + columnName + '\'' + ", name=" + name + ", sourceDatasetName=" + sourceDatasetName + - ", sourceColumnName=" + sourceColumnName + + ", sourceColumn=" + sourceColumn + '}'; } @@ -144,7 +145,7 @@ public static class Builder private String columnName; private String name; private String sourceModelName; - private String sourceColumnName; + private Column sourceColumn; public Builder() {} @@ -155,7 +156,7 @@ public Builder like(Field field) this.columnName = field.columnName; this.name = field.name.orElse(null); this.sourceModelName = field.sourceDatasetName.orElse(null); - this.sourceColumnName = field.sourceColumnName.orElse(null); + this.sourceColumn = field.sourceColumn.orElse(null); return this; } @@ -189,15 +190,15 @@ public Builder sourceModelName(String sourceModelName) return this; } - public Builder sourceColumnName(String sourceColumnName) + public Builder sourceColumn(Column sourceColumn) { - this.sourceColumnName = sourceColumnName; + this.sourceColumn = sourceColumn; return this; } public Field build() { - return new Field(relationAlias, tableName, columnName, name, sourceModelName, sourceColumnName); + return new Field(relationAlias, tableName, columnName, name, sourceModelName, sourceColumn); } } } diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/RelationType.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/RelationType.java index 6984d9c81..1602f0c3f 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/RelationType.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/RelationType.java @@ -47,6 +47,7 @@ public List getFields() public List resolveFields(QualifiedName name) { return fields.stream() + .filter(input -> input.getSourceColumn().stream().anyMatch(column -> column.getRelationship().isEmpty())) .filter(input -> input.canResolve(name)) .collect(toImmutableList()); } @@ -54,6 +55,7 @@ public List resolveFields(QualifiedName name) public Optional resolveAnyField(QualifiedName name) { return fields.stream() + .filter(input -> input.getSourceColumn().stream().anyMatch(column -> column.getRelationship().isEmpty())) .filter(input -> input.canResolve(name)) .findAny(); } diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java index 9e4124597..1bad25f1a 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java @@ -237,7 +237,7 @@ private List createScopeForQuery(Query query, QualifiedName scopeName, Op .name(f.getName().orElse(f.getColumnName())) .tableName(toCatalogSchemaTableName(sessionContext, scopeName)) .sourceModelName(f.getSourceDatasetName().orElse(null)) - .sourceColumnName(f.getSourceColumnName().orElse(null)) + .sourceColumn(f.getSourceColumn().orElse(null)) .build()))); } else { @@ -254,7 +254,7 @@ private List createScopeForQuery(Query query, QualifiedName scopeName, Op .name(name) .tableName(toCatalogSchemaTableName(sessionContext, scopeName)) .sourceModelName(f.getSourceDatasetName().orElse(null)) - .sourceColumnName(f.getSourceColumnName().orElse(null)) + .sourceColumn(f.getSourceColumn().orElse(null)) .build()); continue; } @@ -282,7 +282,7 @@ private List collectFieldFromMDL(CatalogSchemaTableName tableName) .columnName(column.getName()) .name(column.getName()) .sourceModelName(tableName.getSchemaTableName().getTableName()) - .sourceColumnName(column.getName()) + .sourceColumn(column) .build()) .collect(toImmutableList()); } @@ -296,7 +296,7 @@ else if (wrenMDL.getMetric(tableName.getSchemaTableName().getTableName()).isPres .columnName(column.getName()) .name(column.getName()) .sourceModelName(tableName.getSchemaTableName().getTableName()) - .sourceColumnName(column.getName()) + .sourceColumn(column) .build()) .collect(toImmutableList()); } @@ -308,14 +308,14 @@ else if (wrenMDL.getCumulativeMetric(tableName.getSchemaTableName().getTableName .columnName(cumulativeMetric.getWindow().getName()) .name(cumulativeMetric.getWindow().getName()) .sourceModelName(tableName.getSchemaTableName().getTableName()) - .sourceColumnName(cumulativeMetric.getWindow().getName()) + .sourceColumn(cumulativeMetric.getWindow().toColumn()) .build(), Field.builder() .tableName(tableName) .columnName(cumulativeMetric.getMeasure().getName()) .name(cumulativeMetric.getMeasure().getName()) .sourceModelName(tableName.getSchemaTableName().getTableName()) - .sourceColumnName(cumulativeMetric.getMeasure().getName()) + .sourceColumn(cumulativeMetric.getMeasure().toColumn()) .build()); } return ImmutableList.of(); diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/DecisionPointAnalyzer.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/DecisionPointAnalyzer.java index 8f4a126a9..464bfe184 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/DecisionPointAnalyzer.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/DecisionPointAnalyzer.java @@ -33,6 +33,7 @@ import io.wren.base.CatalogSchemaTableName; import io.wren.base.SessionContext; import io.wren.base.WrenMDL; +import io.wren.base.dto.Column; import io.wren.base.sqlrewrite.analyzer.Analysis; import io.wren.base.sqlrewrite.analyzer.Field; import io.wren.base.sqlrewrite.analyzer.StatementAnalyzer; @@ -131,6 +132,7 @@ protected Void visitAllColumns(AllColumns node, DecisionPointContext decisionPoi else { scopedFields.stream() .filter(field -> field.getRelationAlias().filter(alias -> alias.toString().equals(target)).isPresent() || field.getTableName().equals(catalogSchemaTableName)) + .filter(field -> field.getSourceColumn().stream().anyMatch(column -> !column.isCalculated() && column.getRelationship().isEmpty())) .forEach(field -> { decisionPointContext.getBuilder().addSelectItem(new ColumnAnalysis( Optional.empty(), @@ -140,7 +142,7 @@ protected Void visitAllColumns(AllColumns node, DecisionPointContext decisionPoi List.of(new ExprSource( field.getName().orElse(field.getColumnName()), field.getTableName().getSchemaTableName().getTableName(), - field.getSourceColumnName().orElse(null), + field.getSourceColumn().map(Column::getName).orElse(null), node.getLocation().orElse(null))))); }); } @@ -156,18 +158,20 @@ protected Void visitAllColumns(AllColumns node, DecisionPointContext decisionPoi List.of())); } else { - scopedFields.forEach(field -> { - decisionPointContext.getBuilder().addSelectItem(new ColumnAnalysis( - Optional.empty(), - field.getName().orElse(field.getColumnName()), - DEFAULT_ANALYSIS.toMap(), - node.getLocation().orElse(null), - List.of(new ExprSource( + scopedFields.stream() + .filter(field -> field.getSourceColumn().stream().anyMatch(column -> !column.isCalculated() && column.getRelationship().isEmpty())) + .forEach(field -> { + decisionPointContext.getBuilder().addSelectItem(new ColumnAnalysis( + Optional.empty(), field.getName().orElse(field.getColumnName()), - field.getTableName().getSchemaTableName().getTableName(), - field.getSourceColumnName().orElse(null), - node.getLocation().orElse(null))))); - }); + DEFAULT_ANALYSIS.toMap(), + node.getLocation().orElse(null), + List.of(new ExprSource( + field.getName().orElse(field.getColumnName()), + field.getTableName().getSchemaTableName().getTableName(), + field.getSourceColumn().map(Column::getName).orElse(null), + node.getLocation().orElse(null))))); + }); } } return null; diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/RelationAnalyzer.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/RelationAnalyzer.java index aeb1cafdc..d993b2e31 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/RelationAnalyzer.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/decisionpoint/RelationAnalyzer.java @@ -42,6 +42,7 @@ import io.trino.sql.tree.Values; import io.wren.base.SessionContext; import io.wren.base.WrenMDL; +import io.wren.base.dto.Column; import io.wren.base.sqlrewrite.analyzer.Analysis; import io.wren.base.sqlrewrite.analyzer.Scope; @@ -261,7 +262,7 @@ protected Void visitIdentifier(Identifier node, Void context) { scope.getRelationType().resolveFields(QualifiedName.of(node.getValue())) .stream().filter(field -> field.getSourceDatasetName().isPresent()) - .forEach(field -> exprSources.add(new ExprSource(node.getValue(), field.getSourceDatasetName().get(), field.getSourceColumnName().orElse(null), node.getLocation().orElse(null)))); + .forEach(field -> exprSources.add(new ExprSource(node.getValue(), field.getSourceDatasetName().get(), field.getSourceColumn().map(Column::getName).orElse(null), node.getLocation().orElse(null)))); return null; } @@ -271,7 +272,7 @@ protected Void visitDereferenceExpression(DereferenceExpression node, Void conte Optional.ofNullable(getQualifiedName(node)).ifPresent(qualifiedName -> scope.getRelationType().resolveFields(qualifiedName) .stream().filter(field -> field.getSourceDatasetName().isPresent()) - .forEach(field -> exprSources.add(new ExprSource(qualifiedName.toString(), field.getSourceDatasetName().get(), field.getSourceColumnName().orElse(null), node.getLocation().orElse(null))))); + .forEach(field -> exprSources.add(new ExprSource(qualifiedName.toString(), field.getSourceDatasetName().get(), field.getSourceColumn().map(Column::getName).orElse(null), node.getLocation().orElse(null))))); return null; } } diff --git a/wren-base/src/test/java/io/wren/base/sqlrewrite/TestCacheRewrite.java b/wren-base/src/test/java/io/wren/base/sqlrewrite/TestCacheRewrite.java deleted file mode 100644 index 9c5380e10..000000000 --- a/wren-base/src/test/java/io/wren/base/sqlrewrite/TestCacheRewrite.java +++ /dev/null @@ -1,554 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.wren.base.sqlrewrite; - -import com.google.common.collect.ImmutableMap; -import io.trino.sql.tree.Statement; -import io.wren.base.CatalogSchemaTableName; -import io.wren.base.SessionContext; -import io.wren.base.WrenMDL; -import io.wren.base.WrenTypes; -import io.wren.base.dto.Column; -import io.wren.base.dto.Metric; -import io.wren.base.dto.Model; -import io.wren.base.dto.TimeGrain; -import io.wren.base.dto.TimeUnit; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import java.text.MessageFormat; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static io.trino.sql.SqlFormatter.Dialect.DUCKDB; -import static io.trino.sql.SqlFormatter.formatSql; -import static io.wren.base.sqlrewrite.Utils.parseSql; -import static java.lang.String.format; -import static org.assertj.core.api.Assertions.assertThat; - -public class TestCacheRewrite -{ - private WrenMDL wrenMDL; - - private static final Map METRIC_CACHE_NAME_MAPPING = - ImmutableMap.builder() - .put(new CatalogSchemaTableName("wren", "test", "Collection"), "table_Collection") - .put(new CatalogSchemaTableName("wren", "test", "AvgCollection"), "table_AvgCollection") - .put(new CatalogSchemaTableName("wren", "test", "t-1"), "table_t-1") - .put(new CatalogSchemaTableName("wren", "test", "Album"), "table_Album") - .put(new CatalogSchemaTableName("wren", "test", "Tag"), "table_Tag") - .build(); - - @BeforeClass - public void init() - { - wrenMDL = WrenMDL.fromManifest(AbstractTestFramework.withDefaultCatalogSchema() - .setModels(List.of( - Model.model("Album", - "select * from (values (1, 'Gusare', 'ZUTOMAYO', 2560, DATE '2023-03-29', TIMESTAMP '2023-04-27 06:06:06'), " + - "(2, 'HisoHiso Banashi', 'ZUTOMAYO', 1500, DATE '2023-04-29', TIMESTAMP '2023-05-27 07:07:07'), " + - "(3, 'Dakara boku wa ongaku o yameta', 'Yorushika', 2553, DATE '2023-05-29', TIMESTAMP '2023-06-27 08:08:08')) " + - "album(id, name, author, price, publish_date, release_date)", - List.of( - Column.column("id", WrenTypes.INTEGER, null, true), - Column.column("name", WrenTypes.VARCHAR, null, true), - Column.column("author", WrenTypes.VARCHAR, null, true), - Column.column("price", WrenTypes.INTEGER, null, true), - Column.column("publish_date", WrenTypes.DATE, null, true), - Column.column("release_date", WrenTypes.TIMESTAMP, null, true)), - true), - Model.model("Tag", - "select * from (VALUES\n" + - " (1, 'U2', 5),\n" + - " (2, 'Blur', 5),\n" + - " (3, 'Oasis', 5),\n" + - " (4, '2Pac', 6),\n" + - " (5, 'Rock', 7),\n" + - " (6, 'Rap', 7),\n" + - " (7, 'Music', 9),\n" + - " (8, 'Movies', 9),\n" + - " (9, 'Art', NULL))\n" + - "tag(id, name, subclassof)", - List.of( - Column.column("id", WrenTypes.INTEGER, null, true), - Column.column("name", WrenTypes.VARCHAR, null, true), - Column.column("subclassof", WrenTypes.INTEGER, null, false)), - true))) - .setMetrics(List.of( - Metric.metric( - "Collection", - "Album", - List.of( - Column.column("author", WrenTypes.VARCHAR, null, true), - Column.column("album_name", WrenTypes.VARCHAR, null, true, "Album.name")), - List.of(Column.column("price", WrenTypes.INTEGER, null, true, "sum(Album.price)")), - List.of( - TimeGrain.timeGrain("p_date", "Album.publish_date", List.of(TimeUnit.YEAR)), - TimeGrain.timeGrain("r_date", "Album.release_date", List.of(TimeUnit.YEAR))), - true), - Metric.metric( - "AvgCollection", - "Album", - List.of( - Column.column("author", WrenTypes.VARCHAR, null, true), - Column.column("album_name", WrenTypes.VARCHAR, null, true, "Album.name")), - List.of(Column.column("price", WrenTypes.DECIMAL, null, true, "avg(Album.price)")), - List.of( - TimeGrain.timeGrain("p_date", "Album.publish_date", List.of(TimeUnit.YEAR)), - TimeGrain.timeGrain("r_date", "Album.release_date", List.of(TimeUnit.YEAR))), - true), - Metric.metric( - "t-1", - "Album", - List.of( - Column.column("author", WrenTypes.VARCHAR, null, true), - Column.column("album_name", WrenTypes.VARCHAR, null, true, "Album.name")), - List.of(Column.column("price", WrenTypes.INTEGER, null, true, "avg(Album.price)")), - List.of( - TimeGrain.timeGrain("p_date", "Album.publish_date", List.of(TimeUnit.YEAR)), - TimeGrain.timeGrain("r_date", "Album.release_date", List.of(TimeUnit.YEAR))), - true))) - .build()); - } - - @DataProvider(name = "oneTableProvider") - public Object[][] oneTableProvider() - { - return new Object[][] { - {OneTableTestData.create("wren", "test", "wren.test.Collection")}, - {OneTableTestData.create("wren", "test", "test.Collection")}, - {OneTableTestData.create("wren", "test", "Collection")}, - {OneTableTestData.create("wren", "w2", "wren.test.Collection")}, - {OneTableTestData.create("wren", "w2", "test.Collection")}, - {OneTableTestData.create("other", "test", "wren.test.Collection")}, - {OneTableTestData.create("other", "w2", "wren.test.Collection")}, - }; - } - - @DataProvider(name = "twoTableProvider") - public Object[][] twoTableProvider() - { - return new Object[][] { - {TwoTableTestData.create("wren", "test", "wren.test.Collection", "wren.test.AvgCollection")}, - {TwoTableTestData.create("wren", "test", "wren.test.Collection", "test.AvgCollection")}, - {TwoTableTestData.create("wren", "test", "test.Collection", "test.AvgCollection")}, - {TwoTableTestData.create("wren", "test", "Collection", "AvgCollection")}, - }; - } - - @Test(dataProvider = "oneTableProvider") - public void testSelect(OneTableTestData testData) - { - assertOneTable("SELECT * FROM {0}", testData); - } - - @Test - public void testSelectModel() - { - assertRewrite("select * from Album", - "wren", - "test", - "select * from table_Album"); - } - - @Test - public void testSelectWithRecursive() - { - // sample from duckdb https://duckdb.org/docs/sql/query_syntax/with.html - assertRewrite("WITH RECURSIVE tag_hierarchy(id, source, path) AS (\n" + - " SELECT id, name, name AS path\n" + - " FROM Tag\n" + - " WHERE subclassof IS NULL\n" + - "UNION ALL\n" + - " SELECT Tag.id, Tag.name, CONCAT(Tag.name, ',', tag_hierarchy.path)\n" + - " FROM Tag, tag_hierarchy\n" + - " WHERE Tag.subclassof = tag_hierarchy.id\n" + - ")\n" + - "SELECT path\n" + - "FROM tag_hierarchy\n" + - "WHERE source = 'Oasis'", - "wren", - "test", - "WITH RECURSIVE tag_hierarchy(id, source, path) AS (\n" + - " SELECT id, name, name AS path\n" + - " FROM table_Tag\n" + - " WHERE subclassof IS NULL\n" + - "UNION ALL\n" + - " SELECT table_Tag.id, table_Tag.name, CONCAT(table_Tag.name, ',', tag_hierarchy.path)\n" + - " FROM table_Tag, tag_hierarchy\n" + - " WHERE table_Tag.subclassof = tag_hierarchy.id\n" + - ")\n" + - "SELECT path\n" + - "FROM tag_hierarchy\n" + - "WHERE source = 'Oasis'"); - } - - @Test(dataProvider = "twoTableProvider") - public void testJoin(TwoTableTestData testData) - { - assertTwoTables("SELECT * FROM {0} a LEFT JOIN {1} b ON a.author = b.author", testData); - } - - @Test - public void testJoinWithoutAlias() - { - String expectSql = "" + - "SELECT * FROM table_Collection " + - "JOIN table_AvgCollection " + - "ON table_Collection.author = table_AvgCollection.author"; - - assertRewrite( - "SELECT * FROM wren.test.Collection JOIN wren.test.AvgCollection ON Collection.author = AvgCollection.author", - "wren", - "test", - expectSql); - assertRewrite( - "SELECT * FROM test.Collection JOIN test.AvgCollection ON Collection.author = AvgCollection.author", - "wren", - "test", - expectSql); - assertRewrite( - "SELECT * FROM Collection JOIN AvgCollection ON Collection.author = AvgCollection.author", - "wren", - "test", - expectSql); - } - - @Test(dataProvider = "twoTableProvider") - public void testUnion(TwoTableTestData testData) - { - assertTwoTables("SELECT * FROM {0} UNION SELECT * FROM {1}", testData); - } - - @Test(dataProvider = "oneTableProvider") - public void testWithQuery(OneTableTestData testData) - { - assertOneTable("WITH table_alias AS (SELECT * FROM {0}) SELECT * FROM table_alias", testData); - } - - @Test(dataProvider = "twoTableProvider") - public void testWithQueryTwoTable(TwoTableTestData testData) - { - assertTwoTables("WITH " + - "table_alias1 AS (SELECT * FROM {0})," + - "table_alias2 AS (SELECT * FROM {1}) " + - "SELECT * FROM table_alias1 JOIN table_alias2 ON table_alias1.author = table_alias2.author", testData); - } - - @Test(dataProvider = "oneTableProvider") - public void testSubquery(OneTableTestData testData) - { - assertOneTable("SELECT * FROM (SELECT * FROM {0}) AS table_alias", testData); - } - - @Test(dataProvider = "oneTableProvider") - public void testInSubquery(OneTableTestData testData) - { - assertOneTable("SELECT * FROM {0} WHERE key IN (SELECT key FROM {0})", testData); - } - - @Test(dataProvider = "oneTableProvider") - public void testRewriteColumns(OneTableTestData testData) - { - assertOneTable("SELECT {0}.author FROM {0}", testData); - } - - @Test(dataProvider = "oneTableProvider") - public void testRewriteColumnsCallFunctionInWhere(OneTableTestData testData) - { - String sql = "SELECT count(*) AS \"count\" " + - "FROM {0} " + - "WHERE date_trunc('day', {0}.author) BETWEEN date_trunc('day', date_add('day', -30, now())) AND date_trunc('day', date_add('day', -1, now()))"; - assertOneTable(sql, testData); - } - - @Test - public void testEscapeDash() - { - assertRewrite( - "SELECT * FROM \"t-1\"", - "wren", - "test", - "SELECT * FROM \"table_t-1\""); - } - - @DataProvider(name = "aliasSameNameProvider") - public Object[][] aliasSameNameProvider() - { - return new Object[][] { - {"SELECT Collection.author author FROM {0} Collection"}, - {"SELECT Collection.column AS author FROM {0} Collection"}, - {"SELECT Collection.column AS author FROM {0} AS Collection"}, - {"SELECT \"Collection\".\"author\" AS \"author\" FROM {0} AS \"Collection\""}, - }; - } - - @Test(dataProvider = "aliasSameNameProvider") - public void testAliasSameName(String sql) - { - assertRewrite(MessageFormat.format(sql, "test.Collection"), - "wren", - "test", - MessageFormat.format(sql, - "table_Collection")); - } - - @DataProvider(name = "columnDereferenceProvider") - public Object[][] columnDereferenceProvider() - { - return new Object[][] { - {"SELECT Collection.author FROM Collection"}, - {"SELECT Collection.author FROM test.Collection"}, - {"SELECT test.Collection.author FROM test.Collection"}, - {"SELECT Collection.author FROM wren.test.Collection"}, - {"SELECT test.Collection.author FROM wren.test.Collection"}, - {"SELECT wren.test.Collection.author FROM wren.test.Collection"}, - }; - } - - @Test(dataProvider = "columnDereferenceProvider") - public void testColumnDereferenceRewrite(String sql) - { - assertRewrite( - sql, - "wren", - "test", - MessageFormat.format("SELECT {0}.author FROM {0}", "table_Collection")); - } - - @Test(dataProvider = "oneTableProvider") - public void testFunction(OneTableTestData testData) - { - assertOneTable("SELECT author, count(*) FROM {0} GROUP BY author", testData); - } - - @Test - public void testTableAliasScope() - { - assertRewrite( - "with test_a as (SELECT * FROM Collection Collection) select * from Collection", - "wren", - "test", - "with test_a as (SELECT * FROM table_Collection Collection) select * from table_Collection"); - - assertRewrite( - "with test_a as (with AvgCollection as (select * from Collection) select * from AvgCollection) select * from AvgCollection", - "wren", - "test", - "with test_a as (with AvgCollection as (select * from table_Collection) select * from AvgCollection) select * from table_AvgCollection"); - } - - @Test - public void testDecimalRewrite() - { - assertRewrite( - "SELECT * from AvgCollection where avg = DECIMAL '1.0'", - "wren", - "test", - "SELECT * FROM table_AvgCollection WHERE avg = 1.0"); - } - - @DataProvider(name = "unexpectedStatementProvider") - public Object[][] unexpectedStatementProvider() - { - return new Object[][] { - {"explain analyze select * from Collection"}, - {"prepare aa from select * from Collection"}, - {"execute aa"}, - {"deallocate prepare aa"}, - {"describe output aa"}, - {"describe input aa"}, - {"explain select * from Collection"}, - {"show tables from test"}, - {"show schemas from wren"}, - {"show catalogs"}, - {"show columns from Collection"}, - {"show stats for Collection"}, - {"show create table Collection"}, - {"show functions"}, - {"show session"}, - {"use wren.test"}, - {"use wren.test"}, - {"set session catalog.name = wren"}, - {"reset session optimize_hash_generation"}, - {"create view test_view as select * from Collection"}, - {"drop view if exists test_view"}, - {"insert into cities values (1, 'San Francisco')"}, - {"call test(name => 'apple', id => 123)"}, - {"delete from lineitem where shipmode = 'AIR'"}, - {"start transaction"}, - {"create role admin"}, - {"drop role admin"}, - {"grant bar to user foo"}, - {"revoke insert, select on orders from alice"}, - {"show grants"}, - {"show role grants from wren"}, - {"commit"}, - {"rollback"}, - {"select 1"}, - }; - } - - @Test(dataProvider = "unexpectedStatementProvider") - public void testUnexpectedStatement(String sql) - { - assertThat(rewriteCached(sql)).isEmpty(); - } - - private void assertOneTable(String sqlFormat, OneTableTestData testData) - { - assertRewrite(MessageFormat.format(sqlFormat, testData.table), - testData.defaultCatalog, - testData.defaultSchema, - MessageFormat.format(sqlFormat, "table_Collection")); - } - - private void assertTwoTables(String sqlFormat, TwoTableTestData testData) - { - assertRewrite(MessageFormat.format(sqlFormat, testData.table1, testData.table2), - testData.defaultCatalog, - testData.defaultSchema, - MessageFormat.format(sqlFormat, "table_Collection", "table_AvgCollection")); - - assertRewrite(MessageFormat.format(sqlFormat, testData.table2, testData.table1), - testData.defaultCatalog, - testData.defaultSchema, - MessageFormat.format(sqlFormat, "table_AvgCollection", "table_Collection")); - } - - private void assertRewrite( - String sql, - String defaultCatalog, - String defaultSchema, - String expectSql) - { - assertRewrite( - sql, - defaultCatalog, - defaultSchema, - expectSql, - this::toCacheTable); - } - - private void assertRewrite( - String sql, - String defaultCatalog, - String defaultSchema, - String expectSql, - Function> tableConverter) - { - String result = rewriteCached( - sql, - defaultCatalog, - defaultSchema, - tableConverter).orElseThrow(() -> new AssertionError("No rewrite result")); - - Statement expect = parseSql(expectSql); - Statement actualStatement = parseSql(result); - assertThat(result).isEqualTo(formatSql(expect, DUCKDB)); - assertThat(actualStatement).isEqualTo(expect); - } - - private Optional rewriteCached(String sql) - { - return rewriteCached( - sql, - "wren", - "test", - this::toCacheTable); - } - - private Optional rewriteCached( - String sql, - String defaultCatalog, - String defaultSchema, - Function> tableConverter) - { - SessionContext sessionContext = SessionContext.builder() - .setCatalog(defaultCatalog) - .setSchema(defaultSchema) - .build(); - return CacheRewrite.rewrite( - sessionContext, - sql, - tableConverter, - wrenMDL); - } - - private static class OneTableTestData - { - private final String defaultCatalog; - private final String defaultSchema; - private final String table; - - private static OneTableTestData create(String defaultCatalog, String defaultSchema, String table) - { - return new OneTableTestData(defaultCatalog, defaultSchema, table); - } - - private OneTableTestData(String defaultCatalog, String defaultSchema, String table) - { - this.defaultCatalog = defaultCatalog; - this.defaultSchema = defaultSchema; - this.table = table; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("table", format("%s.%s.%s", defaultCatalog, defaultSchema, table)) - .toString(); - } - } - - private static class TwoTableTestData - { - private final String defaultCatalog; - private final String defaultSchema; - private final String table1; - private final String table2; - - private static TwoTableTestData create(String defaultCatalog, String defaultSchema, String table1, String table2) - { - return new TwoTableTestData(defaultCatalog, defaultSchema, table1, table2); - } - - private TwoTableTestData(String defaultCatalog, String defaultSchema, String table1, String table2) - { - this.defaultCatalog = defaultCatalog; - this.defaultSchema = defaultSchema; - this.table1 = table1; - this.table2 = table2; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("table1", format("%s.%s.%s", defaultCatalog, defaultSchema, table1)) - .add("table2", format("%s.%s.%s", defaultCatalog, defaultSchema, table2)) - .toString(); - } - } - - private Optional toCacheTable(CatalogSchemaTableName tableName) - { - return Optional.ofNullable(METRIC_CACHE_NAME_MAPPING.get(tableName)); - } -} diff --git a/wren-base/src/test/java/io/wren/base/sqlrewrite/analyzer/TestDecisionPointAnalyzer.java b/wren-base/src/test/java/io/wren/base/sqlrewrite/analyzer/TestDecisionPointAnalyzer.java index 9cb2a113b..5d90c92a4 100644 --- a/wren-base/src/test/java/io/wren/base/sqlrewrite/analyzer/TestDecisionPointAnalyzer.java +++ b/wren-base/src/test/java/io/wren/base/sqlrewrite/analyzer/TestDecisionPointAnalyzer.java @@ -21,7 +21,9 @@ import io.wren.base.WrenMDL; import io.wren.base.WrenTypes; import io.wren.base.dto.Column; +import io.wren.base.dto.JoinType; import io.wren.base.dto.Manifest; +import io.wren.base.dto.Relationship; import io.wren.base.sqlrewrite.analyzer.decisionpoint.DecisionPointAnalyzer; import io.wren.base.sqlrewrite.analyzer.decisionpoint.ExprSource; import io.wren.base.sqlrewrite.analyzer.decisionpoint.FilterAnalysis; @@ -72,7 +74,9 @@ public TestDecisionPointAnalyzer() Column.column("orderpriority", WrenTypes.VARCHAR, null, true), Column.column("clerk", WrenTypes.VARCHAR, null, true), Column.column("shippriority", WrenTypes.INTEGER, null, true), - Column.column("comment", WrenTypes.VARCHAR, null, true)); + Column.column("comment", WrenTypes.VARCHAR, null, true), + Column.column("customer", "customer", "CustomerOrders", false), + Column.caluclatedColumn("customer_name", WrenTypes.VARCHAR, "customer.name")); List lineitemColumns = List.of( Column.column("orderkey", WrenTypes.INTEGER, null, true), Column.column("partkey", WrenTypes.INTEGER, null, true), @@ -97,6 +101,7 @@ public TestDecisionPointAnalyzer() .setModels(List.of(onTableReference("customer", tableReference(null, "main", "customer"), customerColumns, "custkey"), onTableReference("orders", tableReference(null, "main", "orders"), ordersColumns, "orderkey"), onTableReference("lineitem", tableReference(null, "main", "lineitem"), lineitemColumns, null))) + .setRelationships(List.of(Relationship.relationship("CustomerOrders", List.of("customer", "orders"), JoinType.ONE_TO_MANY, "customer.custkey = orders.custkey"))) .build()); } @@ -185,6 +190,20 @@ public void testSelectItem() assertThat(result.size()).isEqualTo(1); assertThat(result.get(0).getSelectItems().size()).isEqualTo(1); assertThat(result.get(0).getSelectItems().get(0).getExpression()).isEqualTo("c.*"); + + statement = parseSql("SELECT customer_name FROM orders"); + result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl); + assertThat(result.size()).isEqualTo(1); + assertThat(result.get(0).getSelectItems().size()).isEqualTo(1); + assertThat(result.get(0).getSelectItems().get(0).getExpression()).isEqualTo("customer_name"); + assertThat(result.get(0).getSelectItems().get(0).getExprSources()).isEqualTo(List.of(new ExprSource("customer_name", "orders", "customer_name", new NodeLocation(1, 8)))); + + statement = parseSql("SELECT customer FROM orders"); + result = DecisionPointAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, mdl); + assertThat(result.size()).isEqualTo(1); + assertThat(result.get(0).getSelectItems().size()).isEqualTo(1); + assertThat(result.get(0).getSelectItems().get(0).getExpression()).isEqualTo("customer"); + assertThat(result.get(0).getSelectItems().get(0).getExprSources()).isEqualTo(List.of()); } @Test @@ -253,7 +272,6 @@ public void testRelation() assertThat(joinRelation.getRight().getNodeLocation()).isEqualTo(new NodeLocation(1, 30)); assertThat(((TableRelation) joinRelation.getRight()).getTableName()).isEqualTo("orders"); assertThat(joinRelation.getCriteria()).isEqualTo(joinCriteria("ON (customer.custkey = orders.custkey)", new NodeLocation(1, 40))); - assertThat(joinRelation.getExprSources().size()).isEqualTo(2); assertThat(Set.copyOf(joinRelation.getExprSources())).isEqualTo(Set.of( new ExprSource("customer.custkey", "customer", "custkey", new NodeLocation(1, 40)), new ExprSource("orders.custkey", "orders", "custkey", new NodeLocation(1, 59))));