Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/main/sphinx/connector/ignite.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ statements, the connector supports the following features:

- {doc}`/sql/insert`
- {doc}`/sql/update`
- {doc}`/sql/merge`
- {doc}`/sql/delete`
- {doc}`/sql/create-table`
- {doc}`/sql/create-table-as`
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/sphinx/connector/oracle.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ supports the following statements:

- {doc}`/sql/insert`
- {doc}`/sql/update`
- {doc}`/sql/merge`
- {doc}`/sql/delete`
- {doc}`/sql/truncate`
- {doc}`/sql/create-table`
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/sphinx/connector/phoenix.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ statements, the connector supports the following features:

- {doc}`/sql/insert`
- {doc}`/sql/delete`
- {doc}`/sql/update`
- {doc}`/sql/merge`
- {doc}`/sql/create-table`
- {doc}`/sql/create-table-as`
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/sphinx/connector/postgresql.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ statements, the connector supports the following features:

- {doc}`/sql/insert`
- {doc}`/sql/update`
- {doc}`/sql/merge`
- {doc}`/sql/delete`
- {doc}`/sql/truncate`
- {ref}`sql-schema-table-management`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.jdbc;

import com.google.common.base.Joiner;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -31,6 +32,7 @@
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.FixedSplitSource;
import io.trino.spi.connector.JoinStatistics;
Expand All @@ -43,6 +45,7 @@
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import jakarta.annotation.Nullable;
Expand Down Expand Up @@ -78,6 +81,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Iterators.tryFind;
import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE;
Expand All @@ -102,6 +106,7 @@
public abstract class BaseJdbcClient
implements JdbcClient
{
public static final String MERGE_ROW_ID = "$merge_row_id";
private static final Logger log = Logger.get(BaseJdbcClient.class);

static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT;
Expand Down Expand Up @@ -929,6 +934,21 @@ private RemoteTableName constructPageSinkIdsTable(ConnectorSession session, Conn

@Override
public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle handle, Set<Long> pageSinkIds)
{
RemoteTableName targetTable = new RemoteTableName(
Optional.ofNullable(handle.getCatalogName()),
Optional.ofNullable(handle.getSchemaName()),
handle.getTableName());
String columns = handle.getColumnNames().stream()
.map(this::quoted)
.collect(joining(", "));

// last args will be handled in the finish operation
String insertSql = "INSERT INTO %s (%s)".formatted(postProcessInsertTableNameClause(session, quoted(targetTable)), columns) + " %s";
finishOperation(session, handle, pageSinkIds, insertSql);
}

private void finishOperation(ConnectorSession session, JdbcOutputTableHandle handle, Set<Long> pageSinkIds, String operateSql)
{
if (isNonTransactionalInsert(session)) {
checkState(handle.getTemporaryTableName().isEmpty(), "Unexpected use of temporary table when non transactional inserts are enabled");
Expand All @@ -939,10 +959,6 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
Optional.ofNullable(handle.getCatalogName()),
Optional.ofNullable(handle.getSchemaName()),
handle.getTemporaryTableName().orElseThrow());
RemoteTableName targetTable = new RemoteTableName(
Optional.ofNullable(handle.getCatalogName()),
Optional.ofNullable(handle.getSchemaName()),
handle.getTableName());

// We conditionally create more than the one table, so keep a list of the tables that need to be dropped.
Closer closer = Closer.create();
Expand All @@ -953,23 +969,18 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
String columns = handle.getColumnNames().stream()
.map(this::quoted)
.collect(joining(", "));

String insertSql = format("INSERT INTO %s (%s) SELECT %s FROM %s temp_table",
postProcessInsertTableNameClause(session, quoted(targetTable)),
columns,
columns,
quoted(temporaryTable));
String tempTableData = "SELECT %s FROM %s temp_table".formatted(columns, quoted(temporaryTable));

if (handle.getPageSinkIdColumnName().isPresent()) {
RemoteTableName pageSinkTable = constructPageSinkIdsTable(session, connection, handle, pageSinkIds, closer);

insertSql += format(" WHERE EXISTS (SELECT 1 FROM %s page_sink_table WHERE page_sink_table.%s = temp_table.%s)",
tempTableData += format(" WHERE EXISTS (SELECT 1 FROM %s page_sink_table WHERE page_sink_table.%s = temp_table.%s)",
quoted(pageSinkTable),
handle.getPageSinkIdColumnName().get(),
handle.getPageSinkIdColumnName().get());
}

execute(session, connection, insertSql);
execute(session, connection, operateSql.formatted(tempTableData));
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand All @@ -984,6 +995,37 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
}
}

@Override
public JdbcOutputTableHandle beginDeleteTableForMerge(ConnectorSession session, JdbcTableHandle tableHandle)
{
verify(shouldUseFaultTolerantExecution(session));
return beginInsertTable(session, tableHandle, getPrimaryKeys(session, tableHandle));
}

protected String getConjunctsBetweenTargetAndTemporaryTable(JdbcOutputTableHandle handle)
{
StringBuilder conjuncts = new StringBuilder();
String conjunct = "merge_target.%s = temp.%$1s";
for (String column : handle.getColumnNames()) {
conjuncts.append(conjunct.formatted(column));
}
return conjuncts.toString();
}

@Override
public void finishDeleteTableForMerge(ConnectorSession session, JdbcOutputTableHandle handle, Set<Long> pageSinkIds)
{
verify(shouldUseFaultTolerantExecution(session));
RemoteTableName targetTable = new RemoteTableName(
Optional.ofNullable(handle.getCatalogName()),
Optional.ofNullable(handle.getSchemaName()),
handle.getTableName());
String deleteCondition = "WHERE EXISTS (SELECT 1 FROM (%s) temp WHERE " + getConjunctsBetweenTargetAndTemporaryTable(handle) + ")";

String deleteSql = "DELETE FROM %s merge_target ".formatted(postProcessInsertTableNameClause(session, quoted(targetTable))) + deleteCondition;
finishOperation(session, handle, pageSinkIds, deleteSql);
}

protected String postProcessInsertTableNameClause(ConnectorSession session, String tableName)
{
return tableName;
Expand Down Expand Up @@ -1140,6 +1182,33 @@ public String buildInsertSql(JdbcOutputTableHandle handle, List<WriteFunction> c
hasPageSinkIdColumn ? ", ?" : "");
}

@Override
public String buildMergeRowIdConjuncts(ConnectorSession session, List<String> mergeRowIdFieldNames, List<Type> mergeRowIdFieldTypes)
{
List<WriteFunction> mergeRowIdColumnWriters = mergeRowIdFieldTypes.stream()
.map(type -> {
WriteMapping writeMapping = toWriteMapping(session, type);
WriteFunction writeFunction = writeMapping.getWriteFunction();
verify(
type.getJavaType() == writeFunction.getJavaType(),
"Trino type %s is not compatible with write function %s accepting %s",
type,
writeFunction,
writeFunction.getJavaType());
return writeMapping;
})
.map(WriteMapping::getWriteFunction)
.collect(toImmutableList());
verify(!mergeRowIdColumnWriters.isEmpty() && mergeRowIdFieldNames.size() == mergeRowIdColumnWriters.size());

ImmutableList.Builder<String> conjunctsBuilder = ImmutableList.builder();
for (int i = 0; i < mergeRowIdFieldNames.size(); i++) {
conjunctsBuilder.add(quoted(mergeRowIdFieldNames.get(i)) + " = " + mergeRowIdColumnWriters.get(i).getBindExpression());
}

return Joiner.on(" AND ").join(conjunctsBuilder.build());
}

@Override
public Connection getConnection(ConnectorSession session, JdbcOutputTableHandle handle)
throws SQLException
Expand Down Expand Up @@ -1541,7 +1610,7 @@ static TopNFunction sqlStandard(Function<String, String> quote)
}
}

