diff --git a/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java b/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java index 98d7f254ba6a..9179c3ccf41c 100644 --- a/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java +++ b/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java @@ -37,6 +37,7 @@ import io.trino.metadata.ProcedureRegistry; import io.trino.metadata.SchemaPropertyManager; import io.trino.metadata.SessionPropertyManager; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; import io.trino.metadata.TablePropertyManager; @@ -60,6 +61,7 @@ import io.trino.spi.connector.TableProcedureMetadata; import io.trino.spi.eventlistener.EventListener; import io.trino.spi.procedure.Procedure; +import io.trino.spi.ptf.ConnectorTableFunction; import io.trino.spi.session.PropertyMetadata; import io.trino.spi.type.TypeManager; import io.trino.split.PageSinkManager; @@ -119,6 +121,7 @@ public class ConnectorManager private final TypeManager typeManager; private final ProcedureRegistry procedureRegistry; private final TableProceduresRegistry tableProceduresRegistry; + private final TableFunctionRegistry tableFunctionRegistry; private final SessionPropertyManager sessionPropertyManager; private final SchemaPropertyManager schemaPropertyManager; private final ColumnPropertyManager columnPropertyManager; @@ -158,6 +161,7 @@ public ConnectorManager( TypeManager typeManager, ProcedureRegistry procedureRegistry, TableProceduresRegistry tableProceduresRegistry, + TableFunctionRegistry tableFunctionRegistry, SessionPropertyManager sessionPropertyManager, SchemaPropertyManager schemaPropertyManager, ColumnPropertyManager columnPropertyManager, @@ -186,6 +190,7 @@ public ConnectorManager( this.typeManager = typeManager; this.procedureRegistry = procedureRegistry; this.tableProceduresRegistry = tableProceduresRegistry; + this.tableFunctionRegistry = tableFunctionRegistry; this.sessionPropertyManager = sessionPropertyManager; this.schemaPropertyManager = schemaPropertyManager; this.columnPropertyManager = columnPropertyManager; @@ -333,6 +338,7 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) procedureRegistry.addProcedures(catalogName, connector.getProcedures()); Set tableProcedures = connector.getTableProcedures(); tableProceduresRegistry.addTableProcedures(catalogName, tableProcedures); + tableFunctionRegistry.addTableFunctions(catalogName, connector.getTableFunctions()); connector.getAccessControl() .ifPresent(accessControl -> accessControlManager.addCatalogAccessControl(catalogName, accessControl)); @@ -369,6 +375,7 @@ private synchronized void removeConnectorInternal(CatalogName catalogName) nodePartitioningManager.removePartitioningProvider(catalogName); procedureRegistry.removeProcedures(catalogName); tableProceduresRegistry.removeProcedures(catalogName); + tableFunctionRegistry.removeTableFunctions(catalogName); accessControlManager.removeCatalogAccessControl(catalogName); tablePropertyManager.removeProperties(catalogName); materializedViewPropertyManager.removeProperties(catalogName); @@ -495,6 +502,7 @@ private static class MaterializedConnector private final Set systemTables; private final Set procedures; private final Set tableProcedures; + private final Set connectorTableFunctions; private final Optional splitManager; private final Optional pageSourceProvider; private final Optional pageSinkProvider; @@ -527,6 +535,10 @@ public MaterializedConnector(CatalogName catalogName, Connector connector, Runna requireNonNull(tableProcedures, format("Connector '%s' returned a null table procedures set", catalogName)); this.tableProcedures = ImmutableSet.copyOf(tableProcedures); + Set connectorTableFunctions = connector.getTableFunctions(); + requireNonNull(connectorTableFunctions, format("Connector '%s' returned a null table functions set", catalogName)); + this.connectorTableFunctions = ImmutableSet.copyOf(connectorTableFunctions); + ConnectorSplitManager splitManager = null; try { splitManager = connector.getSplitManager(); @@ -642,6 +654,11 @@ public Set getTableProcedures() return tableProcedures; } + public Set getTableFunctions() + { + return connectorTableFunctions; + } + public Optional getSplitManager() { return splitManager; diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java index 22f4935b59a7..a2781957ac60 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java @@ -188,7 +188,7 @@ FunctionBinding resolveFunction( throw new TrinoException(FUNCTION_NOT_FOUND, message); } - private static List toPath(Session session, QualifiedName name) + public static List toPath(Session session, QualifiedName name) { List parts = name.getParts(); checkArgument(parts.size() <= 3, "Function name can only have 3 parts: " + name); diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index f15fa85961b9..48d5e6b68582 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -42,6 +42,7 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableColumnsMetadata; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; @@ -473,6 +474,8 @@ Optional> applyTopN( List sortItems, Map assignments); + Optional> applyTableFunction(Session session, TableFunctionHandle handle); + default void validateScan(Session session, TableHandle table) {} // diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 096f36b39003..537985f73ef3 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -72,6 +72,7 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableColumnsMetadata; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; @@ -1649,6 +1650,18 @@ public Optional> applyTopN( result.isPrecalculateStatistics())); } + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + CatalogName catalogName = handle.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + + return metadata.applyTableFunction(session.toConnectorSession(catalogName), handle.getFunctionHandle()) + .map(result -> new TableFunctionApplicationResult<>( + new TableHandle(catalogName, result.getTableHandle(), handle.getTransactionHandle()), + result.getColumnHandles())); + } + private void verifyProjection(TableHandle table, List projections, List assignments, int expectedProjectionSize) { projections.forEach(projection -> requireNonNull(projection, "one of the projections is null")); diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java new file mode 100644 index 000000000000..85d330481cfc --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.connector.CatalogName; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionHandle +{ + private final CatalogName catalogName; + private final ConnectorTableFunctionHandle functionHandle; + private final ConnectorTransactionHandle transactionHandle; + + @JsonCreator + public TableFunctionHandle( + @JsonProperty("catalogName") CatalogName catalogName, + @JsonProperty("functionHandle") ConnectorTableFunctionHandle functionHandle, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + } + + @JsonProperty + public CatalogName getCatalogName() + { + return catalogName; + } + + @JsonProperty + public ConnectorTableFunctionHandle getFunctionHandle() + { + return functionHandle; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionMetadata.java new file mode 100644 index 000000000000..a60ea5e68189 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionMetadata.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import io.trino.connector.CatalogName; +import io.trino.spi.ptf.ConnectorTableFunction; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionMetadata +{ + private final CatalogName catalogName; + private final ConnectorTableFunction function; + + public TableFunctionMetadata(CatalogName catalogName, ConnectorTableFunction function) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.function = requireNonNull(function, "function is null"); + } + + public CatalogName getCatalogName() + { + return catalogName; + } + + public ConnectorTableFunction getFunction() + { + return function; + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java new file mode 100644 index 000000000000..08f9627753ac --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.spi.ptf.ConnectorTableFunction; +import io.trino.sql.tree.QualifiedName; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.metadata.FunctionResolver.toPath; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class TableFunctionRegistry +{ + // catalog name in the original case; schema and function name in lowercase + private final Map> tableFunctions = new ConcurrentHashMap<>(); + + public void addTableFunctions(CatalogName catalogName, Collection functions) + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(functions, "functions is null"); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (ConnectorTableFunction function : functions) { + builder.put( + new SchemaFunctionName( + function.getSchema().toLowerCase(ENGLISH), + function.getName().toLowerCase(ENGLISH)), + new TableFunctionMetadata(catalogName, function)); + } + checkState(tableFunctions.putIfAbsent(catalogName, builder.buildOrThrow()) == null, "Table functions already registered for catalog: " + catalogName); + } + + public void removeTableFunctions(CatalogName catalogName) + { + tableFunctions.remove(catalogName); + } + + /** + * Resolve table function with given qualified name. + * Table functions are resolved case-insensitive for consistency with existing scalar function resolution. + */ + public TableFunctionMetadata resolve(Session session, QualifiedName qualifiedName) + { + for (CatalogSchemaFunctionName name : toPath(session, qualifiedName)) { + CatalogName catalogName = new CatalogName(name.getCatalogName()); + Map catalogFunctions = tableFunctions.get(catalogName); + if (catalogFunctions != null) { + String lowercasedSchemaName = name.getSchemaFunctionName().getSchemaName().toLowerCase(ENGLISH); + String lowercasedFunctionName = name.getSchemaFunctionName().getFunctionName().toLowerCase(ENGLISH); + TableFunctionMetadata function = catalogFunctions.get(new SchemaFunctionName(lowercasedSchemaName, lowercasedFunctionName)); + if (function != null) { + return function; + } + } + } + + return null; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 65dbe73fbbfd..060d2117e1d8 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -86,6 +86,7 @@ import io.trino.metadata.StaticCatalogStoreConfig; import io.trino.metadata.SystemFunctionBundle; import io.trino.metadata.SystemSecurityMetadata; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; import io.trino.metadata.TablePropertyManager; @@ -391,6 +392,7 @@ protected void setup(Binder binder) newExporter(binder).export(TypeOperatorsCache.class).withGeneratedName(); binder.bind(ProcedureRegistry.class).in(Scopes.SINGLETON); binder.bind(TableProceduresRegistry.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(PlannerContext.class).in(Scopes.SINGLETON); // function diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index b5064c84e367..8c630b74ced1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -23,6 +23,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.Multiset; import com.google.common.collect.Streams; +import io.trino.connector.CatalogName; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableExecuteHandle; @@ -34,10 +35,13 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnSchema; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.eventlistener.ColumnDetail; import io.trino.spi.eventlistener.ColumnInfo; import io.trino.spi.eventlistener.RoutineInfo; import io.trino.spi.eventlistener.TableInfo; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.Identity; import io.trino.spi.type.Type; import io.trino.sql.analyzer.ExpressionAnalyzer.LabelPrefixedReference; @@ -68,6 +72,7 @@ import io.trino.sql.tree.Statement; import io.trino.sql.tree.SubqueryExpression; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.Unnest; import io.trino.sql.tree.WindowFrame; import io.trino.sql.tree.WindowOperation; @@ -221,6 +226,8 @@ public class Analysis private Optional tableExecuteHandle = Optional.empty(); + private final Map, TableFunctionInvocationAnalysis> tableFunctionAnalyses = new LinkedHashMap<>(); + public Analysis(@Nullable Statement root, Map, Expression> parameters, QueryType queryType) { this.root = root; @@ -1153,6 +1160,16 @@ public Optional getTableExecuteHandle() return tableExecuteHandle; } + public void setTableFunctionAnalysis(TableFunctionInvocation node, TableFunctionInvocationAnalysis analysis) + { + tableFunctionAnalyses.put(NodeRef.of(node), analysis); + } + + public TableFunctionInvocationAnalysis getTableFunctionAnalysis(TableFunctionInvocation node) + { + return tableFunctionAnalyses.get(NodeRef.of(node)); + } + private boolean isInputTable(Table table) { return !(isUpdateTarget(table) || isInsertTarget(table)); @@ -1857,4 +1874,52 @@ public Optional getAtMost() return atMost; } } + + public static class TableFunctionInvocationAnalysis + { + private final CatalogName catalogName; + private final String functionName; + private final Map arguments; + private final ConnectorTableFunctionHandle connectorTableFunctionHandle; + private final ConnectorTransactionHandle transactionHandle; + + public TableFunctionInvocationAnalysis( + CatalogName catalogName, + String functionName, + Map arguments, + ConnectorTableFunctionHandle connectorTableFunctionHandle, + ConnectorTransactionHandle transactionHandle) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = requireNonNull(arguments, "arguments is null"); + this.connectorTableFunctionHandle = requireNonNull(connectorTableFunctionHandle, "connectorTableFunctionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + } + + public CatalogName getCatalogName() + { + return catalogName; + } + + public String getFunctionName() + { + return functionName; + } + + public Map getArguments() + { + return arguments; + } + + public ConnectorTableFunctionHandle getConnectorTableFunctionHandle() + { + return connectorTableFunctionHandle; + } + + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index f119639d8d36..443b8a11710a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -38,6 +38,8 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.SessionPropertyManager; import io.trino.metadata.TableExecuteHandle; +import io.trino.metadata.TableFunctionMetadata; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableHandle; import io.trino.metadata.TableLayout; import io.trino.metadata.TableMetadata; @@ -50,6 +52,7 @@ import io.trino.metadata.ViewDefinition; import io.trino.security.AccessControl; import io.trino.security.AllowAllAccessControl; +import io.trino.security.SecurityContext; import io.trino.security.ViewAccessControl; import io.trino.spi.TrinoException; import io.trino.spi.TrinoWarning; @@ -59,10 +62,21 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ColumnSchema; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.PointerType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableProcedureMetadata; import io.trino.spi.function.OperatorType; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.ArgumentSpecification; +import io.trino.spi.ptf.ConnectorTableFunction; +import io.trino.spi.ptf.Descriptor; +import io.trino.spi.ptf.DescriptorArgumentSpecification; +import io.trino.spi.ptf.ReturnTypeSpecification; +import io.trino.spi.ptf.ReturnTypeSpecification.DescribedTable; +import io.trino.spi.ptf.ScalarArgument; +import io.trino.spi.ptf.ScalarArgumentSpecification; +import io.trino.spi.ptf.TableArgumentSpecification; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.GroupProvider; import io.trino.spi.security.Identity; @@ -84,6 +98,7 @@ import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.Analysis.SourceColumn; +import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import io.trino.sql.analyzer.Analysis.UnnestAnalysis; import io.trino.sql.analyzer.PatternRecognitionAnalyzer.PatternRecognitionAnalysis; import io.trino.sql.analyzer.Scope.AsteriskedIdentifierChainBasis; @@ -114,6 +129,7 @@ import io.trino.sql.tree.Delete; import io.trino.sql.tree.Deny; import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.DescriptorArgument; import io.trino.sql.tree.DropColumn; import io.trino.sql.tree.DropMaterializedView; import io.trino.sql.tree.DropSchema; @@ -190,6 +206,8 @@ import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableExecute; +import io.trino.sql.tree.TableFunctionArgument; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TruncateTable; import io.trino.sql.tree.Union; @@ -207,6 +225,7 @@ import io.trino.sql.tree.WindowSpecification; import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; +import io.trino.transaction.TransactionManager; import io.trino.type.TypeCoercion; import java.util.ArrayList; @@ -239,6 +258,7 @@ import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; import static io.trino.metadata.MetadataUtil.getRequiredCatalogHandle; import static io.trino.spi.StandardErrorCode.AMBIGUOUS_NAME; +import static io.trino.spi.StandardErrorCode.AMBIGUOUS_RETURN_TYPE; import static io.trino.spi.StandardErrorCode.COLUMN_NOT_FOUND; import static io.trino.spi.StandardErrorCode.COLUMN_TYPE_UNKNOWN; import static io.trino.spi.StandardErrorCode.DUPLICATE_COLUMN_NAME; @@ -247,6 +267,7 @@ import static io.trino.spi.StandardErrorCode.DUPLICATE_WINDOW_NAME; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_IN_DISTINCT; +import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_WINDOW; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE; @@ -260,10 +281,12 @@ import static io.trino.spi.StandardErrorCode.INVALID_WINDOW_FRAME; import static io.trino.spi.StandardErrorCode.INVALID_WINDOW_REFERENCE; import static io.trino.spi.StandardErrorCode.MISMATCHED_COLUMN_ALIASES; +import static io.trino.spi.StandardErrorCode.MISSING_ARGUMENT; import static io.trino.spi.StandardErrorCode.MISSING_COLUMN_ALIASES; import static io.trino.spi.StandardErrorCode.MISSING_COLUMN_NAME; import static io.trino.spi.StandardErrorCode.MISSING_GROUP_BY; import static io.trino.spi.StandardErrorCode.MISSING_ORDER_BY; +import static io.trino.spi.StandardErrorCode.MISSING_RETURN_TYPE; import static io.trino.spi.StandardErrorCode.NESTED_RECURSIVE; import static io.trino.spi.StandardErrorCode.NESTED_ROW_PATTERN_RECOGNITION; import static io.trino.spi.StandardErrorCode.NESTED_WINDOW; @@ -279,6 +302,8 @@ import static io.trino.spi.StandardErrorCode.VIEW_IS_RECURSIVE; import static io.trino.spi.StandardErrorCode.VIEW_IS_STALE; import static io.trino.spi.connector.StandardWarningCode.REDUNDANT_ORDER_BY; +import static io.trino.spi.ptf.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static io.trino.spi.ptf.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -330,7 +355,9 @@ class StatementAnalyzer private final SqlParser sqlParser; private final GroupProvider groupProvider; private final AccessControl accessControl; + private final TransactionManager transactionManager; private final TableProceduresRegistry tableProceduresRegistry; + private final TableFunctionRegistry tableFunctionRegistry; private final SessionPropertyManager sessionPropertyManager; private final TablePropertyManager tablePropertyManager; private final AnalyzePropertyManager analyzePropertyManager; @@ -346,8 +373,10 @@ class StatementAnalyzer SqlParser sqlParser, GroupProvider groupProvider, AccessControl accessControl, + TransactionManager transactionManager, Session session, TableProceduresRegistry tableProceduresRegistry, + TableFunctionRegistry tableFunctionRegistry, SessionPropertyManager sessionPropertyManager, TablePropertyManager tablePropertyManager, AnalyzePropertyManager analyzePropertyManager, @@ -363,8 +392,10 @@ class StatementAnalyzer this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.session = requireNonNull(session, "session is null"); this.tableProceduresRegistry = requireNonNull(tableProceduresRegistry, "tableProceduresRegistry is null"); + this.tableFunctionRegistry = requireNonNull(tableFunctionRegistry, "tableFunctionRegistry is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); this.analyzePropertyManager = requireNonNull(analyzePropertyManager, "analyzePropertyManager is null"); @@ -1437,6 +1468,231 @@ protected Scope visitLateral(Lateral node, Optional scope) return createAndAssignScope(node, scope, queryScope.getRelationType()); } + @Override + protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional scope) + { + TableFunctionMetadata tableFunctionMetadata = tableFunctionRegistry.resolve(session, node.getName()); + if (tableFunctionMetadata == null) { + throw semanticException(FUNCTION_NOT_FOUND, node, "Table function %s not registered", node.getName()); + } + + ConnectorTableFunction function = tableFunctionMetadata.getFunction(); + CatalogName catalogName = tableFunctionMetadata.getCatalogName(); + + QualifiedObjectName functionName = new QualifiedObjectName(catalogName.getCatalogName(), function.getSchema(), function.getName()); + accessControl.checkCanExecuteFunction(SecurityContext.of(session), functionName); + + Map passedArguments = analyzeArguments(node, function.getArguments(), node.getArguments()); + + // a call to getRequiredCatalogHandle() is necessary so that the catalog is recorded by the TransactionManager + ConnectorTransactionHandle transactionHandle = transactionManager.getConnectorTransaction( + session.getRequiredTransactionId(), + getRequiredCatalogHandle(metadata, session, node, catalogName.getCatalogName())); + ConnectorTableFunction.Analysis functionAnalysis = function.analyze(session.toConnectorSession(catalogName), transactionHandle, passedArguments); + analysis.setTableFunctionAnalysis(node, new TableFunctionInvocationAnalysis(catalogName, functionName.toString(), passedArguments, functionAnalysis.getHandle(), transactionHandle)); + + // TODO handle the DescriptorMapping descriptorsToTables mapping from the TableFunction.Analysis: + // This is a mapping of descriptor arguments to table arguments. It consists of two parts: + // - mapping by descriptor field: (arg name of descriptor argument, and position in the descriptor) to (arg name of table argument) + // - mapping by descriptor: (arg name of descriptor argument) to (arg name of table argument) + // 1. get the DescriptorField from the designated DescriptorArgument (or all fields for mapping by descriptor) + // 2. validate there is no DataType specified, + // 3. analyze the Identifier in the scope of the designated table (it is recorded, because args were already analyzed). Disable correlation. + // 4. at this point, the Identifier should be recorded as a column reference to the appropriate table + // 5. record the mapping NameAndPosition -> Identifier + // ... later translate Identifier to Symbol in Planner, and eventually translate it to channel before execution + if (!functionAnalysis.getDescriptorsToTables().isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "Table arguments are not yet supported for table functions"); + } + + // TODO process the copartitioning: + // 1. validate input table references + // 2. the copartitioned tables in each set must be partitioned, and have the same number of partitioning columns + // 3. the corresponding columns must be comparable + // 4. within a set, determine and record coercions of the corresponding columns to a common supertype + // Note that if a table is part of multiple copartitioning sets, it might require a different coercion for a column + // per each set. Additionally, there might be another coercion required by the Table Function logic. Also, since + // all partitioning columns are passed-through, we also need an un-coerced copy. + // See ExpressionAnalyzer.sortKeyCoercionsForFrameBoundCalculation for multiple coercions on a column. + if (!node.getCopartitioning().isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "COPARTITION clause is not yet supported for table functions"); + } + + // determine the result relation type. + // The result relation type of a table function consists of: + // 1. passed columns from input tables: + // - for tables with the "pass through columns" option, these are all columns of the table, + // - for tables without the "pass through columns" option, these are the partitioning columns of the table, if any. + // 2. columns created by the table function, called the proper columns. + ReturnTypeSpecification returnTypeSpecification = function.getReturnTypeSpecification(); + Optional analyzedProperColumnsDescriptor = functionAnalysis.getReturnedType(); + Descriptor properColumnsDescriptor; + if (returnTypeSpecification == ONLY_PASS_THROUGH) { + // this option is only allowed if there are input tables + throw semanticException(NOT_SUPPORTED, node, "Returning only pass through columns is not yet supported for table functions"); + } + if (returnTypeSpecification == GENERIC_TABLE) { + properColumnsDescriptor = analyzedProperColumnsDescriptor + .orElseThrow(() -> semanticException(MISSING_RETURN_TYPE, node, "Cannot determine returned relation type for table function " + node.getName())); + } + else { + // returned type is statically declared at function declaration and cannot be overridden + if (analyzedProperColumnsDescriptor.isPresent()) { + throw semanticException(AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); + } + properColumnsDescriptor = ((DescribedTable) returnTypeSpecification).getDescriptor(); + } + + // currently we don't support input tables, so the output consists of proper columns only + List fields = properColumnsDescriptor.getFields().stream() + // per spec, field names are mandatory + .map(field -> Field.newUnqualified(field.getName(), field.getType().orElseThrow(() -> new IllegalStateException("missing returned type for proper field")))) + .collect(toImmutableList()); + + return createAndAssignScope(node, scope, fields); + } + + private Map analyzeArguments(Node node, List argumentSpecifications, List arguments) + { + Node errorLocation = node; + if (!arguments.isEmpty()) { + errorLocation = arguments.get(0); + } + + if (argumentSpecifications.size() < arguments.size()) { + throw semanticException(INVALID_ARGUMENTS, errorLocation, "Too many arguments. Expected at most %s arguments, got %s arguments", argumentSpecifications.size(), arguments.size()); + } + + if (arguments.isEmpty()) { + return ImmutableMap.of(); + } + + boolean argumentsPassedByName = arguments.stream().allMatch(argument -> argument.getName().isPresent()); + boolean argumentsPassedByPosition = arguments.stream().allMatch(argument -> argument.getName().isEmpty()); + if (!argumentsPassedByName && !argumentsPassedByPosition) { + throw semanticException(INVALID_ARGUMENTS, errorLocation, "All arguments must be passed by name or all must be passed positionally"); + } + + ImmutableMap.Builder passedArguments = ImmutableMap.builder(); + if (argumentsPassedByName) { + Map argumentSpecificationsByName = new HashMap<>(); + for (ArgumentSpecification argumentSpecification : argumentSpecifications) { + if (argumentSpecificationsByName.put(argumentSpecification.getName(), argumentSpecification) != null) { + throw new IllegalStateException("Duplicate argument specification for name: " + argumentSpecification.getName()); + } + } + Set uniqueArgumentNames = new HashSet<>(); + for (TableFunctionArgument argument : arguments) { + String argumentName = argument.getName().get().getCanonicalValue(); + if (!uniqueArgumentNames.add(argumentName)) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Duplicate argument name: ", argumentName); + } + ArgumentSpecification argumentSpecification = argumentSpecificationsByName.remove(argumentName); + if (argumentSpecification == null) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected argument name: ", argumentName); + } + passedArguments.put(argumentSpecification.getName(), analyzeArgument(argumentSpecification, argument)); + } + // apply defaults for not specified arguments + for (Map.Entry entry : argumentSpecificationsByName.entrySet()) { + ArgumentSpecification argumentSpecification = entry.getValue(); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + } + else { + for (int i = 0; i < arguments.size(); i++) { + TableFunctionArgument argument = arguments.get(i); + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); // TODO args passed positionally - can one only pass some prefix of args? + passedArguments.put(argumentSpecification.getName(), analyzeArgument(argumentSpecification, argument)); + } + // apply defaults for not specified arguments + for (int i = arguments.size(); i < argumentSpecifications.size(); i++) { + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + } + + return passedArguments.buildOrThrow(); + } + + private Argument analyzeArgument(ArgumentSpecification argumentSpecification, TableFunctionArgument argument) + { + String actualType; + if (argument.getValue() instanceof Relation) { + actualType = "table"; + } + else if (argument.getValue() instanceof DescriptorArgument) { + actualType = "descriptor"; + } + else if (argument.getValue() instanceof Expression) { + actualType = "expression"; + } + else { + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Unexpected table function argument type: ", argument.getClass().getSimpleName()); + } + + if (argumentSpecification instanceof TableArgumentSpecification) { + if (!(argument.getValue() instanceof Relation)) { + if (argument.getValue() instanceof FunctionCall) { + // probably an attempt to pass a table function call, which is not supported, and was parsed as a function call + throw semanticException(NOT_SUPPORTED, argument, "Invalid table argument %s. Table functions are not allowed as table function arguments", argumentSpecification.getName()); + } + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected table, got %s", argumentSpecification.getName(), actualType); + } + // TODO analyze the argument + // 1. process the Relation + // 2. partitioning and ordering must only apply to tables with set semantics + // 3. validate partitioning and ordering using `validateAndGetInputField()` + // 4. validate the prune when empty property vs argument specification (forbidden for row semantics; override? -> check spec) + // 5. return Argument + throw semanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); + } + if (argumentSpecification instanceof DescriptorArgumentSpecification) { + if (!(argument.getValue() instanceof DescriptorArgument)) { + if (argument.getValue() instanceof FunctionCall && ((FunctionCall) argument.getValue()).getName().hasSuffix(QualifiedName.of("descriptor"))) { // function name is always compared case-insensitive + // malformed descriptor which parsed as a function call + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid descriptor argument %s. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", argumentSpecification.getName()); + } + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected descriptor, got %s", argumentSpecification.getName(), actualType); + } + throw semanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); + } + if (argumentSpecification instanceof ScalarArgumentSpecification) { + if (!(argument.getValue() instanceof Expression)) { + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected expression, got %s", argumentSpecification.getName(), actualType); + } + Expression expression = (Expression) argument.getValue(); + // 'descriptor' as a function name is not allowed in this context + if (argument.getValue() instanceof FunctionCall && ((FunctionCall) argument.getValue()).getName().hasSuffix(QualifiedName.of("decsriptor"))) { // function name is always compared case-insensitive + throw semanticException(INVALID_FUNCTION_ARGUMENT, argument, "'descriptor' function is not allowed as a table function argument"); + } + Type expectedArgumentType = ((ScalarArgumentSpecification) argumentSpecification).getType(); + // currently, only constant arguments are supported + Object constantValue = ExpressionInterpreter.evaluateConstantExpression(expression, expectedArgumentType, plannerContext, session, accessControl, analysis.getParameters()); + return new ScalarArgument(expectedArgumentType, constantValue); // TODO test coercion, test parameter + } + + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + + private Argument analyzeDefault(ArgumentSpecification argumentSpecification, Node errorLocation) + { + if (argumentSpecification.isRequired()) { + throw semanticException(MISSING_ARGUMENT, errorLocation, "Missing argument: " + argumentSpecification.getName()); + } + + checkArgument(!(argumentSpecification instanceof TableArgumentSpecification), "invalid table argument specification: default set"); + + if (argumentSpecification instanceof DescriptorArgumentSpecification) { + throw semanticException(NOT_SUPPORTED, errorLocation, "Descriptor arguments are not yet supported for table functions"); + } + if (argumentSpecification instanceof ScalarArgumentSpecification) { + return new ScalarArgument(((ScalarArgumentSpecification) argumentSpecification).getType(), argumentSpecification.getDefaultValue()); + } + + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + private Optional getMaterializedViewStorageTableName(MaterializedViewDefinition viewDefinition) { if (viewDefinition.getStorageTable().isEmpty()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzerFactory.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzerFactory.java index 323c1b2ad4a6..eece35f09571 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzerFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzerFactory.java @@ -18,6 +18,7 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AnalyzePropertyManager; import io.trino.metadata.SessionPropertyManager; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; import io.trino.metadata.TablePropertyManager; @@ -25,6 +26,8 @@ import io.trino.spi.security.GroupProvider; import io.trino.sql.PlannerContext; import io.trino.sql.parser.SqlParser; +import io.trino.transaction.NoOpTransactionManager; +import io.trino.transaction.TransactionManager; import javax.inject.Inject; @@ -35,8 +38,10 @@ public class StatementAnalyzerFactory private final PlannerContext plannerContext; private final SqlParser sqlParser; private final AccessControl accessControl; + private final TransactionManager transactionManager; private final GroupProvider groupProvider; private final TableProceduresRegistry tableProceduresRegistry; + private final TableFunctionRegistry tableFunctionRegistry; private final SessionPropertyManager sessionPropertyManager; private final TablePropertyManager tablePropertyManager; private final AnalyzePropertyManager analyzePropertyManager; @@ -47,8 +52,10 @@ public StatementAnalyzerFactory( PlannerContext plannerContext, SqlParser sqlParser, AccessControl accessControl, + TransactionManager transactionManager, GroupProvider groupProvider, TableProceduresRegistry tableProceduresRegistry, + TableFunctionRegistry tableFunctionRegistry, SessionPropertyManager sessionPropertyManager, TablePropertyManager tablePropertyManager, AnalyzePropertyManager analyzePropertyManager, @@ -57,8 +64,10 @@ public StatementAnalyzerFactory( this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); this.tableProceduresRegistry = requireNonNull(tableProceduresRegistry, "tableProceduresRegistry is null"); + this.tableFunctionRegistry = requireNonNull(tableFunctionRegistry, "tableFunctionRegistry is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); this.analyzePropertyManager = requireNonNull(analyzePropertyManager, "analyzePropertyManager is null"); @@ -71,8 +80,10 @@ public StatementAnalyzerFactory withSpecializedAccessControl(AccessControl acces plannerContext, sqlParser, accessControl, + transactionManager, groupProvider, tableProceduresRegistry, + tableFunctionRegistry, sessionPropertyManager, tablePropertyManager, analyzePropertyManager, @@ -92,8 +103,10 @@ public StatementAnalyzer createStatementAnalyzer( sqlParser, groupProvider, accessControl, + transactionManager, session, tableProceduresRegistry, + tableFunctionRegistry, sessionPropertyManager, tablePropertyManager, analyzePropertyManager, @@ -112,8 +125,10 @@ public static StatementAnalyzerFactory createTestingStatementAnalyzerFactory( plannerContext, new SqlParser(), accessControl, + new NoOpTransactionManager(), user -> ImmutableSet.of(), new TableProceduresRegistry(), + new TableFunctionRegistry(), new SessionPropertyManager(), tablePropertyManager, analyzePropertyManager, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index 8823a74140ca..277bd8d36c41 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -25,6 +25,7 @@ import io.trino.metadata.OperatorNotFoundException; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.SessionPropertyManager; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; import io.trino.metadata.TablePropertyManager; @@ -64,6 +65,7 @@ import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; +import io.trino.transaction.NoOpTransactionManager; import io.trino.type.LikeFunctions; import io.trino.type.TypeCoercion; @@ -310,8 +312,10 @@ public static ExtractionResult getExtractionResult(PlannerContext plannerContext plannerContext, new SqlParser(), new AllowAllAccessControl(), + new NoOpTransactionManager(), user -> ImmutableSet.of(), new TableProceduresRegistry(), + new TableFunctionRegistry(), new SessionPropertyManager(), new TablePropertyManager(), new AnalyzePropertyManager(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index c9cec05ec5af..7a4f20886685 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -218,6 +218,7 @@ import io.trino.sql.planner.plan.TableDeleteNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; @@ -1645,6 +1646,12 @@ else if (inputSymbols.get(i).equals(matchNumberSymbol)) { return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty()); } + @Override + public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecutionPlanContext context) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 9e968402a424..89fc048639dd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -212,6 +212,7 @@ import io.trino.sql.planner.iterative.rule.ReplaceRedundantJoinWithSource; import io.trino.sql.planner.iterative.rule.ReplaceWindowWithRowNumber; import io.trino.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; +import io.trino.sql.planner.iterative.rule.RewriteTableFunctionToTableScan; import io.trino.sql.planner.iterative.rule.SimplifyCountOverConstant; import io.trino.sql.planner.iterative.rule.SimplifyExpressions; import io.trino.sql.planner.iterative.rule.SimplifyFilterPredicate; @@ -605,6 +606,7 @@ public PlanOptimizers( .add(new PushAggregationIntoTableScan(plannerContext, typeAnalyzer)) .add(new PushDistinctLimitIntoTableScan(plannerContext, typeAnalyzer)) .add(new PushTopNIntoTableScan(metadata)) + .add(new RewriteTableFunctionToTableScan(plannerContext)) .build(); IterativeOptimizer pushIntoTableScanOptimizer = new IterativeOptimizer( plannerContext, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 93a0ff9990bd..29dec22e387a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -19,13 +19,16 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.Session; +import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.ptf.NameAndPosition; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ExpressionUtils; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import io.trino.sql.analyzer.Analysis.UnnestAnalysis; import io.trino.sql.analyzer.Field; import io.trino.sql.analyzer.RelationType; @@ -42,6 +45,8 @@ import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SampleNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UnnestNode; @@ -85,6 +90,7 @@ import io.trino.sql.tree.SubqueryExpression; import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.Union; import io.trino.sql.tree.Unnest; @@ -319,6 +325,42 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext); } + @Override + protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, Void context) + { + TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + + // TODO handle input relations: + // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. + // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: + // - row or set semantics (from the actualArgument) + // - prune when empty property (from the actualArgument) + // - pass through columns property (from the actualArgument) + // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) + List sources = ImmutableList.of(); + List inputRelationsProperties = ImmutableList.of(); + // TODO rewrite column references to Symbols upon the source's RelationPlan (or combined RelationPlan from all sources) + Map inputDescriptorMappings = ImmutableMap.of(); + + Scope scope = analysis.getScope(node); + // TODO pass columns from input relations, and make sure they have the right qualifier + List outputSymbols = scope.getRelationType().getAllFields().stream() + .map(symbolAllocator::newSymbol) + .collect(toImmutableList()); + + PlanNode root = new TableFunctionNode( + idAllocator.getNextId(), + functionAnalysis.getFunctionName(), + functionAnalysis.getArguments(), + outputSymbols, + sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), + inputRelationsProperties, + inputDescriptorMappings, + new TableFunctionHandle(functionAnalysis.getCatalogName(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, scope, outputSymbols, outerContext); + } + @Override protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index 8036942bda29..2af7fe34fc2a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -20,6 +20,7 @@ import io.trino.metadata.AnalyzePropertyManager; import io.trino.metadata.OperatorNotFoundException; import io.trino.metadata.SessionPropertyManager; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; import io.trino.metadata.TablePropertyManager; @@ -50,6 +51,7 @@ import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.SymbolReference; +import io.trino.transaction.NoOpTransactionManager; import io.trino.type.TypeCoercion; import java.util.HashSet; @@ -95,8 +97,10 @@ public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) plannerContext, new SqlParser(), new AllowAllAccessControl(), + new NoOpTransactionManager(), user -> ImmutableSet.of(), new TableProceduresRegistry(), + new TableFunctionRegistry(), new SessionPropertyManager(), new TablePropertyManager(), new AnalyzePropertyManager(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java new file mode 100644 index 000000000000..0082c9f41953 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableMap; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.TableFunctionApplicationResult; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableScanNode; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.matching.Pattern.empty; +import static io.trino.sql.planner.plan.Patterns.sources; +import static io.trino.sql.planner.plan.Patterns.tableFunction; +import static java.util.Objects.requireNonNull; + +public class RewriteTableFunctionToTableScan + implements Rule +{ + private static final Pattern PATTERN = tableFunction() + .with(empty(sources())); + + private final PlannerContext plannerContext; + + public RewriteTableFunctionToTableScan(PlannerContext plannerContext) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode tableFunctionNode, Captures captures, Context context) + { + Optional> result = plannerContext.getMetadata().applyTableFunction(context.getSession(), tableFunctionNode.getHandle()); + + if (result.isEmpty()) { + return Result.empty(); + } + + List columnHandles = result.get().getColumnHandles(); + checkState(tableFunctionNode.getOutputSymbols().size() == columnHandles.size(), "returned table does not match the node's output"); + ImmutableMap.Builder assignments = ImmutableMap.builder(); + for (int i = 0; i < columnHandles.size(); i++) { + assignments.put(tableFunctionNode.getOutputSymbols().get(i), columnHandles.get(i)); + } + + return Result.ofPlanNode(new TableScanNode( + tableFunctionNode.getId(), + result.get().getTableHandle(), + tableFunctionNode.getOutputSymbols(), + assignments.buildOrThrow(), + TupleDomain.all(), + Optional.empty(), + false, + Optional.empty())); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 8ac2c80f2593..eb709af04022 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -68,6 +68,7 @@ import io.trino.sql.planner.plan.TableDeleteNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -358,6 +359,12 @@ public PlanWithProperties visitPatternRecognition(PatternRecognitionNode node, P return rebaseAndDeriveProperties(node, child); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredProperties preferredProperties) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties preferredProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index bc97d123799d..13bc34bef4b1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -72,6 +72,7 @@ import io.trino.sql.planner.plan.TableDeleteNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -314,6 +315,28 @@ public PlanAndMappings visitPatternRecognition(PatternRecognitionNode node, Unal return new PlanAndMappings(rewrittenPatternRecognition, mapping); } + @Override + public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext context) + { + // TODO rewrite sources, tableArgumentProperties, and inputDescriptorMappings when we add support for input tables + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); + + List newProperOutputs = mapper.map(node.getProperOutputs()); + + return new PlanAndMappings( + new TableFunctionNode( + node.getId(), + node.getName(), + node.getArguments(), + newProperOutputs, + node.getSources(), + node.getTableArgumentProperties(), + node.getInputDescriptorMappings(), + node.getHandle()), + mapping); + } + @Override public PlanAndMappings visitTableScan(TableScanNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index 1d455d975ff0..dcb71e51477a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -205,6 +205,11 @@ public static Pattern patternRecognition() return typeOf(PatternRecognitionNode.class); } + public static Pattern tableFunction() + { + return typeOf(TableFunctionNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index c523d36cefaa..5437f8c1993f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -67,6 +67,7 @@ @JsonSubTypes.Type(value = CorrelatedJoinNode.class, name = "correlatedJoin"), @JsonSubTypes.Type(value = StatisticsWriterNode.class, name = "statisticsWriterNode"), @JsonSubTypes.Type(value = PatternRecognitionNode.class, name = "patternRecognition"), + @JsonSubTypes.Type(value = TableFunctionNode.class, name = "tableFunction"), }) public abstract class PlanNode { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 118d5189fd95..aad0e80b70f5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -238,4 +238,9 @@ public R visitPatternRecognition(PatternRecognitionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunction(TableFunctionNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java new file mode 100644 index 000000000000..11f925e73ab0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -0,0 +1,172 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.metadata.TableFunctionHandle; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.NameAndPosition; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.WindowNode.Specification; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@Immutable +public class TableFunctionNode + extends PlanNode +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List sources; + private final List tableArgumentProperties; + private final Map inputDescriptorMappings; + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("sources") List sources, + @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("inputDescriptorMappings") Map inputDescriptorMappings, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(id); + this.name = requireNonNull(name, "name is null"); + this.arguments = requireNonNull(arguments, "arguments is null"); + this.properOutputs = requireNonNull(properOutputs, "properOutputs is null"); + this.sources = requireNonNull(sources, "sources is null"); + this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.inputDescriptorMappings = requireNonNull(inputDescriptorMappings, "inputDescriptorMappings is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Map getArguments() + { + return arguments; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public List getTableArgumentProperties() + { + return tableArgumentProperties; + } + + @JsonProperty + public Map getInputDescriptorMappings() + { + return inputDescriptorMappings; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return sources; + } + + @Override + public List getOutputSymbols() + { + // TODO add outputs from input relations + return properOutputs; + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitTableFunction(this, context); + } + + @Override + public PlanNode replaceChildren(List newSources) + { + checkArgument(sources.size() == newSources.size(), "wrong number of new children"); + return new TableFunctionNode(getId(), name, arguments, properOutputs, newSources, tableArgumentProperties, inputDescriptorMappings, handle); + } + + public static class TableArgumentProperties + { + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Specification specification; + + @JsonCreator + public TableArgumentProperties( + @JsonProperty("rowSemantics") boolean rowSemantics, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("specification") Specification specification) + { + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + } + + @JsonProperty + public boolean isRowSemantics() + { + return rowSemantics; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public boolean isPassThroughColumns() + { + return passThroughColumns; + } + + @JsonProperty + public Specification getSpecification() + { + return specification; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index df03b11d9ad6..ca8640294b9e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -62,6 +62,7 @@ import io.trino.sql.planner.plan.TableDeleteNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -213,6 +214,13 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Set bou return null; } + @Override + public Void visitTableFunction(TableFunctionNode node, Set context) + { + // TODO + return null; + } + @Override public Void visitWindow(WindowNode node, Set boundSymbols) { diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index c2a6756504a9..d548b3b8e583 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -96,6 +96,7 @@ import io.trino.metadata.Split; import io.trino.metadata.SystemFunctionBundle; import io.trino.metadata.SystemSecurityMetadata; +import io.trino.metadata.TableFunctionRegistry; import io.trino.metadata.TableHandle; import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TableProceduresRegistry; @@ -382,8 +383,10 @@ private LocalQueryRunner( plannerContext, sqlParser, accessControl, + transactionManager, groupProvider, tableProceduresRegistry, + new TableFunctionRegistry(), sessionPropertyManager, tablePropertyManager, analyzePropertyManager, @@ -423,6 +426,7 @@ private LocalQueryRunner( typeManager, new ProcedureRegistry(), tableProceduresRegistry, + new TableFunctionRegistry(), sessionPropertyManager, schemaPropertyManager, columnPropertyManager, diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index b97eea2b1a66..5742af206575 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -47,6 +47,7 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableColumnsMetadata; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; @@ -582,6 +583,12 @@ public Optional> applyJoin( return Optional.empty(); } + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return Optional.empty(); + } + // // Roles and Grants // diff --git a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java index 5dec1b241319..f3ea117e7d87 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java @@ -44,6 +44,7 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableColumnsMetadata; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; @@ -588,6 +589,12 @@ public Optional> applyTopN(Session session, T return delegate.applyTopN(session, handle, topNCount, sortItems, assignments); } + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return delegate.applyTableFunction(session, handle); + } + @Override public void validateScan(Session session, TableHandle table) { diff --git a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 index eaea9c8b8d17..b8b526dce5fa 100644 --- a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 +++ b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 @@ -405,9 +405,44 @@ relationPrimary | '(' query ')' #subqueryRelation | UNNEST '(' expression (',' expression)* ')' (WITH ORDINALITY)? #unnest | LATERAL '(' query ')' #lateral + | TABLE '(' tableFunctionCall ')' #tableFunctionInvocation | '(' relation ')' #parenthesizedRelation ; +tableFunctionCall + : qualifiedName '(' (tableFunctionArgument (',' tableFunctionArgument)*)? + (COPARTITION copartitionTables (',' copartitionTables)*)? ')' + ; + +tableFunctionArgument + : (identifier '=>')? (tableArgument | descriptorArgument | expression) // descriptor before expression to avoid parsing descriptor as a function call + ; + +tableArgument + : tableArgumentRelation + (PARTITION BY ('(' (expression (',' expression)*)? ')' | expression))? + (PRUNE WHEN EMPTY | KEEP WHEN EMPTY)? + (ORDER BY ('(' sortItem (',' sortItem)* ')' | sortItem))? + ; + +tableArgumentRelation + : TABLE '(' qualifiedName ')' (AS? identifier columnAliases?)? #tableArgumentTable + | TABLE '(' query ')' (AS? identifier columnAliases?)? #tableArgumentQuery + ; + +descriptorArgument + : DESCRIPTOR '(' descriptorField (',' descriptorField)* ')' + | CAST '(' NULL AS DESCRIPTOR ')' + ; + +descriptorField + : identifier type? + ; + +copartitionTables + : '(' qualifiedName ',' qualifiedName (',' qualifiedName)* ')' + ; + expression : booleanExpression ; @@ -717,19 +752,20 @@ nonReserved // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved : ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION | BERNOULLI | BOTH - | CALL | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | COUNT | CURRENT - | DATA | DATE | DAY | DEFAULT | DEFINE | DEFINER | DESC | DISTRIBUTED | DOUBLE + | CALL | CASCADE | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | COPARTITION | COUNT | CURRENT + | DATA | DATE | DAY | DEFAULT | DEFINE | DEFINER | DESC | DESCRIPTOR | DISTRIBUTED | DOUBLE | EMPTY | ERROR | EXCLUDING | EXPLAIN | FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTIONS | GRANT | DENY | GRANTED | GRANTS | GRAPHVIZ | GROUPS | HOUR | IF | IGNORE | INCLUDING | INITIAL | INPUT | INTERVAL | INVOKER | IO | ISOLATION | JSON + | KEEP | LAST | LATERAL | LEADING | LEVEL | LIMIT | LOCAL | LOGICAL | MAP | MATCH | MATCHED | MATCHES | MATCH_RECOGNIZE | MATERIALIZED | MEASURES | MERGE | MINUTE | MONTH | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | OVERFLOW - | PARTITION | PARTITIONS | PAST | PATH | PATTERN | PER | PERMUTE | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES + | PARTITION | PARTITIONS | PAST | PATH | PATTERN | PER | PERMUTE | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE | RANGE | READ | REFRESH | RENAME | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | RUNNING | SCHEMA | SCHEMAS | SECOND | SECURITY | SEEK | SERIALIZABLE | SESSION | SET | SETS | SHOW | SOME | START | STATS | SUBSET | SUBSTRING | SYSTEM @@ -770,6 +806,7 @@ COMMIT: 'COMMIT'; COMMITTED: 'COMMITTED'; CONSTRAINT: 'CONSTRAINT'; COUNT: 'COUNT'; +COPARTITION: 'COPARTITION'; CREATE: 'CREATE'; CROSS: 'CROSS'; CUBE: 'CUBE'; @@ -792,6 +829,7 @@ DELETE: 'DELETE'; DENY: 'DENY'; DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; +DESCRIPTOR: 'DESCRIPTOR'; DEFINE: 'DEFINE'; DISTINCT: 'DISTINCT'; DISTRIBUTED: 'DISTRIBUTED'; @@ -845,6 +883,7 @@ IS: 'IS'; ISOLATION: 'ISOLATION'; JOIN: 'JOIN'; JSON: 'JSON'; +KEEP: 'KEEP'; LAST: 'LAST'; LATERAL: 'LATERAL'; LEADING: 'LEADING'; @@ -907,6 +946,7 @@ PRECISION: 'PRECISION'; PREPARE: 'PREPARE'; PRIVILEGES: 'PRIVILEGES'; PROPERTIES: 'PROPERTIES'; +PRUNE: 'PRUNE'; RANGE: 'RANGE'; READ: 'READ'; RECURSIVE: 'RECURSIVE'; diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index 0ed4977a8061..a9624150b204 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -904,7 +904,7 @@ public static String formatOrderBy(OrderBy orderBy) return "ORDER BY " + formatSortItems(orderBy.getSortItems()); } - private static String formatSortItems(List sortItems) + public static String formatSortItems(List sortItems) { return Joiner.on(", ").join(sortItems.stream() .map(sortItemFormatterFunction()) diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index c096ba474277..188983c51aed 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -36,6 +36,7 @@ import io.trino.sql.tree.Deny; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.DescribeOutput; +import io.trino.sql.tree.DescriptorArgument; import io.trino.sql.tree.DropColumn; import io.trino.sql.tree.DropMaterializedView; import io.trino.sql.tree.DropRole; @@ -120,7 +121,10 @@ import io.trino.sql.tree.SingleColumn; import io.trino.sql.tree.StartTransaction; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableArgument; import io.trino.sql.tree.TableExecute; +import io.trino.sql.tree.TableFunctionArgument; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TransactionAccessMode; import io.trino.sql.tree.TransactionMode; @@ -227,6 +231,114 @@ protected Void visitLateral(Lateral node, Integer indent) return null; } + @Override + protected Void visitTableFunctionInvocation(TableFunctionInvocation node, Integer indent) + { + append(indent, "TABLE("); + appendTableFunctionInvocation(node, indent + 1); + builder.append(")"); + return null; + } + + private void appendTableFunctionInvocation(TableFunctionInvocation node, Integer indent) + { + builder.append(formatName(node.getName())) + .append("(\n"); + appendTableFunctionArguments(node.getArguments(), indent + 1); + if (!node.getCopartitioning().isEmpty()) { + builder.append("\n"); + append(indent + 1, "COPARTITION "); + builder.append(node.getCopartitioning().stream() + .map(tableList -> tableList.stream() + .map(SqlFormatter::formatName) + .collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", "))); + } + builder.append(")"); + } + + private void appendTableFunctionArguments(List arguments, int indent) + { + for (int i = 0; i < arguments.size(); i++) { + TableFunctionArgument argument = arguments.get(i); + if (argument.getName().isPresent()) { + append(indent, formatExpression(argument.getName().get())); + builder.append(" => "); + } + else { + append(indent, ""); + } + Node value = argument.getValue(); + if (value instanceof Expression) { + builder.append(formatExpression((Expression) value)); + } + else { + process(value, indent + 1); + } + if (i < arguments.size() - 1) { + builder.append(",\n"); + } + } + } + + @Override + protected Void visitTableArgument(TableArgument node, Integer indent) + { + Relation relation = node.getTable(); + Relation unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + builder.append("TABLE("); + process(unaliased, indent); + builder.append(")"); + if (relation instanceof AliasedRelation) { + AliasedRelation aliasedRelation = (AliasedRelation) relation; + builder.append(" AS ") + .append(formatExpression(aliasedRelation.getAlias())); + appendAliasColumns(builder, aliasedRelation.getColumnNames()); + } + if (node.getPartitionBy().isPresent()) { + builder.append("\n"); + append(indent, "PARTITION BY ") + .append(node.getPartitionBy().get().stream() + .map(ExpressionFormatter::formatExpression) + .collect(joining(", "))); + } + if (node.isPruneWhenEmpty()) { + builder.append("\n"); + append(indent, "PRUNE WHEN EMPTY"); + } + else { + builder.append("\n"); + append(indent, "KEEP WHEN EMPTY"); + } + node.getOrderBy().ifPresent(orderBy -> { + builder.append("\n"); + append(indent, formatOrderBy(orderBy)); + }); + + return null; + } + + @Override + protected Void visitDescriptorArgument(DescriptorArgument node, Integer indent) + { + if (node.getDescriptor().isPresent()) { + builder.append(node.getDescriptor().get().getFields().stream() + .map(field -> { + String formattedField = formatExpression(field.getName()); + if (field.getType().isPresent()) { + formattedField = formattedField + " " + formatExpression(field.getType().get()); + } + return formattedField; + }) + .collect(Collectors.joining(", ", "DESCRIPTOR(", ")"))); + } + else { + builder.append("CAST (NULL AS DESCRIPTOR)"); + } + + return null; + } + @Override protected Void visitPrepare(Prepare node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 7d111bd3a3e4..01b246c2354c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -62,6 +62,8 @@ import io.trino.sql.tree.DereferenceExpression; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.DescribeOutput; +import io.trino.sql.tree.Descriptor; +import io.trino.sql.tree.DescriptorField; import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.DropColumn; import io.trino.sql.tree.DropMaterializedView; @@ -208,8 +210,11 @@ import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableArgument; import io.trino.sql.tree.TableElement; import io.trino.sql.tree.TableExecute; +import io.trino.sql.tree.TableFunctionArgument; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TimeLiteral; import io.trino.sql.tree.TimestampLiteral; @@ -256,6 +261,8 @@ import static io.trino.sql.parser.SqlBaseParser.TIMESTAMP; import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_END; import static io.trino.sql.tree.AnchorPattern.Type.PARTITION_START; +import static io.trino.sql.tree.DescriptorArgument.descriptorArgument; +import static io.trino.sql.tree.DescriptorArgument.nullDescriptorArgument; import static io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch.ALL_OMIT_EMPTY; import static io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch.ALL_SHOW_EMPTY; import static io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch.ALL_WITH_UNMATCHED; @@ -1773,6 +1780,117 @@ public Node visitLateral(SqlBaseParser.LateralContext context) return new Lateral(getLocation(context), (Query) visit(context.query())); } + @Override + public Node visitTableFunctionInvocation(SqlBaseParser.TableFunctionInvocationContext context) + { + return visit(context.tableFunctionCall()); + } + + @Override + public Node visitTableFunctionCall(SqlBaseParser.TableFunctionCallContext context) + { + QualifiedName name = getQualifiedName(context.qualifiedName()); + List arguments = visit(context.tableFunctionArgument(), TableFunctionArgument.class); + List> copartitioning = ImmutableList.of(); + if (context.COPARTITION() != null) { + copartitioning = context.copartitionTables().stream() + .map(tablesList -> tablesList.qualifiedName().stream() + .map(this::getQualifiedName) + .collect(toImmutableList())) + .collect(toImmutableList()); + } + + return new TableFunctionInvocation(getLocation(context), name, arguments, copartitioning); + } + + @Override + public Node visitTableFunctionArgument(SqlBaseParser.TableFunctionArgumentContext context) + { + Optional name = visitIfPresent(context.identifier(), Identifier.class); + Node value; + if (context.tableArgument() != null) { + value = visit(context.tableArgument()); + } + else if (context.descriptorArgument() != null) { + value = visit(context.descriptorArgument()); + } + else { + value = visit(context.expression()); + } + + return new TableFunctionArgument(getLocation(context), name, value); + } + + @Override + public Node visitTableArgument(SqlBaseParser.TableArgumentContext context) + { + Relation table = (Relation) visit(context.tableArgumentRelation()); + + Optional> partitionBy = Optional.empty(); + if (context.PARTITION() != null) { + partitionBy = Optional.of(visit(context.expression(), Expression.class)); + } + + Optional orderBy = Optional.empty(); + if (context.ORDER() != null) { + orderBy = Optional.of(new OrderBy(visit(context.sortItem(), SortItem.class))); + } + + boolean pruneWhenEmpty = context.PRUNE() != null; + + return new TableArgument(getLocation(context), table, partitionBy, orderBy, pruneWhenEmpty); + } + + @Override + public Node visitTableArgumentTable(SqlBaseParser.TableArgumentTableContext context) + { + Relation relation = new Table(getLocation(context.TABLE()), getQualifiedName(context.qualifiedName())); + + if (context.identifier() != null) { + Identifier alias = (Identifier) visit(context.identifier()); + List columnNames = ImmutableList.of(); + if (context.columnAliases() != null) { + columnNames = visit(context.columnAliases().identifier(), Identifier.class); + } + relation = new AliasedRelation(getLocation(context.TABLE()), relation, alias, columnNames); + } + + return relation; + } + + @Override + public Node visitTableArgumentQuery(SqlBaseParser.TableArgumentQueryContext context) + { + Relation relation = new TableSubquery(getLocation(context.TABLE()), (Query) visit(context.query())); + + if (context.identifier() != null) { + Identifier alias = (Identifier) visit(context.identifier()); + List columnNames = ImmutableList.of(); + if (context.columnAliases() != null) { + columnNames = visit(context.columnAliases().identifier(), Identifier.class); + } + relation = new AliasedRelation(getLocation(context.TABLE()), relation, alias, columnNames); + } + + return relation; + } + + @Override + public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext context) + { + if (context.NULL() != null) { + return nullDescriptorArgument(getLocation(context)); + } + List fields = visit(context.descriptorField(), DescriptorField.class); + return descriptorArgument(getLocation(context), new Descriptor(getLocation(context.DESCRIPTOR()), fields)); + } + + @Override + public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) + { + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), visitIfPresent(context.type(), DataType.class)); + } + @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 9ac75d904145..38da35912c40 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -1086,4 +1086,34 @@ protected R visitQueryPeriod(QueryPeriod node, C context) { return visitNode(node, context); } + + protected R visitTableFunctionInvocation(TableFunctionInvocation node, C context) + { + return visitRelation(node, context); + } + + protected R visitTableFunctionArgument(TableFunctionArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitTableArgument(TableArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptorArgument(DescriptorArgument node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptor(Descriptor node, C context) + { + return visitNode(node, context); + } + + protected R visitDescriptorField(DescriptorField node, C context) + { + return visitNode(node, context); + } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Descriptor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Descriptor.java new file mode 100644 index 000000000000..80d3d3ec1d54 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Descriptor.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class Descriptor + extends Node +{ + private final List fields; + + public Descriptor(NodeLocation location, List fields) + { + super(Optional.of(location)); + requireNonNull(fields, "fields is null"); + checkArgument(!fields.isEmpty(), "fields list is empty"); + this.fields = fields; + } + + public List getFields() + { + return fields; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptor(this, context); + } + + @Override + public List getChildren() + { + return fields; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return Objects.equals(fields, ((Descriptor) o).fields); + } + + @Override + public int hashCode() + { + return Objects.hash(fields); + } + + @Override + public String toString() + { + return fields.stream() + .map(DescriptorField::toString) + .collect(Collectors.joining(", ", "DESCRIPTOR(", ")")); + } + + @Override + public boolean shallowEquals(Node o) + { + return sameClass(this, o); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorArgument.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorArgument.java new file mode 100644 index 000000000000..c5f1f21a1b9a --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorArgument.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class DescriptorArgument + extends Node +{ + private final Optional descriptor; + + public static DescriptorArgument descriptorArgument(NodeLocation location, Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + return new DescriptorArgument(location, Optional.of(descriptor)); + } + + public static DescriptorArgument nullDescriptorArgument(NodeLocation location) + { + return new DescriptorArgument(location, Optional.empty()); + } + + private DescriptorArgument(NodeLocation location, Optional descriptor) + { + super(Optional.of(location)); + this.descriptor = descriptor; + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptorArgument(this, context); + } + + @Override + public List getChildren() + { + return descriptor.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return Objects.equals(descriptor, ((DescriptorArgument) o).descriptor); + } + + @Override + public int hashCode() + { + return Objects.hash(descriptor); + } + + @Override + public String toString() + { + return descriptor.map(Descriptor::toString).orElse("CAST (NULL AS DESCRIPTOR)"); + } + + @Override + public boolean shallowEquals(Node o) + { + return sameClass(this, o); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorField.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorField.java new file mode 100644 index 000000000000..47d772dc6e4b --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DescriptorField.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class DescriptorField + extends Node +{ + private final Identifier name; + private final Optional type; + + public DescriptorField(NodeLocation location, Identifier name, Optional type) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + } + + public Identifier getName() + { + return name; + } + + public Optional getType() + { + return type; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDescriptorField(this, context); + } + + @Override + public List getChildren() + { + return type.map(ImmutableList::of) + .orElse(ImmutableList.of()); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DescriptorField field = (DescriptorField) o; + return Objects.equals(name, field.name) && + Objects.equals(type, (field.type)); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + + @Override + public String toString() + { + return type.map(dataType -> name + " " + dataType).orElse(name.toString()); + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + return Objects.equals(name, ((DescriptorField) o).name); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/TableArgument.java b/core/trino-parser/src/main/java/io/trino/sql/tree/TableArgument.java new file mode 100644 index 000000000000..6431d9d6b627 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/TableArgument.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.trino.sql.ExpressionFormatter.formatSortItems; +import static java.util.Objects.requireNonNull; + +public class TableArgument + extends Node +{ + private final Relation table; + private final Optional> partitionBy; // it is allowed to partition by empty list + private final Optional orderBy; + private final boolean pruneWhenEmpty; + + public TableArgument( + NodeLocation location, + Relation table, + Optional> partitionBy, + Optional orderBy, + boolean pruneWhenEmpty) + { + super(Optional.of(location)); + this.table = requireNonNull(table, "table is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + } + + public Relation getTable() + { + return table; + } + + public Optional> getPartitionBy() + { + return partitionBy; + } + + public Optional getOrderBy() + { + return orderBy; + } + + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableArgument(this, context); + } + + @Override + public List getChildren() + { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add(table); + partitionBy.ifPresent(builder::addAll); + orderBy.ifPresent(builder::add); + + return builder.build(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableArgument other = (TableArgument) o; + return Objects.equals(table, other.table) && + Objects.equals(partitionBy, other.partitionBy) && + Objects.equals(orderBy, other.orderBy) && + pruneWhenEmpty == other.pruneWhenEmpty; + } + + @Override + public int hashCode() + { + return Objects.hash(table, partitionBy, orderBy, pruneWhenEmpty); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append(table); + partitionBy.ifPresent(partitioning -> builder.append(partitioning.stream() + .map(Expression::toString) + .collect(Collectors.joining(", ", " PARTITION BY (", ")")))); + orderBy.ifPresent(ordering -> builder.append(" ORDER BY (") + .append(formatSortItems(ordering.getSortItems())) + .append(")")); + + return builder.toString(); + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + return pruneWhenEmpty == ((TableArgument) o).pruneWhenEmpty; + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionArgument.java b/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionArgument.java new file mode 100644 index 000000000000..b403313cfd59 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionArgument.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class TableFunctionArgument + extends Node +{ + private final Optional name; + private final Node value; + + public TableFunctionArgument(NodeLocation location, Optional name, Node value) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + requireNonNull(value, "value is null"); + checkArgument(value instanceof TableArgument || value instanceof DescriptorArgument || value instanceof Expression); + this.value = value; + } + + public Optional getName() + { + return name; + } + + public Node getValue() + { + return value; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableFunctionArgument(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(value); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableFunctionArgument other = (TableFunctionArgument) o; + return Objects.equals(name, other.name) && + Objects.equals(value, other.value); + } + + @Override + public int hashCode() + { + return Objects.hash(name, value); + } + + @Override + public String toString() + { + return name.map(identifier -> identifier + " => ").orElse("") + value; + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + return Objects.equals(name, ((TableFunctionArgument) o).name); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionInvocation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionInvocation.java new file mode 100644 index 000000000000..08ae9b22ad74 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/TableFunctionInvocation.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionInvocation + extends Relation +{ + private final QualifiedName name; + private final List arguments; + private final List> copartitioning; + + public TableFunctionInvocation(NodeLocation location, QualifiedName name, List arguments, List> copartitioning) + { + super(Optional.of(location)); + this.name = requireNonNull(name, "name is null"); + this.arguments = requireNonNull(arguments, "arguments is null"); + this.copartitioning = requireNonNull(copartitioning, "copartitioning is null"); + } + + public QualifiedName getName() + { + return name; + } + + public List getArguments() + { + return arguments; + } + + public List> getCopartitioning() + { + return copartitioning; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitTableFunctionInvocation(this, context); + } + + @Override + public List getChildren() + { + return arguments; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TableFunctionInvocation that = (TableFunctionInvocation) o; + return Objects.equals(name, that.name) && + Objects.equals(arguments, that.arguments) && + Objects.equals(copartitioning, that.copartitioning); + } + + @Override + public int hashCode() + { + return Objects.hash(name, arguments, copartitioning); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + builder.append(name) + .append("("); + builder.append(arguments.stream() + .map(TableFunctionArgument::toString) + .collect(Collectors.joining(", "))); + if (!copartitioning.isEmpty()) { + builder.append(" COPARTITION"); + builder.append(copartitioning.stream() + .map(list -> list.stream() + .map(QualifiedName::toString) + .collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", "))); + } + builder.append(")"); + + return builder.toString(); + } + + @Override + public boolean shallowEquals(Node o) + { + if (!sameClass(this, o)) { + return false; + } + + TableFunctionInvocation other = (TableFunctionInvocation) o; + return Objects.equals(name, other.name) && + Objects.equals(copartitioning, other.copartitioning); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 7310e49515d2..a223283f6ac9 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -53,6 +53,8 @@ import io.trino.sql.tree.DereferenceExpression; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.DescribeOutput; +import io.trino.sql.tree.Descriptor; +import io.trino.sql.tree.DescriptorField; import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.DropColumn; import io.trino.sql.tree.DropMaterializedView; @@ -73,6 +75,7 @@ import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.FunctionCall.NullTreatment; +import io.trino.sql.tree.GenericDataType; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.Grant; import io.trino.sql.tree.GrantOnType; @@ -132,6 +135,7 @@ import io.trino.sql.tree.QuerySpecification; import io.trino.sql.tree.RangeQuantifier; import io.trino.sql.tree.RefreshMaterializedView; +import io.trino.sql.tree.Relation; import io.trino.sql.tree.RenameColumn; import io.trino.sql.tree.RenameMaterializedView; import io.trino.sql.tree.RenameSchema; @@ -174,7 +178,10 @@ import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableArgument; import io.trino.sql.tree.TableExecute; +import io.trino.sql.tree.TableFunctionArgument; +import io.trino.sql.tree.TableFunctionInvocation; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TimeLiteral; import io.trino.sql.tree.TimestampLiteral; @@ -240,6 +247,8 @@ import static io.trino.sql.tree.ArithmeticUnaryExpression.positive; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.tree.DateTimeDataType.Type.TIMESTAMP; +import static io.trino.sql.tree.DescriptorArgument.descriptorArgument; +import static io.trino.sql.tree.DescriptorArgument.nullDescriptorArgument; import static io.trino.sql.tree.FrameBound.Type.CURRENT_ROW; import static io.trino.sql.tree.FrameBound.Type.FOLLOWING; import static io.trino.sql.tree.PatternSearchMode.Mode.SEEK; @@ -247,6 +256,7 @@ import static io.trino.sql.tree.ProcessingMode.Mode.RUNNING; import static io.trino.sql.tree.SetProperties.Type.MATERIALIZED_VIEW; import static io.trino.sql.tree.SkipTo.skipToNextRow; +import static io.trino.sql.tree.SortItem.NullOrdering.LAST; import static io.trino.sql.tree.SortItem.NullOrdering.UNDEFINED; import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; import static io.trino.sql.tree.SortItem.Ordering.DESCENDING; @@ -3857,6 +3867,101 @@ public void testListagg() new BooleanLiteral("false")))); } + @Test + public void testTableFunctionInvocation() + { + assertThat(statement("SELECT * FROM TABLE(some_ptf(input => 1))")) + .isEqualTo(selectAllFrom(new TableFunctionInvocation( + location(1, 21), + qualifiedName(location(1, 21), "some_ptf"), + ImmutableList.of(new TableFunctionArgument( + location(1, 30), + Optional.of(new Identifier(location(1, 30), "input", false)), + new LongLiteral(location(1, 39), "1"))), + ImmutableList.of()))); + + assertThat(statement("SELECT * FROM TABLE(some_ptf(" + + " arg1 => TABLE(orders) AS ord(a, b, c) " + + " PARTITION BY a " + + " PRUNE WHEN EMPTY " + + " ORDER BY b ASC NULLS LAST, " + + " arg2 => CAST(NULL AS DESCRIPTOR), " + + " arg3 => DESCRIPTOR(x integer, y varchar), " + + " arg4 => 5, " + + " 'not-named argument' " + + " COPARTITION (ord, nation)))")) + .isEqualTo(selectAllFrom(new TableFunctionInvocation( + location(1, 21), + qualifiedName(location(1, 21), "some_ptf"), + ImmutableList.of( + new TableFunctionArgument( + location(1, 77), + Optional.of(new Identifier(location(1, 77), "arg1", false)), + new TableArgument( + location(1, 85), + new AliasedRelation( + location(1, 85), + new Table(location(1, 85), qualifiedName(location(1, 91), "orders")), + new Identifier(location(1, 102), "ord", false), + ImmutableList.of( + new Identifier(location(1, 106), "a", false), + new Identifier(location(1, 109), "b", false), + new Identifier(location(1, 112), "c", false))), + Optional.of(ImmutableList.of(new Identifier(location(1, 196), "a", false))), + Optional.of(new OrderBy(ImmutableList.of(new SortItem(location(1, 360), new Identifier(location(1, 360), "b", false), ASCENDING, LAST)))), + true)), + new TableFunctionArgument( + location(1, 425), + Optional.of(new Identifier(location(1, 425), "arg2", false)), + nullDescriptorArgument(location(1, 433))), + new TableFunctionArgument( + location(1, 506), + Optional.of(new Identifier(location(1, 506), "arg3", false)), + descriptorArgument( + location(1, 514), + new Descriptor(location(1, 514), ImmutableList.of( + new DescriptorField( + location(1, 525), + new Identifier(location(1, 525), "x", false), + Optional.of(new GenericDataType(location(1, 527), new Identifier(location(1, 527), "integer", false), ImmutableList.of()))), + new DescriptorField( + location(1, 536), + new Identifier(location(1, 536), "y", false), + Optional.of(new GenericDataType(location(1, 538), new Identifier(location(1, 538), "varchar", false), ImmutableList.of()))))))), + new TableFunctionArgument( + location(1, 595), + Optional.of(new Identifier(location(1, 595), "arg4", false)), + new LongLiteral(location(1, 603), "5")), + new TableFunctionArgument( + location(1, 653), + Optional.empty(), + new StringLiteral(location(1, 653), "not-named argument"))), + ImmutableList.of(ImmutableList.of( + qualifiedName(location(1, 734), "ord"), + qualifiedName(location(1, 739), "nation")))))); + } + + private static Query selectAllFrom(Relation relation) + { + return new Query( + location(1, 1), + Optional.empty(), + new QuerySpecification( + location(1, 1), + new Select(location(1, 1), false, ImmutableList.of(new AllColumns(location(1, 8), Optional.empty(), ImmutableList.of()))), + Optional.of(relation), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + private static QualifiedName makeQualifiedName(String tableName) { List parts = Splitter.on('.').splitToList(tableName).stream() diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index 481de0d5f136..c1230b8adb5e 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -50,7 +50,7 @@ private static Stream statements() Arguments.of("select * from foo where @what", "line 1:25: mismatched input '@'. Expecting: "), Arguments.of("select * from 'oops", - "line 1:15: mismatched input '''. Expecting: '(', 'LATERAL', 'UNNEST', "), + "line 1:15: mismatched input '''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "), Arguments.of("select *\nfrom x\nfrom", "line 3:1: mismatched input 'from'. Expecting: ',', '.', 'AS', 'CROSS', 'EXCEPT', 'FETCH', 'FOR', 'FULL', 'GROUP', 'HAVING', 'INNER', 'INTERSECT', 'JOIN', 'LEFT', " + "'LIMIT', 'MATCH_RECOGNIZE', 'NATURAL', 'OFFSET', 'ORDER', 'RIGHT', 'TABLESAMPLE', 'UNION', 'WHERE', 'WINDOW', , "), @@ -59,9 +59,9 @@ private static Stream statements() Arguments.of("select ", "line 1:8: mismatched input ''. Expecting: '*', 'ALL', 'DISTINCT', "), Arguments.of("select * from", - "line 1:14: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', "), + "line 1:14: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "), Arguments.of("select * from ", - "line 1:16: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', "), + "line 1:16: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', "), Arguments.of("select * from `foo`", "line 1:15: backquoted identifiers are not supported; use double quotes to quote identifiers"), Arguments.of("select * from foo `bar`", @@ -115,7 +115,7 @@ private static Stream statements() Arguments.of("CREATE TABLE t (x bigint) COMMENT ", "line 1:35: mismatched input ''. Expecting: "), Arguments.of("SELECT * FROM ( ", - "line 1:17: mismatched input ''. Expecting: '(', 'LATERAL', 'UNNEST', , "), + "line 1:17: mismatched input ''. Expecting: '(', 'LATERAL', 'TABLE', 'UNNEST', , "), Arguments.of("SELECT CAST(a AS )", "line 1:18: mismatched input ')'. Expecting: "), Arguments.of("SELECT CAST(a AS decimal()", diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 969786729f3c..845b96f9e346 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -130,6 +130,9 @@ public enum StandardErrorCode MISSING_ROW_PATTERN(106, USER_ERROR), INVALID_WINDOW_MEASURE(107, USER_ERROR), STACK_OVERFLOW(108, USER_ERROR), + MISSING_RETURN_TYPE(109, USER_ERROR), + AMBIGUOUS_RETURN_TYPE(110, USER_ERROR), + MISSING_ARGUMENT(111, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java index b7463fc4fc40..91192179a66b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java @@ -15,6 +15,7 @@ import io.trino.spi.eventlistener.EventListener; import io.trino.spi.procedure.Procedure; +import io.trino.spi.ptf.ConnectorTableFunction; import io.trino.spi.session.PropertyMetadata; import io.trino.spi.transaction.IsolationLevel; @@ -67,6 +68,7 @@ default ConnectorMetadata getMetadata(ConnectorSession session, ConnectorTransac /** * Guaranteed to be called at most once per transaction. The returned metadata will only be accessed * in a single threaded context. + * * @deprecated use {@link #getMetadata(ConnectorSession, ConnectorTransactionHandle)} */ @Deprecated @@ -144,6 +146,14 @@ default Set getTableProcedures() return emptySet(); } + /** + * @return the set of table functions provided by this connector + */ + default Set getTableFunctions() + { + return emptySet(); + } + /** * @return the system properties for this connector */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index 29e453ae5638..a0030b64ae84 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -19,6 +19,7 @@ import io.trino.spi.expression.Constant; import io.trino.spi.expression.Variable; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Privilege; import io.trino.spi.security.RoleGrant; @@ -684,7 +685,6 @@ default void finishDelete(ConnectorSession session, ConnectorTableHandle tableHa * operation, in table column order. * @return a ConnectorTableHandle that will be passed to split generation, and to the * {@link #finishUpdate} method. - * * @deprecated Use {@link #beginUpdate(ConnectorSession, ConnectorTableHandle, List, RetryMode)} instead. */ @Deprecated @@ -1251,6 +1251,19 @@ default Optional> applyTopN( return Optional.empty(); } + /** + * Attempt to push down the table function invocation into the connector. + *

+ * Connectors can indicate whether they don't support table function invocation pushdown or that the action had no + * effect by returning {@link Optional#empty()}. Connectors should expect this method may be called multiple times. + *

+ * If the method returns a result, the returned table handle will be used in place of the table function invocation. + */ + default Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + return Optional.empty(); + } + /** * Allows the connector to reject the table scan produced by the planner. *

diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/TableFunctionApplicationResult.java b/core/trino-spi/src/main/java/io/trino/spi/connector/TableFunctionApplicationResult.java new file mode 100644 index 000000000000..6b5cb4f012f8 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/TableFunctionApplicationResult.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionApplicationResult +{ + private final T tableHandle; + private final List columnHandles; + + public TableFunctionApplicationResult(T tableHandle, List columnHandles) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); + } + + public T getTableHandle() + { + return tableHandle; + } + + public List getColumnHandles() + { + return columnHandles; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/Argument.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/Argument.java new file mode 100644 index 000000000000..bc02ecac922c --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/Argument.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.trino.spi.expression.ConnectorExpression; + +/** + * This class represents the three types of arguments passed to a Table Function: + * scalar arguments, descriptor arguments, and table arguments. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include the Table Function arguments. + */ +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = DescriptorArgument.class, name = "descriptor"), + @JsonSubTypes.Type(value = ScalarArgument.class, name = "scalar"), + @JsonSubTypes.Type(value = TableArgument.class, name = "table"), +}) +public abstract class Argument +{ +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ArgumentSpecification.java new file mode 100644 index 000000000000..3f99872897d4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ArgumentSpecification.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import javax.annotation.Nullable; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkArgument; +import static io.trino.spi.ptf.ConnectorTableFunction.checkNotNullOrEmpty; + +/** + * Abstract class to capture the three supported argument types for a table function: + * - Table arguments + * - Descriptor arguments + * - SQL scalar arguments + *

+ * Each argument is named, and either passed positionally or in a `arg_name => value` convention. + *

+ * Default values are allowed for all arguments except Table arguments. + */ +public abstract class ArgumentSpecification +{ + private final String name; + private final boolean required; + + // native representation + private final Object defaultValue; + + public ArgumentSpecification(String name, boolean required, @Nullable Object defaultValue) + { + this.name = checkNotNullOrEmpty(name, "name"); + checkArgument(!required || defaultValue == null, "non-null default value for a required argument"); + this.required = required; + this.defaultValue = defaultValue; + } + + public String getName() + { + return name; + } + + public boolean isRequired() + { + return required; + } + + public Object getDefaultValue() + { + return defaultValue; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java new file mode 100644 index 000000000000..d61622bd0cb4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTransactionHandle; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public abstract class ConnectorTableFunction +{ + private final String schema; + private final String name; + private final List arguments; + private final ReturnTypeSpecification returnTypeSpecification; + + public ConnectorTableFunction(String schema, String name, List arguments, ReturnTypeSpecification returnTypeSpecification) + { + this.schema = checkNotNullOrEmpty(schema, "schema"); + this.name = checkNotNullOrEmpty(name, "name"); + requireNonNull(arguments, "arguments is null"); + Set argumentNames = new HashSet<>(); + for (ArgumentSpecification specification : arguments) { + if (!argumentNames.add(specification.getName())) { + throw new IllegalArgumentException("duplicate argument name: " + specification.getName()); + } + } + long tableArgumentsWithRowSemantics = arguments.stream() + .filter(specification -> specification instanceof TableArgumentSpecification) + .map(TableArgumentSpecification.class::cast) + .filter(TableArgumentSpecification::isRowSemantics) + .count(); + checkArgument(tableArgumentsWithRowSemantics <= 1, "more than one table argument with row semantics"); + this.arguments = List.copyOf(arguments); + this.returnTypeSpecification = requireNonNull(returnTypeSpecification, "returnTypeSpecification is null"); + } + + public String getSchema() + { + return schema; + } + + public String getName() + { + return name; + } + + public List getArguments() + { + return arguments; + } + + public ReturnTypeSpecification getReturnTypeSpecification() + { + return returnTypeSpecification; + } + + /** + * This method is called by the Analyzer. Its main purposes are to: + * 1. Determine the resulting relation type of the Table Function in case when the declared return type is GENERIC_TABLE. + * 2. Declare the dependencies between the input descriptors and the input tables. + * 3. Perform function-specific validation and pre-processing of the input arguments. + * As part of function-specific validation, the Table Function's author might want to: + * - check if the descriptors which reference input tables contain a correct number of column references + * - check if the referenced input columns have appropriate types to fit the function's logic // TODO return request for coercions to the Analyzer in the Analysis object + * - if there is a descriptor which describes the function's output, check if it matches the shape of the actual function's output + * - for table arguments, check the number and types of ordering columns + *

+ * The actual argument values, and the pre-processing results can be stored in an ConnectorTableFunctionHandle + * object, which will be passed along with the Table Function invocation through subsequent phases of planning. + * + * @param arguments actual invocation arguments, mapped by argument names + */ + public Analysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + throw new UnsupportedOperationException("analyze method not implemented for Table Function " + name); + } + + /** + * The `analyze()` method should produce an object of this class, containing all the analysis results: + *

+ * The `returnedType` field is used to inform the Analyzer of the proper columns returned by the Table + * Function, that is, the columns produced by the function, as opposed to the columns passed from the + * input tables. The `returnedType` should only be set if the declared returned type is GENERIC_TABLE. + *

+ * The `descriptorsToTables` field is used to inform the Analyzer of the semantics of descriptor arguments. + * Some descriptor arguments (or some of their fields) might be references to columns of the input tables. + * In such case, the Analyzer must be informed of those dependencies. It allows to pass the right values + * (input channels) to the Table Function during execution. It also allows to prune unused input columns + * during the optimization phase. + *

+ * The `handle` field can be used to carry all information necessary to execute the table function, + * gathered at analysis time. Typically, these are the values of the constant arguments, and results + * of pre-processing arguments. + */ + public static class Analysis + { + private final Optional returnedType; + private final DescriptorMapping descriptorsToTables; + private final ConnectorTableFunctionHandle handle; + + public Analysis(Optional returnedType, DescriptorMapping descriptorsToTables, ConnectorTableFunctionHandle handle) + { + this.returnedType = requireNonNull(returnedType, "returnedType is null"); + returnedType.ifPresent(descriptor -> checkArgument(descriptor.isTyped(), "field types not specified")); + this.descriptorsToTables = requireNonNull(descriptorsToTables, "descriptorsToTables is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + public Optional getReturnedType() + { + return returnedType; + } + + public DescriptorMapping getDescriptorsToTables() + { + return descriptorsToTables; + } + + public ConnectorTableFunctionHandle getHandle() + { + return handle; + } + } + + static String checkNotNullOrEmpty(String value, String name) + { + requireNonNull(value, name + " is null"); + checkArgument(!value.isEmpty(), name + " is empty"); + return value; + } + + static void checkArgument(boolean assertion, String message) + { + if (!assertion) { + throw new IllegalArgumentException(message); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunctionHandle.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunctionHandle.java new file mode 100644 index 000000000000..5c6aececd1fe --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunctionHandle.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +/** + * An area to store all information necessary to execute the table function, gathered at analysis time + */ +public interface ConnectorTableFunctionHandle +{ +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/Descriptor.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/Descriptor.java new file mode 100644 index 000000000000..f0ccd4971160 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/Descriptor.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.type.Type; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkArgument; +import static io.trino.spi.ptf.ConnectorTableFunction.checkNotNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class Descriptor +{ + private final List fields; + + @JsonCreator + public Descriptor(@JsonProperty("fields") List fields) + { + requireNonNull(fields, "fields is null"); + checkArgument(!fields.isEmpty(), "descriptor has no fields"); + this.fields = List.copyOf(fields); + } + + public static Descriptor descriptor(String... names) + { + List fields = Arrays.stream(names) + .map(name -> new Field(name, Optional.empty())) + .collect(Collectors.toList()); + return new Descriptor(fields); + } + + public static Descriptor descriptor(List names, List types) + { + requireNonNull(names, "names is null"); + requireNonNull(types, "types is null"); + checkArgument(names.size() == types.size(), "names and types lists do not match"); + List fields = new ArrayList<>(); + for (int i = 0; i < names.size(); i++) { + fields.add(new Field(names.get(i), Optional.of(types.get(i)))); + } + return new Descriptor(fields); + } + + @JsonProperty + public List getFields() + { + return fields; + } + + public boolean isTyped() + { + return fields.stream().allMatch(field -> field.type.isPresent()); + } + + public static class Field + { + private final String name; + private final Optional type; + + @JsonCreator + public Field(@JsonProperty("name") String name, @JsonProperty("type") Optional type) + { + this.name = checkNotNullOrEmpty(name, "name"); + this.type = requireNonNull(type, "type is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Optional getType() + { + return type; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgument.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgument.java new file mode 100644 index 000000000000..ea87f17f8e20 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgument.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * This class represents the descriptor argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + */ +public class DescriptorArgument + extends Argument +{ + public static final DescriptorArgument NULL_DESCRIPTOR = new DescriptorArgument(Optional.empty()); + private final Optional descriptor; + + public static DescriptorArgument descriptorArgument(Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + return new DescriptorArgument(Optional.of(descriptor)); + } + + @JsonCreator + private DescriptorArgument(@JsonProperty("descriptor") Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + @JsonProperty + public Optional getDescriptor() + { + return descriptor; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgumentSpecification.java new file mode 100644 index 000000000000..5a9f1a36c374 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorArgumentSpecification.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +public class DescriptorArgumentSpecification + extends ArgumentSpecification +{ + public DescriptorArgumentSpecification(String name) + { + super(name, true, null); + } + + public DescriptorArgumentSpecification(String name, Descriptor defaultValue) + { + super(name, false, defaultValue); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorMapping.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorMapping.java new file mode 100644 index 000000000000..f1566bfe59a6 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/DescriptorMapping.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkArgument; +import static io.trino.spi.ptf.ConnectorTableFunction.checkNotNullOrEmpty; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class DescriptorMapping +{ + public static final DescriptorMapping EMPTY_MAPPING = new DescriptorMappingBuilder().build(); + + private final Map mappingByField; + private final Map mappingByDescriptor; + + private DescriptorMapping(Map mappingByField, Map mappingByDescriptor) + { + this.mappingByField = Map.copyOf(requireNonNull(mappingByField, "mappingByField is null")); + this.mappingByDescriptor = Map.copyOf(requireNonNull(mappingByDescriptor, "mappingByDescriptor is null")); + } + + public Map getMappingByField() + { + return mappingByField; + } + + public Map getMappingByDescriptor() + { + return mappingByDescriptor; + } + + public boolean isEmpty() + { + return mappingByField.isEmpty() && mappingByDescriptor.isEmpty(); + } + + public static class DescriptorMappingBuilder + { + private final Map mappingByField = new HashMap<>(); + private final Map mappingByDescriptor = new HashMap<>(); + private final Set descriptorsMappedByField = new HashSet<>(); + + public DescriptorMappingBuilder mapField(String descriptor, int field, String table) + { + checkNotNullOrEmpty(table, "table"); + checkArgument(!mappingByDescriptor.containsKey(descriptor), format("duplicate mapping for descriptor: %s, field: %s", descriptor, field)); + checkArgument(mappingByField.put(new NameAndPosition(descriptor, field), table) == null, format("duplicate mapping for descriptor: %s, field: %s", descriptor, field)); + descriptorsMappedByField.add(descriptor); + return this; + } + + public DescriptorMappingBuilder mapAllFields(String descriptor, String table) + { + checkNotNullOrEmpty(descriptor, "descriptor"); + checkNotNullOrEmpty(table, "table"); + checkArgument(!descriptorsMappedByField.contains(descriptor), "duplicate mapping for field of descriptor: " + descriptor); + checkArgument(mappingByDescriptor.put(descriptor, table) == null, "duplicate mapping for descriptor: " + descriptor); + return this; + } + + public DescriptorMapping build() + { + return new DescriptorMapping(mappingByField, mappingByDescriptor); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/NameAndPosition.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/NameAndPosition.java new file mode 100644 index 000000000000..29d7a8d4639b --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/NameAndPosition.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import java.util.Objects; + +/** + * This class represents a descriptor field reference. + * `name` is the descriptor argument name, `position` is the zero-based field index. + *

+ * The specified field contains a column name, as passed by the Table Function caller. + * The column name is associated with an appropriate input table during the Analysis phase. + * The Table Function is supposed to refer to input data using `NameAndPosition`, + * and the engine should provide the requested column. + */ +public class NameAndPosition +{ + private final String name; + private final int position; + + public NameAndPosition(String name, int position) + { + this.name = ConnectorTableFunction.checkNotNullOrEmpty(name, "name"); + ConnectorTableFunction.checkArgument(position >= 0, "position in descriptor must not be negative"); + this.position = position; + } + + public String getName() + { + return name; + } + + public int getPosition() + { + return position; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NameAndPosition that = (NameAndPosition) o; + return position == that.position && Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return Objects.hash(name, position); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ReturnTypeSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ReturnTypeSpecification.java new file mode 100644 index 000000000000..224dfb5f07e7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ReturnTypeSpecification.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * The return type declaration refers to the proper columns of the table function. + * These are the columns produced by the table function as opposed to the columns + * of input relations passed through by the table function. + */ +public abstract class ReturnTypeSpecification +{ + /** + * The proper columns of the table function are not known at function declaration time. + * They must be determined at query analysis time based on the actual call arguments. + */ + public static class GenericTable + extends ReturnTypeSpecification + { + public static final GenericTable GENERIC_TABLE = new GenericTable(); + + private GenericTable() {} + } + + /** + * The table function has no proper columns. + */ + public static class OnlyPassThrough + extends ReturnTypeSpecification + { + public static final OnlyPassThrough ONLY_PASS_THROUGH = new OnlyPassThrough(); + + private OnlyPassThrough() {} + } + + /** + * The proper columns of the table function are known at function declaration time. + * They do not depend on the actual call arguments. + */ + public static class DescribedTable + extends ReturnTypeSpecification + { + private final Descriptor descriptor; + + public DescribedTable(Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + checkArgument(descriptor.isTyped(), "field types not specified"); + this.descriptor = descriptor; + } + + public Descriptor getDescriptor() + { + return descriptor; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgument.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgument.java new file mode 100644 index 000000000000..06ce9308ea08 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgument.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +/** + * This class represents the scalar argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + *

+ * Additionally, only constant values are currently supported. In the future, + * we will add support for different kinds of expressions. + */ +public class ScalarArgument + extends Argument +{ + private final Type type; + + // native representation + private final Object value; + + @JsonCreator + public ScalarArgument(@JsonProperty("type") Type type, @JsonProperty("value") Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = value; + } + + @JsonProperty + public Type getType() + { + return type; + } + + @JsonProperty + public Object getValue() + { + return value; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgumentSpecification.java new file mode 100644 index 000000000000..ac581b6bfcdb --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ScalarArgumentSpecification.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import io.trino.spi.type.Type; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ScalarArgumentSpecification + extends ArgumentSpecification +{ + private final Type type; + + public ScalarArgumentSpecification(String name, Type type) + { + super(name, true, null); + this.type = requireNonNull(type, "type is null"); + } + + public ScalarArgumentSpecification(String name, Type type, Object defaultValue) + { + super(name, false, defaultValue); + this.type = requireNonNull(type, "type is null"); + checkArgument(type.getJavaType().equals(defaultValue.getClass()), format("default value %s does not match the declared type: %s", defaultValue, type)); + } + + public Type getType() + { + return type; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgument.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgument.java new file mode 100644 index 000000000000..7a6b2ddbaba7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgument.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.type.RowType; + +import java.util.List; +import java.util.Optional; + +import static io.trino.spi.ptf.ConnectorTableFunction.checkNotNullOrEmpty; +import static java.util.Objects.requireNonNull; + +/** + * This class represents the table argument passed to a Table Function. + *

+ * This representation should be considered experimental. Eventually, {@link ConnectorExpression} + * should be extended to include this kind of argument. + */ +public class TableArgument + extends Argument +{ + private final Optional name; + private final RowType rowType; + private final List partitionBy; + private final List orderBy; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + + @JsonCreator + public TableArgument( + @JsonProperty("name") Optional name, + @JsonProperty("rowType") RowType rowType, + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("orderBy") List orderBy, + @JsonProperty("rowSemantics") boolean rowSemantics, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughColumns") boolean passThroughColumns) + { + this.name = requireNonNull(name, "name is null"); + this.rowType = requireNonNull(rowType, "rowType is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + } + + @JsonProperty + public Optional getName() + { + return name; + } + + @JsonProperty + public RowType getRowType() + { + return rowType; + } + + @JsonProperty + public List getPartitionBy() + { + return partitionBy; + } + + @JsonProperty + public List getOrderBy() + { + return orderBy; + } + + @JsonProperty + public boolean isRowSemantics() + { + return rowSemantics; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public boolean isPassThroughColumns() + { + return passThroughColumns; + } + + public static class QualifiedName + { + private final String catalogName; + private final String schemaName; + private final String tableName; + + @JsonCreator + public QualifiedName( + @JsonProperty("catalogName") String catalogName, + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName) + { + this.catalogName = checkNotNullOrEmpty(catalogName, "catalogName"); + this.schemaName = checkNotNullOrEmpty(schemaName, "schemaName"); + this.tableName = checkNotNullOrEmpty(tableName, "tableName"); + } + + @JsonProperty + public String getCatalogName() + { + return catalogName; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + } + + public static class SortItem + { + private final String column; + private final boolean ascending; + private final boolean nullsLast; + + @JsonCreator + public SortItem( + @JsonProperty("column") String column, + @JsonProperty("ascending") boolean ascending, + @JsonProperty("nullsFirst") boolean nullsFirst) + { + this.column = checkNotNullOrEmpty(column, "ordering column"); + this.ascending = ascending; + this.nullsLast = nullsFirst; + } + + @JsonProperty + public String getColumn() + { + return column; + } + + @JsonProperty + public boolean isAscending() + { + return ascending; + } + + @JsonProperty + public boolean isNullsLast() + { + return nullsLast; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java new file mode 100644 index 000000000000..6bdf202fcec4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.ptf; + +public class TableArgumentSpecification + extends ArgumentSpecification +{ + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + + public TableArgumentSpecification(String name, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns) + { + super(name, true, null); + + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + } + + public TableArgumentSpecification(String name) + { + // defaults + this(name, false, false, false); + } + + public boolean isRowSemantics() + { + return rowSemantics; + } + + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean isPassThroughColumns() + { + return passThroughColumns; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 85393f7ba9fa..8e20fa03f4df 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -56,10 +56,12 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableColumnsMetadata; +import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Privilege; import io.trino.spi.security.RoleGrant; @@ -935,6 +937,14 @@ public Optional> applyTopN( } } + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.applyTableFunction(session, handle); + } + } + @Override public void validateScan(ConnectorSession session, ConnectorTableHandle handle) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnector.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnector.java index f803c8042df0..616d955dfcd3 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnector.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcConnector.java @@ -35,22 +35,16 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Sets.immutableEnumSet; import static io.trino.spi.connector.ConnectorCapabilities.NOT_NULL_COLUMN_CONSTRAINT; -import static io.trino.spi.transaction.IsolationLevel.READ_COMMITTED; -import static io.trino.spi.transaction.IsolationLevel.checkConnectorSupports; import static java.util.Objects.requireNonNull; public class JdbcConnector implements Connector { private final LifeCycleManager lifeCycleManager; - private final JdbcMetadataFactory jdbcMetadataFactory; private final ConnectorSplitManager jdbcSplitManager; private final ConnectorRecordSetProvider jdbcRecordSetProvider; private final ConnectorPageSinkProvider jdbcPageSinkProvider; @@ -58,23 +52,21 @@ public class JdbcConnector private final Set procedures; private final List> sessionProperties; private final List> tableProperties; - - private final ConcurrentMap transactions = new ConcurrentHashMap<>(); + private final JdbcTransactionManager transactionManager; @Inject public JdbcConnector( LifeCycleManager lifeCycleManager, - JdbcMetadataFactory jdbcMetadataFactory, ConnectorSplitManager jdbcSplitManager, ConnectorRecordSetProvider jdbcRecordSetProvider, ConnectorPageSinkProvider jdbcPageSinkProvider, Optional accessControl, Set procedures, Set sessionProperties, - Set tableProperties) + Set tableProperties, + JdbcTransactionManager transactionManager) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); - this.jdbcMetadataFactory = requireNonNull(jdbcMetadataFactory, "jdbcMetadataFactory is null"); this.jdbcSplitManager = requireNonNull(jdbcSplitManager, "jdbcSplitManager is null"); this.jdbcRecordSetProvider = requireNonNull(jdbcRecordSetProvider, "jdbcRecordSetProvider is null"); this.jdbcPageSinkProvider = requireNonNull(jdbcPageSinkProvider, "jdbcPageSinkProvider is null"); @@ -86,37 +78,31 @@ public JdbcConnector( this.tableProperties = requireNonNull(tableProperties, "tableProperties is null").stream() .flatMap(tablePropertiesProvider -> tablePropertiesProvider.getTableProperties().stream()) .collect(toImmutableList()); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); } @Override public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly, boolean autoCommit) { - checkConnectorSupports(READ_COMMITTED, isolationLevel); - JdbcTransactionHandle transaction = new JdbcTransactionHandle(); - transactions.put(transaction, jdbcMetadataFactory.create(transaction)); - return transaction; + return transactionManager.beginTransaction(isolationLevel, readOnly, autoCommit); } @Override public ConnectorMetadata getMetadata(ConnectorSession session, ConnectorTransactionHandle transaction) { - JdbcMetadata metadata = transactions.get(transaction); - checkArgument(metadata != null, "no such transaction: %s", transaction); - return new ClassLoaderSafeConnectorMetadata(metadata, getClass().getClassLoader()); + return new ClassLoaderSafeConnectorMetadata(transactionManager.getMetadata(transaction), getClass().getClassLoader()); } @Override public void commit(ConnectorTransactionHandle transaction) { - checkArgument(transactions.remove(transaction) != null, "no such transaction: %s", transaction); + transactionManager.commit(transaction); } @Override public void rollback(ConnectorTransactionHandle transaction) { - JdbcMetadata metadata = transactions.remove(transaction); - checkArgument(metadata != null, "no such transaction: %s", transaction); - metadata.rollback(); + transactionManager.rollback(transaction); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java index ca3db5a8bf15..054efe37cc6a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java @@ -67,6 +67,7 @@ public void setup(Binder binder) newOptionalBinder(binder, ConnectorRecordSetProvider.class).setDefault().to(JdbcRecordSetProvider.class).in(Scopes.SINGLETON); newOptionalBinder(binder, ConnectorPageSinkProvider.class).setDefault().to(JdbcPageSinkProvider.class).in(Scopes.SINGLETON); + binder.bind(JdbcTransactionManager.class).in(Scopes.SINGLETON); binder.bind(JdbcConnector.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(JdbcMetadataConfig.class); configBinder(binder).bindConfig(JdbcWriteConfig.class); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTransactionManager.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTransactionManager.java new file mode 100644 index 000000000000..4a334b0c5910 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTransactionManager.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.transaction.IsolationLevel; + +import javax.inject.Inject; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.transaction.IsolationLevel.READ_COMMITTED; +import static io.trino.spi.transaction.IsolationLevel.checkConnectorSupports; +import static java.util.Objects.requireNonNull; + +public class JdbcTransactionManager +{ + private final ConcurrentMap transactions = new ConcurrentHashMap<>(); + private final JdbcMetadataFactory metadataFactory; + + @Inject + public JdbcTransactionManager(JdbcMetadataFactory metadataFactory) + { + this.metadataFactory = requireNonNull(metadataFactory, "metadataFactory is null"); + } + + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly, boolean autoCommit) + { + checkConnectorSupports(READ_COMMITTED, isolationLevel); + JdbcTransactionHandle transaction = new JdbcTransactionHandle(); + transactions.put(transaction, metadataFactory.create(transaction)); + return transaction; + } + + public JdbcMetadata getMetadata(ConnectorTransactionHandle transaction) + { + JdbcMetadata metadata = transactions.get(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); + return metadata; + } + + public void commit(ConnectorTransactionHandle transaction) + { + checkArgument(transactions.remove(transaction) != null, "no such transaction: %s", transaction); + } + + public void rollback(ConnectorTransactionHandle transaction) + { + JdbcMetadata metadata = transactions.remove(transaction); + checkArgument(metadata != null, "no such transaction: %s", transaction); + metadata.rollback(); + } +}