diff --git a/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java b/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java index 755447922987..49f93f60533e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java +++ b/core/trino-main/src/main/java/io/trino/operator/table/ExcludeColumns.java @@ -17,6 +17,7 @@ import com.google.common.collect.Sets; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorTableFunction; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.ptf.AbstractConnectorTableFunction; @@ -86,7 +87,11 @@ public ExcludeColumnsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { DescriptorArgument excludedColumns = (DescriptorArgument) arguments.get(DESCRIPTOR_ARGUMENT_NAME); if (excludedColumns.equals(NULL_DESCRIPTOR)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/table/Sequence.java b/core/trino-main/src/main/java/io/trino/operator/table/Sequence.java index 5d3eaed6765c..bd89f5d1a739 100644 --- a/core/trino-main/src/main/java/io/trino/operator/table/Sequence.java +++ b/core/trino-main/src/main/java/io/trino/operator/table/Sequence.java @@ -23,6 +23,7 @@ import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitSource; @@ -100,7 +101,11 @@ public SequenceFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { Object startValue = ((ScalarArgument) arguments.get(START_ARGUMENT_NAME)).getValue(); if (startValue == null) { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index ce31502695e8..eba24c00da20 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -54,6 +54,7 @@ import io.trino.metadata.ViewDefinition; import io.trino.security.AccessControl; import io.trino.security.AllowAllAccessControl; +import io.trino.security.InjectedConnectorAccessControl; import io.trino.security.SecurityContext; import io.trino.security.ViewAccessControl; import io.trino.spi.TrinoException; @@ -1567,7 +1568,11 @@ protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optio ArgumentsAnalysis argumentsAnalysis = analyzeArguments(function.getArguments(), node.getArguments(), scope, errorLocation); ConnectorTransactionHandle transactionHandle = transactionManager.getConnectorTransaction(session.getRequiredTransactionId(), catalogHandle); - TableFunctionAnalysis functionAnalysis = function.analyze(session.toConnectorSession(catalogHandle), transactionHandle, argumentsAnalysis.getPassedArguments()); + TableFunctionAnalysis functionAnalysis = function.analyze( + session.toConnectorSession(catalogHandle), + transactionHandle, + argumentsAnalysis.getPassedArguments(), + new InjectedConnectorAccessControl(accessControl, session.toSecurityContext(), catalogHandle.getCatalogName())); List> copartitioningLists = analyzeCopartitioning(node.getCopartitioning(), argumentsAnalysis.getTableArgumentAnalyses()); diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index 4917ddfc79e1..8f236c0d493f 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -16,12 +16,14 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.trino.spi.HostAddress; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitSource; @@ -119,13 +121,19 @@ public SimpleTableFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument argument = (ScalarArgument) arguments.get("COLUMN"); String columnName = ((Slice) argument.getValue()).toStringUtf8(); + String schema = getSchema(); + return TableFunctionAnalysis.builder() - .handle(new SimpleTableFunctionHandle(getSchema(), TABLE_NAME, columnName)) + .handle(new SimpleTableFunctionHandle(schema, TABLE_NAME, columnName)) .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(columnName, Optional.of(BOOLEAN))))) .build(); } @@ -134,6 +142,7 @@ public static class SimpleTableFunctionHandle implements ConnectorTableFunctionHandle { private final MockConnectorTableHandle tableHandle; + private final String columnName; public SimpleTableFunctionHandle(String schema, String table, String column) { @@ -141,12 +150,43 @@ public SimpleTableFunctionHandle(String schema, String table, String column) new SchemaTableName(schema, table), TupleDomain.all(), Optional.of(ImmutableList.of(new MockConnectorColumnHandle(column, BOOLEAN)))); + this.columnName = requireNonNull(column, "column is null"); } public MockConnectorTableHandle getTableHandle() { return tableHandle; } + + public String getColumnName() + { + return columnName; + } + } + } + + /** + * A table function returning a table with single empty column of type BOOLEAN. + * The argument `COLUMN` is the column name. + * The argument `IGNORED` is ignored. + * Both arguments are optional. + * Performs access control checks + */ + public static class SimpleTableFunctionWithAccessControl + extends SimpleTableFunction + { + @Override + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) + { + TableFunctionAnalysis analyzeResult = super.analyze(session, transaction, arguments, accessControl); + SimpleTableFunction.SimpleTableFunctionHandle handle = (SimpleTableFunction.SimpleTableFunctionHandle) analyzeResult.getHandle(); + accessControl.checkCanSelectFromColumns(null, handle.getTableHandle().getTableName(), ImmutableSet.of(handle.getColumnName())); + + return analyzeResult; } } @@ -172,7 +212,11 @@ public TwoScalarArgumentsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return ANALYSIS; } @@ -197,7 +241,11 @@ public TableArgumentFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -226,7 +274,11 @@ public TableArgumentRowSemanticsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -253,7 +305,11 @@ public DescriptorArgumentFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return ANALYSIS; } @@ -282,7 +338,11 @@ public TwoTableArgumentsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -311,7 +371,11 @@ public OnlyPassThroughFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return NO_DESCRIPTOR_ANALYSIS; } @@ -332,7 +396,11 @@ public MonomorphicStaticReturnTypeFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(HANDLE) @@ -358,7 +426,11 @@ public PolymorphicStaticReturnTypeFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return NO_DESCRIPTOR_ANALYSIS; } @@ -383,7 +455,11 @@ public PassThroughFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return NO_DESCRIPTOR_ANALYSIS; } @@ -425,7 +501,11 @@ public DifferentArgumentTypesFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -456,7 +536,11 @@ public RequiredColumnsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -506,7 +590,11 @@ public IdentityFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { List inputColumns = ((TableArgument) arguments.get("INPUT")).getRowType().getFields(); Descriptor returnedType = new Descriptor(inputColumns.stream() @@ -556,7 +644,11 @@ public IdentityPassThroughFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -621,7 +713,11 @@ public RepeatFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument count = (ScalarArgument) arguments.get("N"); requireNonNull(count.getValue(), "count value for function repeat() is null"); @@ -737,7 +833,11 @@ public EmptyOutputFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -791,7 +891,11 @@ public EmptyOutputWithPassThroughFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -860,7 +964,11 @@ public TestInputsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -919,7 +1027,11 @@ public PassThroughInputFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -1015,7 +1127,11 @@ public TestInputFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -1075,7 +1191,11 @@ public TestSingleInputRowSemanticsFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) @@ -1127,7 +1247,11 @@ public ConstantFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument count = (ScalarArgument) arguments.get("N"); requireNonNull(count.getValue(), "count value for function repeat() is null"); @@ -1286,7 +1410,11 @@ public EmptySourceFunction() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { return TableFunctionAnalysis.builder() .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/AbstractConnectorTableFunction.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/AbstractConnectorTableFunction.java index e23ecadd7219..2b21e3b2150a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/AbstractConnectorTableFunction.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/AbstractConnectorTableFunction.java @@ -14,6 +14,7 @@ package io.trino.spi.ptf; import io.trino.spi.Experimental; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -64,5 +65,5 @@ public ReturnTypeSpecification getReturnTypeSpecification() } @Override - public abstract TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments); + public abstract TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments, ConnectorAccessControl accessControl); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java index a5a0d5af1946..d0a2659fbede 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java @@ -14,6 +14,7 @@ package io.trino.spi.ptf; import io.trino.spi.Experimental; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -47,5 +48,5 @@ public interface ConnectorTableFunction * * @param arguments actual invocation arguments, mapped by argument names */ - TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments); + TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments, ConnectorAccessControl accessControl); } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorTableFunction.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorTableFunction.java index 83945b1b2504..0e99690ac932 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorTableFunction.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorTableFunction.java @@ -14,6 +14,7 @@ package io.trino.plugin.base.classloader; import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.ptf.Argument; @@ -72,10 +73,13 @@ public ReturnTypeSpecification getReturnTypeSpecification() } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze(ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { - return delegate.analyze(session, transaction, arguments); + return delegate.analyze(session, transaction, arguments, accessControl); } } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java index fc16067066a6..80cbaf76a5a1 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.JdbcProcedureHandle; import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; import io.trino.plugin.jdbc.JdbcTransactionManager; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -88,7 +89,11 @@ public ProcedureFunction(JdbcTransactionManager transactionManager) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument argument = (ScalarArgument) getOnlyElement(arguments.values()); String procedureQuery = ((Slice) argument.getValue()).toStringUtf8(); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Query.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Query.java index e2f0dd306f2f..6a806abf1f14 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Query.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Query.java @@ -23,6 +23,7 @@ import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTransactionManager; import io.trino.plugin.jdbc.PreparedQuery; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -88,7 +89,11 @@ public QueryFunction(JdbcTransactionManager transactionManager) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument argument = (ScalarArgument) getOnlyElement(arguments.values()); String query = ((Slice) argument.getValue()).toStringUtf8(); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java index bfd8ef34e5aa..43c492c09529 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java @@ -24,6 +24,7 @@ import io.trino.plugin.bigquery.BigQueryColumnHandle; import io.trino.plugin.bigquery.BigQueryQueryRelationHandle; import io.trino.plugin.bigquery.BigQueryTableHandle; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -91,7 +92,11 @@ public QueryFunction(BigQueryClientFactory clientFactory) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument argument = (ScalarArgument) getOnlyElement(arguments.values()); String query = ((Slice) argument.getValue()).toStringUtf8(); diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ptf/Query.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ptf/Query.java index a675312ac117..ef465786a2f9 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ptf/Query.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/ptf/Query.java @@ -23,6 +23,7 @@ import io.trino.plugin.cassandra.CassandraQueryRelationHandle; import io.trino.plugin.cassandra.CassandraTableHandle; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.ptf.AbstractConnectorTableFunction; @@ -88,7 +89,11 @@ public QueryFunction(CassandraMetadata cassandraMetadata) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { ScalarArgument argument = (ScalarArgument) getOnlyElement(arguments.values()); String query = ((Slice) argument.getValue()).toStringUtf8(); diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java index 80547cc2aaf9..045cbf3a4061 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ptf/RawQuery.java @@ -21,6 +21,7 @@ import io.trino.plugin.elasticsearch.ElasticsearchTableHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnSchema; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableSchema; @@ -96,7 +97,11 @@ public RawQueryFunction(ElasticsearchMetadata metadata) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { String schema = ((Slice) ((ScalarArgument) arguments.get("SCHEMA")).getValue()).toStringUtf8(); String index = ((Slice) ((ScalarArgument) arguments.get("INDEX")).getValue()).toStringUtf8(); diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/ptf/Sheet.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/ptf/Sheet.java index 1fa098db8060..6a942efd2fd8 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/ptf/Sheet.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/ptf/Sheet.java @@ -24,6 +24,7 @@ import io.trino.plugin.google.sheets.SheetsMetadata; import io.trino.plugin.google.sheets.SheetsSheetTableHandle; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -98,7 +99,11 @@ public SheetFunction(SheetsMetadata metadata) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { String sheetId = ((Slice) ((ScalarArgument) arguments.get(ID_ARGUMENT)).getValue()).toStringUtf8(); validateSheetId(sheetId); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java index 13ca76862742..bd4b5974cd12 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java @@ -25,6 +25,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnSchema; +import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableSchema; @@ -108,7 +109,11 @@ public QueryFunction(MongoMetadata metadata, MongoSession mongoSession) } @Override - public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments, + ConnectorAccessControl accessControl) { String database = ((Slice) ((ScalarArgument) arguments.get("DATABASE")).getValue()).toStringUtf8(); String collection = ((Slice) ((ScalarArgument) arguments.get("COLLECTION")).getValue()).toStringUtf8(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestTableFunctionInvocation.java b/testing/trino-tests/src/test/java/io/trino/tests/TestTableFunctionInvocation.java index 84153d48471c..897e8505303c 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestTableFunctionInvocation.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestTableFunctionInvocation.java @@ -35,8 +35,8 @@ import io.trino.connector.TestingTableFunctions.RepeatFunction; import io.trino.connector.TestingTableFunctions.RepeatFunction.RepeatFunctionHandle; import io.trino.connector.TestingTableFunctions.RepeatFunction.RepeatFunctionProcessorProvider; -import io.trino.connector.TestingTableFunctions.SimpleTableFunction; import io.trino.connector.TestingTableFunctions.SimpleTableFunction.SimpleTableFunctionHandle; +import io.trino.connector.TestingTableFunctions.SimpleTableFunctionWithAccessControl; import io.trino.connector.TestingTableFunctions.TestInputFunction; import io.trino.connector.TestingTableFunctions.TestInputFunction.TestInputProcessorProvider; import io.trino.connector.TestingTableFunctions.TestInputsFunction; @@ -61,6 +61,8 @@ import static io.trino.connector.MockConnector.MockConnectorSplit.MOCK_CONNECTOR_SPLIT; import static io.trino.connector.TestingTableFunctions.ConstantFunction.getConstantFunctionSplitSource; +import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; +import static io.trino.testing.TestingAccessControlManager.privilege; import static io.trino.testing.TestingSession.testSessionBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -89,7 +91,7 @@ public void setUp() queryRunner.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() .withTableFunctions(ImmutableSet.of( - new SimpleTableFunction(), + new SimpleTableFunctionWithAccessControl(), new IdentityFunction(), new IdentityPassThroughFunction(), new RepeatFunction(), @@ -164,6 +166,20 @@ public void testPrimitiveDefaultArgument() .matches("SELECT true WHERE false"); } + @Test + public void testAccessControl() + { + assertAccessDenied( + "SELECT boolean_column FROM TABLE(system.simple_table_function(column => 'boolean_column', ignored => 1))", + "Cannot select from columns .*", + privilege("simple_table.boolean_column", SELECT_COLUMN)); + + assertAccessDenied( + "SELECT boolean_column FROM TABLE(system.simple_table_function(column => 'boolean_column', ignored => 1))", + "Cannot select from columns .*", + privilege("simple_table", SELECT_COLUMN)); + } + @Test public void testNoArgumentsPassed() {