private static ColumnMetadata getPageSinkIdColumn(List<String> otherColumnNames)
protected static ColumnMetadata getPageSinkIdColumn(List<String> otherColumnNames)
{
// While it's unlikely this column name will collide with client table columns,
// guarantee it will not by appending a deterministic suffix to it.
Expand All @@ -1559,4 +1628,59 @@ public RemoteIdentifiers getRemoteIdentifiers(Connection connection)
{
return jdbcRemoteIdentifiersFactory.createJdbcRemoteIdentifies(connection);
}

@Override
public JdbcTableHandle updatedScanColumnsForMerge(ConnectorSession session, ConnectorTableHandle table, Optional<List<JdbcColumnHandle>> originalColumns, JdbcColumnHandle mergeRowIdColumnHandle)
{
JdbcTableHandle tableHandle = (JdbcTableHandle) table;
if (originalColumns.isEmpty()) {
return tableHandle;
}
List<JdbcColumnHandle> scanColumnHandles = originalColumns.get();
checkArgument(!scanColumnHandles.isEmpty(), "Scan columns should not empty");
checkArgument(tryFind(scanColumnHandles.iterator(), column -> MERGE_ROW_ID.equalsIgnoreCase(column.getColumnName())).isPresent(), "Merge row id column must exist in original columns");

return new JdbcTableHandle(
tableHandle.getRelationHandle(),
tableHandle.getConstraint(),
tableHandle.getConstraintExpressions(),
tableHandle.getSortOrder(),
tableHandle.getLimit(),
Optional.of(getUpdatedScanColumnHandles(session, tableHandle, scanColumnHandles, mergeRowIdColumnHandle)),
tableHandle.getOtherReferencedTables(),
tableHandle.getNextSyntheticColumnId(),
tableHandle.getAuthorization(),
tableHandle.getUpdateAssignments());
}

protected List<JdbcColumnHandle> getUpdatedScanColumnHandles(ConnectorSession session, JdbcTableHandle tableHandle, List<JdbcColumnHandle> scanColumnHandles, JdbcColumnHandle mergeRowIdColumnHandle)
{
RowType columnType = (RowType) mergeRowIdColumnHandle.getColumnType();
List<JdbcColumnHandle> primaryKeyColumnHandles = getPrimaryKeys(session, tableHandle);
Set<String> mergeRowIdFieldNames = columnType.getFields().stream()
.map(RowType.Field::getName)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableSet());
Set<String> primaryKeyColumnNames = primaryKeyColumnHandles.stream()
.map(JdbcColumnHandle::getColumnName)
.collect(toImmutableSet());
checkArgument(mergeRowIdFieldNames.containsAll(primaryKeyColumnNames), "Merge row id fields should contains all primary keys");

ImmutableList.Builder<JdbcColumnHandle> columnHandleBuilder = ImmutableList.builder();
scanColumnHandles.stream()
.filter(jdbcColumnHandle -> !MERGE_ROW_ID.equalsIgnoreCase(jdbcColumnHandle.getColumnName()))
.forEach(columnHandleBuilder::add);

// Add merge row id fields
for (JdbcColumnHandle columnHandle : primaryKeyColumnHandles) {
String columnName = columnHandle.getColumnName();

if (!tryFind(scanColumnHandles.iterator(), column -> column.getColumnName().equalsIgnoreCase(columnName)).isPresent()) {
columnHandleBuilder.add(columnHandle);
}
}

return columnHandleBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
Expand Down Expand Up @@ -179,6 +180,12 @@ public List<JdbcColumnHandle> getColumns(ConnectorSession session, JdbcTableHand
return get(columnsCache, key, () -> delegate.getColumns(session, tableHandle));
}

@Override
public List<JdbcColumnHandle> getPrimaryKeys(ConnectorSession session, JdbcTableHandle tableHandle)
{
return delegate.getPrimaryKeys(session, tableHandle);
}

@Override
public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle)
{
Expand Down Expand Up @@ -366,6 +373,19 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
onDataChanged(new SchemaTableName(handle.getSchemaName(), handle.getTableName()));
}

