diff --git a/core/trino-main/src/main/java/io/trino/connector/system/SystemPageSourceProvider.java b/core/trino-main/src/main/java/io/trino/connector/system/SystemPageSourceProvider.java index a5385c2d349b..952f8486e8ad 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/SystemPageSourceProvider.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/SystemPageSourceProvider.java @@ -14,6 +14,8 @@ package io.trino.connector.system; import com.google.common.collect.ImmutableList; +import io.trino.plugin.base.MappedPageSource; +import io.trino.plugin.base.MappedRecordSet; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -32,8 +34,6 @@ import io.trino.spi.connector.SystemTable; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; -import io.trino.split.MappedPageSource; -import io.trino.split.MappedRecordSet; import java.util.HashMap; import java.util.List; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 4876e1cd5cf2..8ae8096de884 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -149,6 +149,7 @@ import io.trino.operator.window.pattern.PhysicalValueAccessor; import io.trino.operator.window.pattern.PhysicalValuePointer; import io.trino.operator.window.pattern.SetEvaluator.SetEvaluatorSupplier; +import io.trino.plugin.base.MappedRecordSet; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; @@ -172,7 +173,6 @@ import io.trino.spiller.PartitioningSpillerFactory; import io.trino.spiller.SingleStreamSpillerFactory; import io.trino.spiller.SpillerFactory; -import io.trino.split.MappedRecordSet; import io.trino.split.PageSinkManager; import io.trino.split.PageSourceProvider; import io.trino.sql.DynamicFilters; diff --git a/docs/src/main/sphinx/connector/sqlserver.rst b/docs/src/main/sphinx/connector/sqlserver.rst index 3ff776295ac5..4749e7f05849 100644 --- a/docs/src/main/sphinx/connector/sqlserver.rst +++ b/docs/src/main/sphinx/connector/sqlserver.rst @@ -367,6 +367,46 @@ nations by population:: ) ); +.. _sqlserver-procedure-function: + +``procedure(varchar) -> table`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``procedure`` function allows you to run stored procedures on the underlying +database directly. + +.. note:: + + The ``procedure`` function does not support running StoredProcedures that return multiple statements, + use a non-select statement, use output parameters, or use conditional statements. + + +The follow example runs the stored procedure ``employee_sp`` in the ``example`` catalog and the +``example_schema`` in the underlying SQL Server database:: + + SELECT + * + FROM + TABLE( + example.system.procedure( + schema => 'example_schema', + procedure => 'employee_sp' + ) + ); + +If the stored procedure ``employee_sp`` requires any input, it can be specified via ``inputs`` +argument as an ``array``:: + + SELECT + * + FROM + TABLE( + example.system.procedure( + schema => 'example_schema', + procedure => 'employee_sp', + inputs => ARRAY['0'] + ) + ); Performance ----------- diff --git a/core/trino-main/src/main/java/io/trino/split/MappedPageSource.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java similarity index 98% rename from core/trino-main/src/main/java/io/trino/split/MappedPageSource.java rename to lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java index 1d021b8687d5..958f6f2a9f9d 100644 --- a/core/trino-main/src/main/java/io/trino/split/MappedPageSource.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.split; +package io.trino.plugin.base; import com.google.common.primitives.Ints; import io.trino.spi.Page; diff --git a/core/trino-main/src/main/java/io/trino/split/MappedRecordSet.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedRecordSet.java similarity index 99% rename from core/trino-main/src/main/java/io/trino/split/MappedRecordSet.java rename to lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedRecordSet.java index 1ed519e21142..0ccca34389d1 100644 --- a/core/trino-main/src/main/java/io/trino/split/MappedRecordSet.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedRecordSet.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.split; +package io.trino.plugin.base; import com.google.common.primitives.Ints; import io.airlift.slice.Slice; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 3b1ee8815dc6..421b197f1b54 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -22,6 +22,7 @@ import io.airlift.log.Logger; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -45,6 +46,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; @@ -245,44 +247,55 @@ public Optional getTableHandle(ConnectorSession session, Schema @Override public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery preparedQuery) { - ImmutableList.Builder columns = ImmutableList.builder(); try (Connection connection = connectionFactory.openConnection(session); PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery)) { ResultSetMetaData metadata = preparedStatement.getMetaData(); if (metadata == null) { throw new UnsupportedOperationException("Query not supported: ResultSetMetaData not available for query: " + preparedQuery.getQuery()); } - for (int column = 1; column <= metadata.getColumnCount(); column++) { - // Use getColumnLabel method because query pass-through table function may contain column aliases - String name = metadata.getColumnLabel(column); - JdbcTypeHandle jdbcTypeHandle = new JdbcTypeHandle( - metadata.getColumnType(column), - Optional.ofNullable(metadata.getColumnTypeName(column)), - Optional.of(metadata.getPrecision(column)), - Optional.of(metadata.getScale(column)), - Optional.empty(), // TODO support arrays - Optional.of(metadata.isCaseSensitive(column) ? CASE_SENSITIVE : CASE_INSENSITIVE)); - Type type = toColumnMapping(session, connection, jdbcTypeHandle) - .orElseThrow(() -> new UnsupportedOperationException(format("Unsupported type: %s of column: %s", jdbcTypeHandle, name))) - .getType(); - columns.add(new JdbcColumnHandle(name, jdbcTypeHandle, type)); - } + return new JdbcTableHandle( + new JdbcQueryRelationHandle(preparedQuery), + TupleDomain.all(), + ImmutableList.of(), + Optional.empty(), + OptionalLong.empty(), + Optional.of(getColumns(session, connection, metadata)), + // The query is opaque, so we don't know referenced tables + Optional.empty(), + 0, + Optional.empty()); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, "Failed to get table handle for prepared query. " + firstNonNull(e.getMessage(), e), e); } + } + + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + { + throw new TrinoException(NOT_SUPPORTED, "Procedure is not supported"); + } - return new JdbcTableHandle( - new JdbcQueryRelationHandle(preparedQuery), - TupleDomain.all(), - ImmutableList.of(), - Optional.empty(), - OptionalLong.empty(), - Optional.of(columns.build()), - // The query is opaque, so we don't know referenced tables - Optional.empty(), - 0, - Optional.empty()); + protected List getColumns(ConnectorSession session, Connection connection, ResultSetMetaData metadata) + throws SQLException + { + ImmutableList.Builder columns = ImmutableList.builder(); + for (int column = 1; column <= metadata.getColumnCount(); column++) { + // Use getColumnLabel method because query pass-through table function may contain column aliases + String name = metadata.getColumnLabel(column); + JdbcTypeHandle jdbcTypeHandle = new JdbcTypeHandle( + metadata.getColumnType(column), + Optional.ofNullable(metadata.getColumnTypeName(column)), + Optional.of(metadata.getPrecision(column)), + Optional.of(metadata.getScale(column)), + Optional.empty(), // TODO support arrays + Optional.of(metadata.isCaseSensitive(column) ? CASE_SENSITIVE : CASE_INSENSITIVE)); + Type type = toColumnMapping(session, connection, jdbcTypeHandle) + .orElseThrow(() -> new UnsupportedOperationException(format("Unsupported type: %s of column: %s", jdbcTypeHandle, name))) + .getType(); + columns.add(new JdbcColumnHandle(name, jdbcTypeHandle, type)); + } + return columns.build(); } @Override @@ -418,11 +431,30 @@ public ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle return new FixedSplitSource(new JdbcSplit(Optional.empty())); } + @Override + public ConnectorSplitSource getSplits(ConnectorSession session, JdbcProcedureHandle procedureHandle) + { + return new FixedSplitSource(new JdbcSplit(Optional.empty())); + } + @Override public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) throws SQLException { verify(tableHandle.getAuthorization().isEmpty(), "Unexpected authorization is required for table: %s".formatted(tableHandle)); + return getConnection(session); + } + + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return getConnection(session); + } + + private Connection getConnection(ConnectorSession session) + throws SQLException + { Connection connection = connectionFactory.openConnection(session); try { connection.setReadOnly(true); @@ -459,6 +491,13 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return queryBuilder.prepareStatement(this, session, connection, preparedQuery); } + @Override + public CallableStatement buildProcedure(ConnectorSession session, Connection connection, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return queryBuilder.callProcedure(this, session, connection, procedureHandle.getProcedureQuery()); + } + protected PreparedQuery prepareQuery( ConnectorSession session, Connection connection, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcConnectorTableHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcConnectorTableHandle.java new file mode 100644 index 000000000000..f8f41743999f --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcConnectorTableHandle.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.spi.connector.ConnectorTableHandle; + +import java.util.List; +import java.util.Optional; + +public abstract class BaseJdbcConnectorTableHandle + implements ConnectorTableHandle +{ + public abstract Optional> getColumns(); +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 3dc3a6ab087d..81166b898d42 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -23,6 +23,7 @@ import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.jdbc.IdentityCacheMapping.IdentityCacheKey; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -45,6 +46,7 @@ import javax.inject.Inject; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -83,6 +85,7 @@ public class CachingJdbcClient private final Cache> tableNamesCache; private final Cache> tableHandlesByNameCache; private final Cache tableHandlesByQueryCache; + private final Cache procedureHandlesByQueryCache; private final Cache> columnsCache; private final Cache statisticsCache; @@ -148,6 +151,7 @@ public CachingJdbcClient( tableNamesCache = buildCache(cacheSize, tableNamesCachingTtl); tableHandlesByNameCache = buildCache(cacheSize, metadataCachingTtl); tableHandlesByQueryCache = buildCache(cacheSize, metadataCachingTtl); + procedureHandlesByQueryCache = buildCache(cacheSize, metadataCachingTtl); columnsCache = buildCache(cacheSize, metadataCachingTtl); statisticsCache = buildCache(cacheSize, metadataCachingTtl); } @@ -235,6 +239,12 @@ public ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle return delegate.getSplits(session, tableHandle); } + @Override + public ConnectorSplitSource getSplits(ConnectorSession session, JdbcProcedureHandle procedureHandle) + { + return delegate.getSplits(session, procedureHandle); + } + @Override public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) throws SQLException @@ -242,6 +252,13 @@ public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcT return delegate.getConnection(session, split, tableHandle); } + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return delegate.getConnection(session, split, procedureHandle); + } + @Override public void abortReadConnection(Connection connection, ResultSet resultSet) throws SQLException @@ -267,6 +284,13 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return delegate.buildSql(session, connection, split, table, columns); } + @Override + public CallableStatement buildProcedure(ConnectorSession session, Connection connection, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return delegate.buildProcedure(session, connection, split, procedureHandle); + } + @Override public Optional implementJoin( ConnectorSession session, @@ -327,6 +351,13 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery pr return get(tableHandlesByQueryCache, key, () -> delegate.getTableHandle(session, preparedQuery)); } + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + { + ProcedureHandlesByQueryCacheKey key = new ProcedureHandlesByQueryCacheKey(getIdentityKey(session), procedureInformation); + return get(procedureHandlesByQueryCache, key, () -> delegate.getProcedureHandle(session, procedureInformation)); + } + @Override public void commitCreateTable(ConnectorSession session, JdbcOutputTableHandle handle, Set pageSinkIds) { @@ -637,6 +668,12 @@ CacheStats getTableHandlesByQueryCacheStats() return tableHandlesByQueryCache.stats(); } + @VisibleForTesting + CacheStats getProcedureHandlesByQueryCacheStats() + { + return procedureHandlesByQueryCache.stats(); + } + @VisibleForTesting CacheStats getColumnsCacheStats() { @@ -687,6 +724,15 @@ private record TableHandlesByQueryCacheKey(IdentityCacheKey identity, PreparedQu } } + private record ProcedureHandlesByQueryCacheKey(IdentityCacheKey identity, ProcedureInformation procedureInformation) + { + private ProcedureHandlesByQueryCacheKey + { + requireNonNull(identity, "identity is null"); + requireNonNull(procedureInformation, "procedureInformation is null"); + } + } + private record TableNamesCacheKey(IdentityCacheKey identity, Optional schemaName) { private TableNamesCacheKey diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 41496a7b2e26..9c625adbcfc2 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -18,6 +18,8 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.trino.plugin.jdbc.PredicatePushdownController.DomainPushdownResult; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureFunctionHandle; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.plugin.jdbc.ptf.Query.QueryFunctionHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -141,6 +143,12 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery pr return jdbcClient.getTableHandle(session, preparedQuery); } + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + { + return jdbcClient.getProcedureHandle(session, procedureInformation); + } + @Override public Optional getSystemTable(ConnectorSession session, SchemaTableName tableName) { @@ -150,6 +158,10 @@ public Optional getSystemTable(ConnectorSession session, SchemaTabl @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + JdbcTableHandle handle = (JdbcTableHandle) table; if (handle.getSortOrder().isPresent() && handle.getLimit().isPresent()) { handle = flushAttributesAsQuery(session, handle); @@ -259,6 +271,10 @@ public Optional> applyProjecti List projections, Map assignments) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + JdbcTableHandle handle = (JdbcTableHandle) table; List newColumns = assignments.values().stream() @@ -309,6 +325,10 @@ public Optional> applyAggrega Map assignments, List> groupingSets) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + if (!isAggregationPushdownEnabled(session)) { return Optional.empty(); } @@ -412,6 +432,10 @@ public Optional> applyJoin( Map rightAssignments, JoinStatistics statistics) { + if (left instanceof JdbcProcedureHandle || right instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + if (!isJoinPushdownEnabled(session)) { return Optional.empty(); } @@ -521,6 +545,10 @@ private static PreparedQuery asPreparedQuery(JdbcTableHandle tableHandle) @Override public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + JdbcTableHandle handle = (JdbcTableHandle) table; if (limit > Integer.MAX_VALUE) { @@ -558,6 +586,10 @@ public Optional> applyTopN( List sortItems, Map assignments) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + if (!isTopNPushdownEnabled(session)) { return Optional.empty(); } @@ -603,11 +635,17 @@ public Optional> applyTopN( @Override public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) { - if (!(handle instanceof QueryFunctionHandle)) { - return Optional.empty(); + if (handle instanceof QueryFunctionHandle queryFunctionHandle) { + return Optional.of(getTableFunctionApplicationResult(session, queryFunctionHandle.getTableHandle())); + } + if (handle instanceof ProcedureFunctionHandle procedureFunctionHandle) { + return Optional.of(getTableFunctionApplicationResult(session, procedureFunctionHandle.getTableHandle())); } + return Optional.empty(); + } - ConnectorTableHandle tableHandle = ((QueryFunctionHandle) handle).getTableHandle(); + private TableFunctionApplicationResult getTableFunctionApplicationResult(ConnectorSession session, ConnectorTableHandle tableHandle) + { ConnectorTableSchema tableSchema = getTableSchema(session, tableHandle); Map columnHandlesByName = getColumnHandles(session, tableHandle); List columnHandles = tableSchema.getColumns().stream() @@ -615,12 +653,16 @@ public Optional> applyTable .map(columnHandlesByName::get) .collect(toImmutableList()); - return Optional.of(new TableFunctionApplicationResult<>(tableHandle, columnHandles)); + return new TableFunctionApplicationResult<>(tableHandle, columnHandles); } @Override public Optional applyTableScanRedirect(ConnectorSession session, ConnectorTableHandle table) { + if (table instanceof JdbcProcedureHandle) { + return Optional.empty(); + } + JdbcTableHandle tableHandle = (JdbcTableHandle) table; return jdbcClient.getTableScanRedirection(session, tableHandle); } @@ -628,6 +670,14 @@ public Optional applyTableScanRedirect(Conne @Override public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTableHandle table) { + if (table instanceof JdbcProcedureHandle procedureHandle) { + return new ConnectorTableSchema( + getSchemaTableNameForProcedureHandle(), + procedureHandle.getColumns().orElseThrow().stream() + .map(JdbcColumnHandle::getColumnSchema) + .collect(toImmutableList())); + } + JdbcTableHandle handle = (JdbcTableHandle) table; return new ConnectorTableSchema( @@ -640,6 +690,14 @@ public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTa @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) { + if (table instanceof JdbcProcedureHandle procedureHandle) { + return new ConnectorTableMetadata( + getSchemaTableNameForProcedureHandle(), + procedureHandle.getColumns().orElseThrow().stream() + .map(JdbcColumnHandle::getColumnMetadata) + .collect(toImmutableList())); + } + JdbcTableHandle handle = (JdbcTableHandle) table; return new ConnectorTableMetadata( @@ -659,6 +717,12 @@ public static SchemaTableName getSchemaTableName(JdbcTableHandle handle) : new SchemaTableName("_generated", "_generated_query"); } + private static SchemaTableName getSchemaTableNameForProcedureHandle() + { + // TODO (https://github.com/trinodb/trino/issues/6694) SchemaTableName should not be required for synthetic JdbcProcedureHandle + return new SchemaTableName("_generated", "_generated_procedure"); + } + public static Optional getTableComment(JdbcTableHandle handle) { return handle.isNamedRelation() ? handle.getRequiredNamedRelation().getComment() : Optional.empty(); @@ -673,6 +737,11 @@ public List listTables(ConnectorSession session, Optional getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { + if (tableHandle instanceof JdbcProcedureHandle procedureHandle) { + return procedureHandle.getColumns().orElseThrow().stream() + .collect(toImmutableMap(columnHandle -> columnHandle.getColumnMetadata().getName(), identity())); + } + return jdbcClient.getColumns(session, (JdbcTableHandle) tableHandle).stream() .collect(toImmutableMap(columnHandle -> columnHandle.getColumnMetadata().getName(), identity())); } @@ -824,6 +893,7 @@ public Optional finishInsert(ConnectorSession session, @Override public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { + verify(!(tableHandle instanceof JdbcProcedureHandle), "Not a table reference: %s", tableHandle); // The column is used for row-level merge, which is not supported, but it's required during analysis anyway. return new JdbcColumnHandle( "$merge_row_id", @@ -834,6 +904,7 @@ public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, Connecto @Override public Optional applyDelete(ConnectorSession session, ConnectorTableHandle handle) { + verify(!(handle instanceof JdbcProcedureHandle), "Not a table reference: %s", handle); return Optional.of(handle); } @@ -920,6 +991,9 @@ public void setTableProperties(ConnectorSession session, ConnectorTableHandle ta @Override public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle) { + if (tableHandle instanceof JdbcProcedureHandle) { + return TableStatistics.empty(); + } JdbcTableHandle handle = (JdbcTableHandle) tableHandle; return jdbcClient.getTableStatistics(session, handle); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index c6b382fe886f..ebbeedaa1c7a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -19,7 +19,9 @@ import com.google.inject.Inject; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.JoinType; @@ -29,6 +31,7 @@ import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.Type; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; @@ -208,6 +211,19 @@ else if (javaType == Slice.class) { return statement; } + @Override + public ProcedureQuery createProcedureQuery(JdbcClient client, ConnectorSession session, Connection connection, ProcedureInformation procedureInformation) + { + throw new UnsupportedOperationException(); + } + + @Override + public CallableStatement callProcedure(JdbcClient client, ConnectorSession session, Connection connection, ProcedureQuery procedureQuery) + throws SQLException + { + return connection.prepareCall(procedureQuery.query()); + } + protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) { return format( diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 531353a15dd4..3859c322d5c4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -29,6 +30,7 @@ import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.Type; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -90,6 +92,12 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery pr return delegate().getTableHandle(session, preparedQuery); } + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + { + return delegate().getProcedureHandle(session, procedureInformation); + } + @Override public List getColumns(ConnectorSession session, JdbcTableHandle tableHandle) { @@ -138,6 +146,12 @@ public ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle return delegate().getSplits(session, layoutHandle); } + @Override + public ConnectorSplitSource getSplits(ConnectorSession session, JdbcProcedureHandle procedureHandle) + { + return delegate().getSplits(session, procedureHandle); + } + @Override public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) throws SQLException @@ -145,6 +159,13 @@ public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcT return delegate().getConnection(session, split, tableHandle); } + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return delegate().getConnection(session, split, procedureHandle); + } + @Override public void abortReadConnection(Connection connection, ResultSet resultSet) throws SQLException @@ -170,6 +191,13 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return delegate().buildSql(session, connection, split, tableHandle, columnHandles); } + @Override + public CallableStatement buildProcedure(ConnectorSession session, Connection connection, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return delegate().buildProcedure(session, connection, split, procedureHandle); + } + @Override public Optional implementJoin( ConnectorSession session, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index d5e683d8c4bc..11cfd6df5532 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -30,6 +31,7 @@ import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.Type; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -57,6 +59,8 @@ default boolean schemaExists(ConnectorSession session, String schema) JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery preparedQuery); + JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation); + List getColumns(ConnectorSession session, JdbcTableHandle tableHandle); Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle); @@ -85,9 +89,14 @@ default Optional convertPredicate(ConnectorSession session, ConnectorExp ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle tableHandle); + ConnectorSplitSource getSplits(ConnectorSession session, JdbcProcedureHandle procedureHandle); + Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) throws SQLException; + Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException; + default void abortReadConnection(Connection connection, ResultSet resultSet) throws SQLException { @@ -104,6 +113,9 @@ PreparedQuery prepareQuery( PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException; + CallableStatement buildProcedure(ConnectorSession session, Connection connection, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException; + Optional implementJoin( ConnectorSession session, JoinType joinType, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDynamicFilteringSplitManager.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDynamicFilteringSplitManager.java index 39f1d0a69208..39e514bd2f80 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDynamicFilteringSplitManager.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcDynamicFilteringSplitManager.java @@ -69,6 +69,11 @@ public ConnectorSplitSource getSplits( DynamicFilter dynamicFilter, Constraint constraint) { + // JdbcProcedureHandle doesn't support any pushdown operation, so we rely on delegateSplitManager + if (table instanceof JdbcProcedureHandle) { + return delegateSplitManager.getSplits(transaction, session, table, dynamicFilter, constraint); + } + JdbcTableHandle tableHandle = (JdbcTableHandle) table; // pushing DF through limit could reduce query performance boolean hasLimit = tableHandle.getLimit().isPresent(); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java index 10104eceec50..9fea5958b586 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorSession; @@ -21,5 +22,7 @@ public interface JdbcMetadata { JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery preparedQuery); + JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureQuery); + void rollback(); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcProcedureHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcProcedureHandle.java new file mode 100644 index 000000000000..a2fda6bdcb59 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcProcedureHandle.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Optional; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class JdbcProcedureHandle + extends BaseJdbcConnectorTableHandle +{ + private final ProcedureQuery procedure; + private final List columns; + + @JsonCreator + public JdbcProcedureHandle(@JsonProperty ProcedureQuery procedureQuery, @JsonProperty List columns) + { + this.procedure = requireNonNull(procedureQuery, "procedureQuery is null"); + this.columns = requireNonNull(columns, "columns is null"); + } + + @JsonProperty + public ProcedureQuery getProcedureQuery() + { + return procedure; + } + + @Override + @JsonProperty + public Optional> getColumns() + { + return Optional.of(columns); + } + + @Override + public String toString() + { + return format("Procedure[%s], Columns=%s", procedure, columns); + } + + public record ProcedureQuery(@JsonProperty String query) + { + @JsonCreator + public ProcedureQuery + { + requireNonNull(query, "query is null"); + } + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java index ee47a25b4d65..a785cbedf3c0 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java @@ -63,7 +63,7 @@ public class JdbcRecordCursor private ResultSet resultSet; private boolean closed; - public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, JdbcTableHandle table, List columnHandles) + public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List columnHandles) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); this.executor = requireNonNull(executor, "executor is null"); @@ -78,7 +78,12 @@ public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, Connect objectReadFunctions = new ObjectReadFunction[columnHandles.size()]; try { - connection = jdbcClient.getConnection(session, split, table); + if (table instanceof JdbcProcedureHandle procedureHandle) { + connection = jdbcClient.getConnection(session, split, procedureHandle); + } + else { + connection = jdbcClient.getConnection(session, split, (JdbcTableHandle) table); + } for (int i = 0; i < this.columnHandles.length; i++) { JdbcColumnHandle columnHandle = columnHandles.get(i); @@ -109,7 +114,12 @@ else if (javaType == Slice.class) { } } - statement = jdbcClient.buildSql(session, connection, split, table, columnHandles); + if (table instanceof JdbcProcedureHandle procedureHandle) { + statement = jdbcClient.buildProcedure(session, connection, split, procedureHandle); + } + else { + statement = jdbcClient.buildSql(session, connection, split, (JdbcTableHandle) table, columnHandles); + } } catch (SQLException | RuntimeException e) { throw handleSqlException(e); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java index 5bb17274463b..9d3d3269055a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java @@ -29,13 +29,13 @@ public class JdbcRecordSet { private final JdbcClient jdbcClient; private final ExecutorService executor; - private final JdbcTableHandle table; + private final BaseJdbcConnectorTableHandle table; private final List columnHandles; private final List columnTypes; private final JdbcSplit split; private final ConnectorSession session; - public JdbcRecordSet(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, JdbcTableHandle table, List columnHandles) + public JdbcRecordSet(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List columnHandles) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); this.executor = requireNonNull(executor, "executor is null"); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java index 6d87a2bd2272..b79957aad799 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.plugin.base.MappedRecordSet; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorRecordSetProvider; import io.trino.spi.connector.ConnectorSession; @@ -26,10 +27,15 @@ import javax.inject.Inject; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; +import java.util.stream.IntStream; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; +import static java.util.function.UnaryOperator.identity; public class JdbcRecordSetProvider implements ConnectorRecordSetProvider @@ -48,7 +54,7 @@ public JdbcRecordSetProvider(JdbcClient jdbcClient, @ForRecordCursor ExecutorSer public RecordSet getRecordSet(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorSplit split, ConnectorTableHandle table, List columns) { JdbcSplit jdbcSplit = (JdbcSplit) split; - JdbcTableHandle jdbcTable = (JdbcTableHandle) table; + BaseJdbcConnectorTableHandle jdbcTable = (BaseJdbcConnectorTableHandle) table; // In the current API, the columns (and order) needed by the engine are provided via an argument to this method. Make sure we can // satisfy the requirements using columns which were recorded in the table handle. @@ -57,17 +63,38 @@ public RecordSet getRecordSet(ConnectorTransactionHandle transaction, ConnectorS jdbcTable.getColumns() .ifPresent(tableColumns -> verify(ImmutableSet.copyOf(tableColumns).containsAll(columns))); - ImmutableList.Builder handles = ImmutableList.builderWithExpectedSize(columns.size()); - for (ColumnHandle handle : columns) { - handles.add((JdbcColumnHandle) handle); + if (jdbcTable instanceof JdbcTableHandle jdbcTableHandle) { + ImmutableList.Builder handles = ImmutableList.builderWithExpectedSize(columns.size()); + for (ColumnHandle handle : columns) { + handles.add((JdbcColumnHandle) handle); + } + + return new JdbcRecordSet( + jdbcClient, + executor, + session, + jdbcSplit, + jdbcTableHandle.intersectedWithConstraint(jdbcSplit.getDynamicFilter().transformKeys(ColumnHandle.class::cast)), + handles.build()); } + JdbcProcedureHandle procedureHandle = (JdbcProcedureHandle) jdbcTable; + List sourceColumns = procedureHandle.getColumns().orElseThrow(); + + Map columnIndexMap = IntStream.range(0, sourceColumns.size()) + .boxed() + .collect(toImmutableMap(sourceColumns::get, identity())); - return new JdbcRecordSet( - jdbcClient, - executor, - session, - jdbcSplit, - jdbcTable.intersectedWithConstraint(jdbcSplit.getDynamicFilter().transformKeys(ColumnHandle.class::cast)), - handles.build()); + return new MappedRecordSet( + new JdbcRecordSet( + jdbcClient, + executor, + session, + jdbcSplit, + procedureHandle, + sourceColumns), + columns.stream() + .map(JdbcColumnHandle.class::cast) + .map(columnIndexMap::get) + .collect(toImmutableList())); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcSplitManager.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcSplitManager.java index 2eea16588d8d..cbb6082116b5 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcSplitManager.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcSplitManager.java @@ -45,6 +45,10 @@ public ConnectorSplitSource getSplits( DynamicFilter dynamicFilter, Constraint constraint) { + if (table instanceof JdbcProcedureHandle procedureHandle) { + return jdbcClient.getSplits(session, procedureHandle); + } + JdbcTableHandle tableHandle = (JdbcTableHandle) table; ConnectorSplitSource jdbcSplitSource = jdbcClient.getSplits(session, tableHandle); if (dynamicFilteringEnabled(session)) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java index 44112d9eb81a..5ee04fe49034 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcTableHandle.java @@ -19,7 +19,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.ColumnHandle; -import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.TupleDomain; @@ -34,7 +33,7 @@ import static java.util.Objects.requireNonNull; public final class JdbcTableHandle - implements ConnectorTableHandle + extends BaseJdbcConnectorTableHandle { private final JdbcRelationHandle relationHandle; @@ -149,6 +148,7 @@ public OptionalLong getLimit() return limit; } + @Override @JsonProperty public Optional> getColumns() { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index d934c1d9387e..790c40d7aca9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -13,11 +13,14 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.JoinType; import io.trino.spi.predicate.TupleDomain; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; @@ -63,4 +66,18 @@ PreparedStatement prepareStatement( Connection connection, PreparedQuery preparedQuery) throws SQLException; + + ProcedureQuery createProcedureQuery( + JdbcClient client, + ConnectorSession session, + Connection connection, + ProcedureInformation procedureInformation) + throws SQLException; + + CallableStatement callProcedure( + JdbcClient client, + ConnectorSession session, + Connection connection, + ProcedureQuery procedureQuery) + throws SQLException; } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java index 799ba74e9d71..1ed0947a3de2 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java @@ -25,6 +25,7 @@ public final class JdbcClientStats private final JdbcApiStats buildInsertSql = new JdbcApiStats(); private final JdbcApiStats prepareQuery = new JdbcApiStats(); private final JdbcApiStats buildSql = new JdbcApiStats(); + private final JdbcApiStats buildProcedure = new JdbcApiStats(); private final JdbcApiStats implementJoin = new JdbcApiStats(); private final JdbcApiStats commitCreateTable = new JdbcApiStats(); private final JdbcApiStats createSchema = new JdbcApiStats(); @@ -40,11 +41,14 @@ public final class JdbcClientStats private final JdbcApiStats getColumns = new JdbcApiStats(); private final JdbcApiStats getConnectionWithHandle = new JdbcApiStats(); private final JdbcApiStats getConnectionWithSplit = new JdbcApiStats(); + private final JdbcApiStats getConnectionWithProcedure = new JdbcApiStats(); private final JdbcApiStats getPreparedStatement = new JdbcApiStats(); private final JdbcApiStats getSchemaNames = new JdbcApiStats(); private final JdbcApiStats getSplits = new JdbcApiStats(); + private final JdbcApiStats getSplitsForProcedure = new JdbcApiStats(); private final JdbcApiStats getTableHandle = new JdbcApiStats(); private final JdbcApiStats getTableHandleForQuery = new JdbcApiStats(); + private final JdbcApiStats getProcedureHandle = new JdbcApiStats(); private final JdbcApiStats getTableNames = new JdbcApiStats(); private final JdbcApiStats getTableStatistics = new JdbcApiStats(); private final JdbcApiStats renameColumn = new JdbcApiStats(); @@ -111,6 +115,13 @@ public JdbcApiStats getBuildSql() return buildSql; } + @Managed + @Nested + public JdbcApiStats getBuildProcedure() + { + return buildProcedure; + } + @Managed @Nested public JdbcApiStats getImplementJoin() @@ -216,6 +227,13 @@ public JdbcApiStats getGetConnectionWithSplit() return getConnectionWithSplit; } + @Managed + @Nested + public JdbcApiStats getGetConnectionWithProcedure() + { + return getConnectionWithProcedure; + } + @Managed @Nested public JdbcApiStats getGetPreparedStatement() @@ -237,6 +255,13 @@ public JdbcApiStats getGetSplits() return getSplits; } + @Managed + @Nested + public JdbcApiStats getGetSplitsForProcedure() + { + return getSplitsForProcedure; + } + @Managed @Nested public JdbcApiStats getGetTableHandle() @@ -251,6 +276,13 @@ public JdbcApiStats getGetTableHandleForQuery() return getTableHandleForQuery; } + @Managed + @Nested + public JdbcApiStats getGetProcedureHandle() + { + return getProcedureHandle; + } + @Managed @Nested public JdbcApiStats getGetTableNames() diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 0e26d23bd741..5c3a83b7354f 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -19,6 +19,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; +import io.trino.plugin.jdbc.JdbcProcedureHandle; import io.trino.plugin.jdbc.JdbcSortItem; import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcTableHandle; @@ -27,6 +28,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -45,6 +47,7 @@ import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -110,6 +113,12 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery pr return stats.getGetTableHandleForQuery().wrap(() -> delegate().getTableHandle(session, preparedQuery)); } + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + { + return stats.getGetProcedureHandle().wrap(() -> delegate().getProcedureHandle(session, procedureInformation)); + } + @Override public List getColumns(ConnectorSession session, JdbcTableHandle tableHandle) { @@ -158,6 +167,12 @@ public ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle return stats.getGetSplits().wrap(() -> delegate().getSplits(session, layoutHandle)); } + @Override + public ConnectorSplitSource getSplits(ConnectorSession session, JdbcProcedureHandle procedureHandle) + { + return stats.getGetSplitsForProcedure().wrap(() -> delegate().getSplits(session, procedureHandle)); + } + @Override public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) throws SQLException @@ -165,6 +180,13 @@ public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcT return stats.getGetConnectionWithSplit().wrap(() -> delegate().getConnection(session, split, tableHandle)); } + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return stats.getGetConnectionWithProcedure().wrap(() -> delegate().getConnection(session, split, procedureHandle)); + } + @Override public void abortReadConnection(Connection connection, ResultSet resultSet) throws SQLException @@ -190,6 +212,13 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return stats.getBuildSql().wrap(() -> delegate().buildSql(session, connection, split, tableHandle, columnHandles)); } + @Override + public CallableStatement buildProcedure(ConnectorSession session, Connection connection, JdbcSplit split, JdbcProcedureHandle procedureHandle) + throws SQLException + { + return stats.getBuildProcedure().wrap(() -> delegate().buildProcedure(session, connection, split, procedureHandle)); + } + @Override public Optional implementJoin(ConnectorSession session, JoinType joinType, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java new file mode 100644 index 000000000000..37c403d7245a --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ptf/Procedure.java @@ -0,0 +1,158 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.ptf; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorTableFunction; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcMetadata; +import io.trino.plugin.jdbc.JdbcProcedureHandle; +import io.trino.plugin.jdbc.JdbcTransactionManager; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.predicate.NullableValue; +import io.trino.spi.ptf.AbstractConnectorTableFunction; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.ConnectorTableFunction; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; +import io.trino.spi.ptf.Descriptor; +import io.trino.spi.ptf.Descriptor.Field; +import io.trino.spi.ptf.ScalarArgument; +import io.trino.spi.ptf.ScalarArgumentSpecification; +import io.trino.spi.ptf.TableFunctionAnalysis; +import io.trino.spi.type.ArrayType; + +import javax.inject.Inject; +import javax.inject.Provider; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.ptf.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.util.Objects.requireNonNull; + +public class Procedure + implements Provider +{ + public static final String SCHEMA_NAME = "system"; + public static final String NAME = "procedure"; + + private final JdbcTransactionManager transactionManager; + + @Inject + public Procedure(JdbcTransactionManager transactionManager) + { + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + } + + @Override + public ConnectorTableFunction get() + { + return new ClassLoaderSafeConnectorTableFunction(new ProcedureFunction(transactionManager), getClass().getClassLoader()); + } + + public static class ProcedureFunction + extends AbstractConnectorTableFunction + { + private final JdbcTransactionManager transactionManager; + + public ProcedureFunction(JdbcTransactionManager transactionManager) + { + super( + SCHEMA_NAME, + NAME, + List.of( + ScalarArgumentSpecification.builder() + .name("SCHEMA") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("PROCEDURE") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("INPUTS") + .type(new ArrayType(VARCHAR)) + .defaultValue(null) + .build()), + GENERIC_TABLE); + this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + String schema = ((Slice) ((ScalarArgument) arguments.get("SCHEMA")).getValue()).toStringUtf8(); + String procedure = ((Slice) ((ScalarArgument) arguments.get("PROCEDURE")).getValue()).toStringUtf8(); + ScalarArgument inputArgument = ((ScalarArgument) arguments.get("INPUTS")); + NullableValue inputs = inputArgument.getNullableValue(); + List inputArguments = ImmutableList.of(); + + if (!inputs.isNull()) { + inputArguments = ((List) inputArgument.getType().getObjectValue(session, inputs.asBlock(), 0)).stream() + .map(String.class::cast) + .collect(toImmutableList()); + } + + JdbcMetadata metadata = transactionManager.getMetadata(transaction); + JdbcProcedureHandle tableHandle = metadata.getProcedureHandle(session, new ProcedureInformation(schema, procedure, inputArguments)); + List columns = tableHandle.getColumns().orElseThrow(() -> new IllegalStateException("Handle doesn't have columns info")); + Descriptor returnedType = new Descriptor(columns.stream() + .map(column -> new Field(column.getColumnName(), Optional.of(column.getColumnType()))) + .collect(toImmutableList())); + + ProcedureFunctionHandle handle = new ProcedureFunctionHandle(tableHandle); + + return TableFunctionAnalysis.builder() + .returnedType(returnedType) + .handle(handle) + .build(); + } + } + + public static class ProcedureFunctionHandle + implements ConnectorTableFunctionHandle + { + private final JdbcProcedureHandle tableHandle; + + @JsonCreator + public ProcedureFunctionHandle(@JsonProperty("tableHandle") JdbcProcedureHandle tableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + } + + @JsonProperty + public ConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + + public record ProcedureInformation(String schemaName, String procedureName, List inputArguments) + { + public ProcedureInformation + { + requireNonNull(schemaName, "schemaName is null"); + requireNonNull(procedureName, "procedureName is null"); + inputArguments = ImmutableList.copyOf(requireNonNull(inputArguments, "inputArguments is null")); + } + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestCachingJdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestCachingJdbcClient.java index 97a143f645f6..960bf440b040 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestCachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestCachingJdbcClient.java @@ -20,6 +20,7 @@ import io.airlift.units.Duration; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.jdbc.credential.ExtraCredentialConfig; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; @@ -35,6 +36,10 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -306,6 +311,39 @@ public void testTableHandleOfQueryCached() dropTable(phantomTable); } + @Test + public void testProcedureHandleCached() + throws Exception + { + SchemaTableName phantomTable = new SchemaTableName(schema, "phantom_table"); + + createTable(phantomTable); + createProcedure("test_procedure"); + + ProcedureInformation query = new ProcedureInformation( + schema, + "test_procedure", + ImmutableList.of("'" + phantomTable + "'")); + JdbcProcedureHandle cachedProcedure = assertProcedureHandleByQueryCache(cachingJdbcClient) + .misses(1) + .loads(1) + .calling(() -> cachingJdbcClient.getProcedureHandle(SESSION, query)); + assertThat(cachedProcedure.getColumns().orElseThrow()) + .hasSize(0); + + dropProcedure("test_procedure"); + + assertThatThrownBy(() -> jdbcClient.getProcedureHandle(SESSION, query)) + .hasMessageContaining("Failed to get table handle for procedure query"); + + assertProcedureHandleByQueryCache(cachingJdbcClient) + .hits(1) + .afterRunning(() -> assertThat(cachingJdbcClient.getProcedureHandle(SESSION, query)) + .isEqualTo(cachedProcedure)); + + dropTable(phantomTable); + } + @Test public void testTableHandleInvalidatedOnColumnsModifications() { @@ -369,6 +407,29 @@ private JdbcTableHandle createTable(SchemaTableName phantomTable) return jdbcClient.getTableHandle(SESSION, phantomTable).orElseThrow(); } + private void createProcedure(String procedureName) + throws SQLException + { + try (Statement statement = database.getConnection().createStatement()) { + statement.execute("CREATE ALIAS %s.%s FOR \"io.trino.plugin.jdbc.TestCachingJdbcClient.generateData\"".formatted(schema, procedureName)); + } + } + + private void dropProcedure(String procedureName) + throws SQLException + { + try (Statement statement = database.getConnection().createStatement()) { + statement.execute("DROP ALIAS %s.%s".formatted(schema, procedureName)); + } + } + + // Used by H2 for executing Stored Procedure + public static ResultSet generateData(Connection connection, String table) + throws SQLException + { + return connection.createStatement().executeQuery("SELECT * FROM " + table); + } + private void dropTable(JdbcTableHandle tableHandle) { jdbcClient.dropTable(SESSION, tableHandle); @@ -1044,6 +1105,11 @@ private static SingleJdbcCacheStatsAssertions assertTableHandleByQueryCache(Cach return assertCacheStats(client, CachingJdbcCache.TABLE_HANDLES_BY_QUERY_CACHE); } + private static SingleJdbcCacheStatsAssertions assertProcedureHandleByQueryCache(CachingJdbcClient client) + { + return assertCacheStats(client, CachingJdbcCache.PROCEDURE_HANDLES_BY_QUERY_CACHE); + } + private static SingleJdbcCacheStatsAssertions assertColumnCacheStats(CachingJdbcClient client) { return assertCacheStats(client, CachingJdbcCache.COLUMNS_CACHE); @@ -1178,6 +1244,7 @@ enum CachingJdbcCache TABLE_NAMES_CACHE(CachingJdbcClient::getTableNamesCacheStats), TABLE_HANDLES_BY_NAME_CACHE(CachingJdbcClient::getTableHandlesByNameCacheStats), TABLE_HANDLES_BY_QUERY_CACHE(CachingJdbcClient::getTableHandlesByQueryCacheStats), + PROCEDURE_HANDLES_BY_QUERY_CACHE(CachingJdbcClient::getProcedureHandlesByQueryCacheStats), COLUMNS_CACHE(CachingJdbcClient::getColumnsCacheStats), STATISTICS_CACHE(CachingJdbcClient::getStatisticsCacheStats), /**/; diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestProcedure.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestProcedure.java new file mode 100644 index 000000000000..8c211e14600e --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestProcedure.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TemporaryRelation; + +import static java.util.Objects.requireNonNull; + +public class TestProcedure + implements TemporaryRelation +{ + protected final SqlExecutor sqlExecutor; + protected final String name; + + public TestProcedure(SqlExecutor sqlExecutor, String name, String createProcedureTemplate) + { + this.sqlExecutor = requireNonNull(sqlExecutor, "sqlExecutor is null"); + this.name = requireNonNull(name, "name is null"); + sqlExecutor.execute(createProcedureTemplate.formatted(name)); + } + + @Override + public String getName() + { + return name; + } + + @Override + public void close() + { + sqlExecutor.execute("DROP PROCEDURE " + name); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java index 37d2e962b693..0d1add80c8fa 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java @@ -18,12 +18,14 @@ import dev.failsafe.RetryPolicy; import io.airlift.log.Logger; import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; import io.trino.plugin.jdbc.aggregation.ImplementCountAll; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteVariable; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.plugin.jdbc.mapping.IdentifierMapping; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -34,14 +36,19 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import java.sql.CallableStatement; import java.sql.Connection; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; import java.sql.Types; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.google.common.base.MoreObjects.firstNonNull; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; @@ -93,7 +100,7 @@ public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFa public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) { - super(config, "\"", connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), identifierMapping, RemoteQueryModifier.NONE); + super(config, "\"", connectionFactory, new TestingH2QueryBuilder(RemoteQueryModifier.NONE), identifierMapping, RemoteQueryModifier.NONE); } @Override @@ -250,4 +257,29 @@ protected void renameTable(ConnectorSession session, String catalogName, String } super.renameTable(session, catalogName, schemaName, tableName, newTable); } + + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + + { + try (Connection connection = connectionFactory.openConnection(session)) { + ProcedureQuery procedureQuery = queryBuilder.createProcedureQuery(this, session, connection, procedureInformation); + + try (CallableStatement statement = queryBuilder.callProcedure(this, session, connection, procedureQuery); + ResultSet resultSet = statement.executeQuery()) { + ResultSetMetaData metadata = resultSet.getMetaData(); + if (metadata == null) { + throw new TrinoException(NOT_SUPPORTED, "Procedure not supported: ResultSetMetaData not available for query: " + procedureQuery.query()); + } + JdbcProcedureHandle procedureHandle = new JdbcProcedureHandle(procedureQuery, getColumns(session, connection, metadata)); + if (statement.getMoreResults()) { + throw new TrinoException(NOT_SUPPORTED, "Procedure has multiple ResultSets for query: " + procedureQuery.query()); + } + return procedureHandle; + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, "Failed to get table handle for procedure query. " + firstNonNull(e.getMessage(), e), e); + } + } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2QueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2QueryBuilder.java new file mode 100644 index 000000000000..f77da0c519a8 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2QueryBuilder.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.common.base.Joiner; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.plugin.jdbc.ptf.Procedure; +import io.trino.spi.connector.ConnectorSession; + +import javax.inject.Inject; + +import java.sql.Connection; + +public class TestingH2QueryBuilder + extends DefaultQueryBuilder +{ + @Inject + public TestingH2QueryBuilder(RemoteQueryModifier queryModifier) + { + super(queryModifier); + } + + @Override + public ProcedureQuery createProcedureQuery(JdbcClient client, ConnectorSession session, Connection connection, Procedure.ProcedureInformation procedureInformation) + { + return new JdbcProcedureHandle.ProcedureQuery( + "CALL %s.%s (%s)" + .formatted( + procedureInformation.schemaName(), + procedureInformation.procedureName(), + Joiner.on(", ").join(procedureInformation.inputArguments()))); + } +} diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 4d666570add2..1ea38cf30f0c 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -38,6 +38,8 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; +import io.trino.plugin.jdbc.JdbcProcedureHandle; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; import io.trino.plugin.jdbc.JdbcSortItem; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; @@ -61,6 +63,7 @@ import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -99,6 +102,7 @@ import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Statement; import java.sql.Types; import java.time.Instant; import java.time.LocalDate; @@ -160,6 +164,7 @@ import static io.trino.plugin.sqlserver.SqlServerSessionProperties.isBulkCopyForWriteLockDestinationTable; import static io.trino.plugin.sqlserver.SqlServerTableProperties.DATA_COMPRESSION; import static io.trino.plugin.sqlserver.SqlServerTableProperties.getDataCompression; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -658,6 +663,43 @@ public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHan } } + @Override + public JdbcProcedureHandle getProcedureHandle(ConnectorSession session, ProcedureInformation procedureInformation) + + { + if (procedureInformation.inputArguments().stream().anyMatch(argument -> argument.contains(";"))) { + throw new TrinoException(INVALID_ARGUMENTS, "InputArgument doesn't support special characters like ';'"); + } + + try (Connection connection = connectionFactory.openConnection(session)) { + ProcedureQuery procedureQuery = queryBuilder.createProcedureQuery(this, session, connection, procedureInformation); + + try (Statement statement = connection.createStatement(); + // When FMTONLY is ON , a rowset is returned with the column names for the query + ResultSet resultSet = statement.executeQuery("set fmtonly on %s \nset fmtonly off".formatted(procedureQuery.query()))) { + ResultSetMetaData metadata = resultSet.getMetaData(); + if (metadata == null) { + throw new TrinoException(NOT_SUPPORTED, "Procedure not supported: ResultSetMetaData not available for query: " + procedureQuery.query()); + } + JdbcProcedureHandle procedureHandle = new JdbcProcedureHandle(procedureQuery, getColumns(session, connection, metadata)); + if (statement.getMoreResults()) { + throw new TrinoException(NOT_SUPPORTED, "Procedure has multiple ResultSets for query: " + procedureQuery.query()); + } + // dm_sql_referenced_entities stored procedure provides information about table and columns being updated by a procedure. + try (ResultSet doesStoredProceduresModifiesData = statement.executeQuery("SELECT 1 FROM sys.dm_sql_referenced_entities('%s.%s', 'OBJECT') WHERE is_updated = 1" + .formatted(procedureInformation.schemaName(), procedureInformation.procedureName()))) { + if (doesStoredProceduresModifiesData.next()) { + throw new TrinoException(NOT_SUPPORTED, "Procedure not supported: Procedure updates or inserts data: " + procedureQuery.query()); + } + } + return procedureHandle; + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, "Failed to get table handle for procedure query. " + firstNonNull(e.getMessage(), e), e); + } + } + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table) throws SQLException { diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java index 2cc513a21a0e..eefc9d26c41d 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClientModule.java @@ -28,7 +28,9 @@ import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.credential.CredentialProvider; +import io.trino.plugin.jdbc.ptf.Procedure; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; @@ -50,8 +52,10 @@ protected void setup(Binder binder) binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON); bindTablePropertiesProvider(binder, SqlServerTableProperties.class); bindSessionPropertiesProvider(binder, SqlServerSessionProperties.class); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(SqlServerQueryBuilder.class).in(Scopes.SINGLETON); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Procedure.class).in(Scopes.SINGLETON); install(new JdbcJoinPushdownSupportModule()); } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerQueryBuilder.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerQueryBuilder.java new file mode 100644 index 000000000000..721fffb61dc1 --- /dev/null +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerQueryBuilder.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.sqlserver; + +import com.google.common.base.Joiner; +import io.trino.plugin.jdbc.DefaultQueryBuilder; +import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.plugin.jdbc.ptf.Procedure.ProcedureInformation; +import io.trino.spi.connector.ConnectorSession; + +import javax.inject.Inject; + +import java.sql.Connection; + +public class SqlServerQueryBuilder + extends DefaultQueryBuilder +{ + @Inject + public SqlServerQueryBuilder(RemoteQueryModifier queryModifier) + { + super(queryModifier); + } + + @Override + public ProcedureQuery createProcedureQuery(JdbcClient client, ConnectorSession session, Connection connection, ProcedureInformation procedureInformation) + { + return new ProcedureQuery( + "EXECUTE %s.%s %s" + .formatted( + procedureInformation.schemaName(), + procedureInformation.procedureName(), + Joiner.on(", ").join(procedureInformation.inputArguments()))); + } +} diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index df6f1ce1b7c3..82b0bd3030bb 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -18,7 +18,9 @@ import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.TestProcedure; import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.TestTable; @@ -662,6 +664,328 @@ protected void verifyColumnNameLengthFailurePermissible(Throwable e) assertThat(e).hasMessageMatching("Column name must be shorter than or equal to '128' characters but got '129': '.*'"); } + @Test + public void testSelectFromProcedureFunction() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT * FROM nation WHERE nationkey = 1")) { + assertQuery( + format("SELECT name FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) ".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), getSession().getSchema().orElseThrow()), + "VALUES 'ARGENTINA'"); + } + } + + @Test + public void testSelectFromProcedureFunctionWithInputParameter() + { + try (TestProcedure testProcedure = createTestingProcedure( + "@nationkey bigint, @name varchar(30)", + "SELECT * FROM nation WHERE nationkey = @nationkey AND name = @name")) { + assertQuery( + "SELECT nationkey, name FROM TABLE(system.procedure(schema => '%s', procedure => '%s', inputs => ARRAY['0', 'ALGERIA'])) ".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "VALUES (0, 'ALGERIA')"); + } + } + + @Test + public void testSelectFromProcedureFunctionWithOutputParameter() + { + try (TestProcedure testProcedure = createTestingProcedure("@row_count bigint OUTPUT", "SELECT * FROM nation; SELECT @row_count = @@ROWCOUNT")) { + assertQueryFails( + "SELECT name FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) ".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query\\. Procedure or function '.*' expects parameter '@row_count', which was not supplied\\."); + } + } + + @Test + public void testFilterPushdownRestrictedForProcedureFunction() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT * FROM nation")) { + assertThat(query("SELECT name FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) WHERE nationkey = 0".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()))) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 'ALGERIA'"); + } + } + + @Test + public void testAggregationPushdownRestrictedForProcedureFunction() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT * FROM nation")) { + assertThat(query( + "SELECT COUNT(*) FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()))) + .isNotFullyPushedDown(AggregationNode.class) + .matches("VALUES BIGINT '25'"); + } + } + + @Test + public void testJoinPushdownRestrictedForProcedureFunction() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT * FROM nation")) { + assertThat(query( + joinPushdownEnabled(getSession()), + "SELECT nationkey FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) INNER JOIN nation USING (nationkey) ORDER BY 1 LIMIT 1" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()))) + .joinIsNotFullyPushedDown() + .matches("VALUES BIGINT '0'"); + } + } + + @Test + public void testProcedureWithSingleIfStatement() + { + try (TestProcedure testProcedure = createTestingProcedure( + "@id INTEGER", + """ + IF @id > 50 + SELECT 1 as first_column; + """)) { + assertQuery( + format("SELECT first_column FROM TABLE(system.procedure(schema => '%s', procedure => '%s', inputs => ARRAY['100'])) ".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), getSession().getSchema().orElseThrow()), + "VALUES 1"); + + assertQueryFails( + "SELECT first_column FROM TABLE(system.procedure(schema => '%s', procedure => '%s', inputs => ARRAY['10'])) ".formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "The statement did not return a result set."); + } + } + + @Test + public void testProcedureWithIfElseStatement() + { + try (TestProcedure testProcedure = createTestingProcedure( + "@id INTEGER", + """ + IF @id > 50 + SELECT 1 as first_column; + ELSE + SELECT '2' as second_column; + """)) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s', inputs => ARRAY['100'])) " + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure has multiple ResultSets for query: .*"); + } + } + + @Test + public void testProcedureWithSingleIfElse() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT 1 as first_row; SELECT 2 as second_row")) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) " + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure has multiple ResultSets for query: .*"); + } + } + + @Test + public void testProcedureWithMultipleResultSet() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT 1 as first_row; SELECT 2 as second_row")) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s')) " + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure has multiple ResultSets for query: .*"); + } + } + + @Test + public void testProcedureWithCreateOperation() + { + String tableName = "table_to_create" + randomNameSuffix(); + try (TestProcedure testProcedure = createTestingProcedure("CREATE TABLE %s (id BIGINT)".formatted(tableName))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQueryReturnsEmptyResult("SHOW TABLES LIKE '%s'".formatted(tableName)); + } + } + + @Test + public void testProcedureWithDropOperation() + { + try (TestTable table = new TestTable(onRemoteDatabase(), "table_to_drop", "(id BIGINT)")) { + try (TestProcedure testProcedure = createTestingProcedure("DROP TABLE " + table.getName())) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SHOW TABLES LIKE '%s'".formatted(table.getName()), "VALUES '%s'".formatted(table.getName())); + } + } + } + + @Test + public void testProcedureWithInsertOperation() + { + try (TestTable table = new TestTable(onRemoteDatabase(), "table_to_insert", "(id BIGINT)"); + TestTable tableForInsertedRows = new TestTable(onRemoteDatabase(), "table_to_capture_inserted_rows", "(id BIGINT)")) { + try (TestProcedure testProcedure = createTestingProcedure("INSERT INTO %s VALUES (1)".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName()); + } + + try (TestProcedure testProcedure = createTestingProcedure("INSERT %s OUTPUT INSERTED.* VALUES (1)".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure not supported: Procedure updates or inserts data.*"); + assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName()); + } + + try (TestProcedure testProcedure = createTestingProcedure("INSERT %s OUTPUT INSERTED.*INTO %s VALUES (1)".formatted(table.getName(), tableForInsertedRows.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName()); + assertQueryReturnsEmptyResult("SELECT * FROM " + tableForInsertedRows.getName()); + } + } + } + + @Test + public void testProcedureWithDeleteOperation() + { + try (TestTable table = new TestTable(onRemoteDatabase(), "table_to_delete", "(id BIGINT)", ImmutableList.of("1", "2", "3")); + TestTable tableForDeletedRows = new TestTable(onRemoteDatabase(), "table_to_capture_deleted_rows", "(id BIGINT)")) { + try (TestProcedure testProcedure = createTestingProcedure("DELETE %s".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + } + + try (TestProcedure testProcedure = createTestingProcedure("DELETE %s OUTPUT DELETED.*".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure not supported: Procedure updates or inserts data.*"); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + } + + try (TestProcedure testProcedure = createTestingProcedure("DELETE %s OUTPUT DELETED.* INTO %s".formatted(table.getName(), tableForDeletedRows.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + assertQueryReturnsEmptyResult("SELECT * FROM " + tableForDeletedRows.getName()); + } + } + } + + @Test + public void testProcedureWithUpdateOperation() + { + try (TestTable table = new TestTable(onRemoteDatabase(), "table_to_update", "(id BIGINT)", ImmutableList.of("1", "2", "3")); + TestTable tableForDeletedRows = new TestTable(onRemoteDatabase(), "table_to_capture_update", "(inserted BIGINT, deleted BIGINT)")) { + try (TestProcedure testProcedure = createTestingProcedure("UPDATE %s SET id = 4".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + } + + try (TestProcedure testProcedure = createTestingProcedure("UPDATE %s SET id = 4 OUTPUT DELETED.*, INSERTED.*".formatted(table.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure not supported: Procedure updates or inserts data.*"); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + } + + try (TestProcedure testProcedure = createTestingProcedure("UPDATE %s SET id = 4 OUTPUT DELETED.id as inserted, INSERTED.id as deleted INTO %s".formatted(table.getName(), tableForDeletedRows.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1), (2), (3)"); + assertQueryReturnsEmptyResult("SELECT * FROM " + tableForDeletedRows.getName()); + } + } + } + + @Test + public void testProcedureWithMergeOperation() + { + try (TestTable sourceTable = new TestTable(onRemoteDatabase(), "source_table", "(id BIGINT)", ImmutableList.of("1", "2", "3")); + TestTable targetTable = new TestTable(onRemoteDatabase(), "destination_table", "(id BIGINT)", ImmutableList.of("3", "4", "5")); + TestTable tableForChangedRows = new TestTable(onRemoteDatabase(), "table_to_capture_change", "(inserted BIGINT, deleted BIGINT)")) { + String mergeQuery = """ + MERGE %s AS TARGET USING %s AS SOURCE + ON (TARGET.id = SOURCE.id) + WHEN NOT MATCHED BY TARGET + THEN INSERT(id) VALUES(SOURCE.id) + WHEN NOT MATCHED BY SOURCE + THEN DELETE + """.formatted(targetTable.getName(), sourceTable.getName()); + try (TestProcedure testProcedure = createTestingProcedure(mergeQuery + ";")) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + targetTable.getName(), "VALUES (3), (4), (5)"); + } + + try (TestProcedure testProcedure = createTestingProcedure(mergeQuery + "OUTPUT DELETED.*, INSERTED.* ;")) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Procedure not supported: Procedure updates or inserts data.*"); + assertQuery("SELECT * FROM " + targetTable.getName(), "VALUES (3), (4), (5)"); + } + + try (TestProcedure testProcedure = createTestingProcedure(mergeQuery + "OUTPUT DELETED.id as inserted, INSERTED.id as deleted INTO %s ;".formatted(tableForChangedRows.getName()))) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s'))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "Failed to get table handle for procedure query. The statement did not return a result set."); + assertQuery("SELECT * FROM " + targetTable.getName(), "VALUES (3), (4), (5)"); + assertQueryReturnsEmptyResult("SELECT * FROM " + tableForChangedRows.getName()); + } + } + } + + @Test + public void testCustomSqlStatementPassedAsArguments() + { + try (TestProcedure testProcedure = createTestingProcedure("SELECT 1 as colA ")) { + assertQueryFails( + "SELECT * FROM TABLE(system.procedure(schema => '%s', procedure => '%s', inputs => ARRAY['; DROP TABLE nation']))" + .formatted(getSession().getSchema().orElseThrow(), testProcedure.getName()), + "InputArgument doesn't support special characters like ';'"); + assertQuery("SHOW TABLES LIKE 'nation'", "VALUES 'nation'"); + } + } + + private TestProcedure createTestingProcedure(String baseQuery) + { + return createTestingProcedure("", baseQuery); + } + + private TestProcedure createTestingProcedure(String inputArguments, String baseQuery) + { + String procedureName = "procedure" + randomNameSuffix(); + return new TestProcedure( + onRemoteDatabase(), + procedureName, + """ + CREATE PROCEDURE %s.%s %s + AS BEGIN + %s + END + """.formatted(getSession().getSchema().orElseThrow(), procedureName, inputArguments, baseQuery)); + } + private String getLongInClause(int start, int length) { String longValues = range(start, start + length) diff --git a/plugin/trino-thrift-testing-server/pom.xml b/plugin/trino-thrift-testing-server/pom.xml index 9f56abc34abc..ff0c5fc89a5c 100644 --- a/plugin/trino-thrift-testing-server/pom.xml +++ b/plugin/trino-thrift-testing-server/pom.xml @@ -21,7 +21,7 @@ io.trino - trino-main + trino-plugin-toolkit diff --git a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftIndexedTpchService.java b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftIndexedTpchService.java index db5dbd66ce13..8f1a7daa0a79 100644 --- a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftIndexedTpchService.java +++ b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftIndexedTpchService.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.plugin.base.MappedRecordSet; import io.trino.plugin.thrift.api.TrinoThriftBlock; import io.trino.plugin.thrift.api.TrinoThriftId; import io.trino.plugin.thrift.api.TrinoThriftNullableToken; @@ -27,7 +28,6 @@ import io.trino.spi.connector.RecordPageSource; import io.trino.spi.connector.RecordSet; import io.trino.spi.type.Type; -import io.trino.split.MappedRecordSet; import io.trino.testing.tpch.TpchIndexedData; import io.trino.testing.tpch.TpchIndexedData.IndexedTable; import io.trino.testing.tpch.TpchScaledTable; diff --git a/testing/trino-testing/src/main/java/io/trino/testing/tpch/TpchIndexProvider.java b/testing/trino-testing/src/main/java/io/trino/testing/tpch/TpchIndexProvider.java index 47054349c976..fc9f18046e84 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/tpch/TpchIndexProvider.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/tpch/TpchIndexProvider.java @@ -14,6 +14,7 @@ package io.trino.testing.tpch; import com.google.common.collect.ImmutableList; +import io.trino.plugin.base.MappedRecordSet; import io.trino.plugin.tpch.TpchColumnHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorIndex; @@ -25,7 +26,6 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; -import io.trino.split.MappedRecordSet; import java.util.ArrayList; import java.util.List;