@Override
public JdbcOutputTableHandle beginDeleteTableForMerge(ConnectorSession session, JdbcTableHandle tableHandle)
{
return delegate.beginDeleteTableForMerge(session, tableHandle);
}

@Override
public void finishDeleteTableForMerge(ConnectorSession session, JdbcOutputTableHandle handle, Set<Long> pageSinkIds)
{
delegate.finishDeleteTableForMerge(session, handle, pageSinkIds);
onDataChanged(new SchemaTableName(handle.getSchemaName(), handle.getTableName()));
}

@Override
public void dropTable(ConnectorSession session, JdbcTableHandle jdbcTableHandle)
{
Expand All @@ -391,6 +411,12 @@ public String buildInsertSql(JdbcOutputTableHandle handle, List<WriteFunction> c
return delegate.buildInsertSql(handle, columnWriters);
}

@Override
public String buildMergeRowIdConjuncts(ConnectorSession session, List<String> mergeRowIdFieldNames, List<Type> mergeRowIdFieldTypes)
{
return delegate.buildMergeRowIdConjuncts(session, mergeRowIdFieldNames, mergeRowIdFieldTypes);
}

@Override
public Connection getConnection(ConnectorSession session, JdbcOutputTableHandle handle)
throws SQLException
Expand Down Expand Up @@ -605,6 +631,12 @@ public void truncateTable(ConnectorSession session, JdbcTableHandle handle)
onDataChanged(handle.getRequiredNamedRelation().getSchemaTableName());
}

@Override
public JdbcTableHandle updatedScanColumnsForMerge(ConnectorSession session, ConnectorTableHandle table, Optional<List<JdbcColumnHandle>> originalColumns, JdbcColumnHandle mergeRowIdColumnHandle)
{
return delegate.updatedScanColumnsForMerge(session, table, originalColumns, mergeRowIdColumnHandle);
}

@Managed
public void flushCache()
{
Expand Down
Loading