From a2ec0b722878368f040ae49d4c092c17dc63ca28 Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Fri, 8 Nov 2024 19:03:57 +0800 Subject: [PATCH 1/2] Refactor merge to support partial update in engine --- .../main/java/io/trino/metadata/Metadata.java | 5 ++- .../io/trino/metadata/MetadataManager.java | 9 +++-- ...hangeOnlyUpdatedColumnsMergeProcessor.java | 4 ++- .../DeleteAndInsertMergeProcessor.java | 16 ++++++--- .../java/io/trino/sql/analyzer/Analysis.java | 9 +++++ .../trino/sql/analyzer/StatementAnalyzer.java | 34 ++++++++++++++++--- .../io/trino/sql/planner/QueryPlanner.java | 10 ++++-- .../optimizations/BeginTableWrite.java | 8 +++-- .../sql/planner/plan/TableWriterNode.java | 12 ++++++- .../tracing/TracingConnectorMetadata.java | 9 +++++ .../io/trino/tracing/TracingMetadata.java | 5 +-- .../trino/metadata/AbstractMockMetadata.java | 3 +- .../TestDeleteAndInsertMergeProcessor.java | 6 ++-- .../iterative/rule/test/PlanBuilder.java | 4 ++- .../spi/connector/ConnectorMetadata.java | 12 +++++++ .../io/trino/spi/connector/MergePage.java | 8 ++--- .../ClassLoaderSafeConnectorMetadata.java | 8 +++++ .../plugin/bigquery/BigQueryMetadata.java | 4 +-- .../plugin/cassandra/CassandraMetadata.java | 2 +- .../plugin/deltalake/DeltaLakeMergeSink.java | 8 ++--- .../io/trino/plugin/kudu/KuduMetadata.java | 2 +- .../io/trino/plugin/kudu/KuduPageSink.java | 6 ++-- .../plugin/phoenix5/PhoenixMergeSink.java | 4 +-- .../plugin/phoenix5/PhoenixMetadata.java | 2 +- 24 files changed, 146 insertions(+), 44 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index ad5efc3f1e00..dff0bf1d0f98 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -13,6 +13,7 @@ */ package io.trino.metadata; +import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.trino.Session; @@ -452,8 +453,10 @@ Optional finishRefreshMaterializedView( /** * Begin merge query + * + * @param updateCaseColumnHandles The merge update case number to the assignment target columns mapping */ - MergeHandle beginMerge(Session session, TableHandle tableHandle); + MergeHandle beginMerge(Session session, TableHandle tableHandle, Multimap updateCaseColumnHandles); /** * Finish merge query diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index bdb66117c70d..ff6d7d0e640b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; import com.google.common.collect.Streams; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -1352,11 +1353,15 @@ public RowChangeParadigm getRowChangeParadigm(Session session, TableHandle table } @Override - public MergeHandle beginMerge(Session session, TableHandle tableHandle) + public MergeHandle beginMerge(Session session, TableHandle tableHandle, Multimap updateCaseColumns) { CatalogHandle catalogHandle = tableHandle.catalogHandle(); ConnectorMetadata metadata = getMetadataForWrite(session, catalogHandle); - ConnectorMergeTableHandle newHandle = metadata.beginMerge(session.toConnectorSession(catalogHandle), tableHandle.connectorHandle(), getRetryPolicy(session).getRetryMode()); + ConnectorMergeTableHandle newHandle = metadata.beginMerge( + session.toConnectorSession(catalogHandle), + tableHandle.connectorHandle(), + updateCaseColumns.asMap(), + getRetryPolicy(session).getRetryMode()); return new MergeHandle(tableHandle.withConnectorHandle(newHandle.getTableHandle()), newHandle); } diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java index 4ce8d5ca5ec6..9313d9f204f2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -67,12 +67,14 @@ public Page transformPage(Page inputPage) Block mergeRow = inputPage.getBlock(mergeRowChannel).getLoadedBlock(); List fields = getRowFieldsFromBlock(mergeRow); - List builder = new ArrayList<>(dataColumnChannels.size() + 3); + List builder = new ArrayList<>(dataColumnChannels.size() + 4); for (int channel : dataColumnChannels) { builder.add(fields.get(channel)); } Block operationChannelBlock = fields.get(fields.size() - 2); builder.add(operationChannelBlock); + Block caseNumberChannelBlock = fields.get(fields.size() - 1); + builder.add(caseNumberChannelBlock); builder.add(inputPage.getBlock(rowIdChannel)); builder.add(RunLengthEncodedBlock.create(INSERT_FROM_UPDATE_BLOCK, positionCount)); diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java index 1a76b33a78a7..3dd4b85ab7f5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java @@ -31,6 +31,7 @@ import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_DELETE_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_INSERT_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TinyintType.TINYINT; import static java.util.Objects.requireNonNull; @@ -124,6 +125,7 @@ public Page transformPage(Page inputPage) List pageTypes = ImmutableList.builder() .addAll(dataColumnTypes) .add(TINYINT) + .add(INTEGER) .add(rowIdType) .add(TINYINT) .build(); @@ -171,11 +173,14 @@ private void addDeleteRow(PageBuilder pageBuilder, Page originalPage, int positi // Add the operation column == deleted TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size()), causedByUpdate ? UPDATE_DELETE_OPERATION_NUMBER : DELETE_OPERATION_NUMBER); + // Add the dummy case number, delete and insert won't use it, use -1 to mark it shouldn't be used + INTEGER.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1), -1); + // Copy row ID column - rowIdType.appendTo(originalPage.getBlock(rowIdChannel), position, pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1)); + rowIdType.appendTo(originalPage.getBlock(rowIdChannel), position, pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2)); // Write 0, meaning this row is not an insert derived from an update - TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2), 0); + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 3), 0); pageBuilder.declarePosition(); } @@ -193,11 +198,14 @@ private void addInsertRow(PageBuilder pageBuilder, List fields, int posit // Add the operation column == insert TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size()), causedByUpdate ? UPDATE_INSERT_OPERATION_NUMBER : INSERT_OPERATION_NUMBER); + // Add the dummy case number, delete and insert won't use it + INTEGER.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1), 0); + // Add null row ID column - pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1).appendNull(); + pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2).appendNull(); // Write 1 if this row is an insert derived from an update, 0 otherwise - TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2), causedByUpdate ? 1 : 0); + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 3), causedByUpdate ? 1 : 0); pageBuilder.declarePosition(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 3a75ce307417..f2dd8027a4cc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -1842,6 +1842,8 @@ public static class MergeAnalysis private final List dataColumnHandles; private final List redistributionColumnHandles; private final List> mergeCaseColumnHandles; + // Case number map to columns + private final Multimap updateCaseColumnHandles; private final Set nonNullableColumnHandles; private final Map columnHandleFieldNumbers; private final RowType mergeRowType; @@ -1857,6 +1859,7 @@ public MergeAnalysis( List dataColumnHandles, List redistributionColumnHandles, List> mergeCaseColumnHandles, + Multimap updateCaseColumnHandles, Set nonNullableColumnHandles, Map columnHandleFieldNumbers, RowType mergeRowType, @@ -1871,6 +1874,7 @@ public MergeAnalysis( this.dataColumnHandles = requireNonNull(dataColumnHandles, "dataColumnHandles is null"); this.redistributionColumnHandles = requireNonNull(redistributionColumnHandles, "redistributionColumnHandles is null"); this.mergeCaseColumnHandles = requireNonNull(mergeCaseColumnHandles, "mergeCaseColumnHandles is null"); + this.updateCaseColumnHandles = requireNonNull(updateCaseColumnHandles, "updateCaseColumnHandles is null"); this.nonNullableColumnHandles = requireNonNull(nonNullableColumnHandles, "nonNullableColumnHandles is null"); this.columnHandleFieldNumbers = requireNonNull(columnHandleFieldNumbers, "columnHandleFieldNumbers is null"); this.mergeRowType = requireNonNull(mergeRowType, "mergeRowType is null"); @@ -1906,6 +1910,11 @@ public List> getMergeCaseColumnHandles() return mergeCaseColumnHandles; } + public Multimap getUpdateCaseColumnHandles() + { + return updateCaseColumnHandles; + } + public Set getNonNullableColumnHandles() { return nonNullableColumnHandles; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 969e8d1a8487..392996e0b46b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -851,7 +851,7 @@ protected Scope visitDelete(Delete node, Optional scope) analyzeCheckConstraints(table, tableName, accessControlScope, tableSchema.tableSchema().getCheckConstraints()); analysis.registerTable(table, Optional.of(handle), tableName, session.getIdentity().getUser(), accessControlScope, Optional.empty()); - createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of()); + createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of(), ImmutableMultimap.of()); return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT)); } @@ -3482,7 +3482,10 @@ protected Scope visitUpdate(Update update, Optional scope) sourceColumnsByColumnName.getOrDefault(column.getName(), ImmutableSet.of()))) .collect(toImmutableList()))); - createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of(updatedColumnHandles)); + ImmutableMultimap.Builder updateCaseColumnsBuilder = ImmutableMultimap.builder(); + // Update only have one update case number which default is 0 + updatedColumnHandles.forEach(columnHandle -> updateCaseColumnsBuilder.put(0, columnHandle)); + createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of(updatedColumnHandles), updateCaseColumnsBuilder.build()); return createAndAssignScope(update, scope, Field.newUnqualified("rows", BIGINT)); } @@ -3645,12 +3648,32 @@ else if (operation instanceof MergeInsert && caseColumnNames.isEmpty()) { analysis.setUpdateTarget(targetTableHandle.catalogHandle().getVersion(), tableName, Optional.of(table), Optional.of(updatedColumns)); List> mergeCaseColumnHandles = buildCaseColumnLists(merge, dataColumnSchemas, allColumnHandles); - createMergeAnalysis(table, targetTableHandle, tableSchema, targetTableScope, joinScope, mergeCaseColumnHandles); + checkArgument( + mergeCaseColumnHandles.size() == merge.getMergeCases().size(), + "Unexpected mergeCaseColumnHandles size: %s with merge cases size: %s", mergeCaseColumnHandles.size(), merge.getMergeCases().size()); + ImmutableMultimap.Builder updateCaseColumnHandles = ImmutableMultimap.builder(); + for (int caseCounter = 0; caseCounter < merge.getMergeCases().size(); caseCounter++) { + MergeCase mergeCase = merge.getMergeCases().get(caseCounter); + if (mergeCase instanceof MergeUpdate) { + for (ColumnHandle columnHandle : mergeCaseColumnHandles.get(caseCounter)) { + updateCaseColumnHandles.put(caseCounter, columnHandle); + } + } + } + + createMergeAnalysis(table, targetTableHandle, tableSchema, targetTableScope, joinScope, mergeCaseColumnHandles, updateCaseColumnHandles.build()); return createAndAssignScope(merge, Optional.empty(), Field.newUnqualified("rows", BIGINT)); } - private void createMergeAnalysis(Table table, TableHandle handle, TableSchema tableSchema, Scope tableScope, Scope joinScope, List> updatedColumns) + private void createMergeAnalysis( + Table table, + TableHandle handle, + TableSchema tableSchema, + Scope tableScope, + Scope joinScope, + List> mergeCaseColumns, + Multimap updateCaseColumns) { Optional updateLayout = metadata.getUpdateLayout(session, handle); Map allColumnHandles = metadata.getColumnHandles(session, handle); @@ -3713,7 +3736,8 @@ private void createMergeAnalysis(Table table, TableHandle handle, TableSchema ta dataColumnSchemas, dataColumnHandles, redistributionColumnHandles, - updatedColumns, + mergeCaseColumns, + updateCaseColumns, nonNullableColumnHandles, columnHandleFieldNumbers, mergeRowType, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index d5da1becd9a8..8f0ac1a72192 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -553,6 +553,8 @@ public PlanNode plan(Delete node) List columnSymbols = columnSymbolsBuilder.build(); Symbol operationSymbol = symbolAllocator.newSymbol("operation", TINYINT); assignmentsBuilder.put(operationSymbol, new Constant(TINYINT, (long) DELETE_OPERATION_NUMBER)); + Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); + assignmentsBuilder.put(caseNumberSymbol, new Constant(INTEGER, 0L)); Symbol projectedRowIdSymbol = symbolAllocator.newSymbol(rowIdSymbol.name(), rowIdType); assignmentsBuilder.put(projectedRowIdSymbol, rowIdSymbol.toSymbolReference()); assignmentsBuilder.put(symbolAllocator.newSymbol("insert_from_update", TINYINT), new Constant(TINYINT, 0L)); @@ -575,7 +577,8 @@ public PlanNode plan(Delete node) Optional.empty(), tableMetadata.table(), paradigmAndTypes, - findSourceTableHandles(projectNode)), + findSourceTableHandles(projectNode), + ImmutableListMultimap.of()), projectNode.getOutputSymbols(), partitioningScheme, outputs); @@ -943,7 +946,8 @@ private MergeWriterNode createMergePipeline(Table table, RelationPlan relationPl Optional.empty(), metadata.getTableName(session, handle).getSchemaTableName(), mergeParadigmAndTypes, - findSourceTableHandles(planNode)); + findSourceTableHandles(planNode), + mergeAnalysis.getUpdateCaseColumnHandles()); ImmutableList.Builder columnSymbolsBuilder = ImmutableList.builder(); for (ColumnHandle columnHandle : mergeAnalysis.getDataColumnHandles()) { @@ -958,11 +962,13 @@ private MergeWriterNode createMergePipeline(Table table, RelationPlan relationPl } Symbol operationSymbol = symbolAllocator.newSymbol("operation", TINYINT); + Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); Symbol insertFromUpdateSymbol = symbolAllocator.newSymbol("insert_from_update", TINYINT); List projectedSymbols = ImmutableList.builder() .addAll(columnSymbols) .add(operationSymbol) + .add(caseNumberSymbol) .add(rowIdSymbol) .add(insertFromUpdateSymbol) .build(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index 6f67734e479e..dc66f381e5c8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -212,7 +212,8 @@ public WriterTarget getWriterTarget(PlanNode node) mergeTarget.getMergeHandle(), mergeTarget.getSchemaTableName(), mergeTarget.getMergeParadigmAndTypes(), - findSourceTableHandles(node)); + findSourceTableHandles(node), + mergeTarget.getUpdateCaseColumnHandles()); } if (node instanceof ExchangeNode || node instanceof UnionNode) { @@ -247,13 +248,14 @@ private WriterTarget createWriterTarget(WriterTarget target, PlanNode planNode) findSourceTableHandles(planNode)); } if (target instanceof MergeTarget merge) { - MergeHandle mergeHandle = metadata.beginMerge(session, merge.getHandle()); + MergeHandle mergeHandle = metadata.beginMerge(session, merge.getHandle(), merge.getUpdateCaseColumnHandles()); return new MergeTarget( mergeHandle.tableHandle(), Optional.of(mergeHandle), merge.getSchemaTableName(), merge.getMergeParadigmAndTypes(), - findSourceTableHandles(planNode)); + findSourceTableHandles(planNode), + merge.getUpdateCaseColumnHandles()); } if (target instanceof TableWriterNode.RefreshMaterializedViewReference refreshMV) { return new TableWriterNode.RefreshMaterializedViewTarget( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java index ee25cf5f12f2..b9e7941c392f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; import com.google.errorprone.annotations.Immutable; import io.trino.Session; import io.trino.metadata.InsertTableHandle; @@ -731,6 +732,7 @@ public static class MergeTarget private final SchemaTableName schemaTableName; private final MergeParadigmAndTypes mergeParadigmAndTypes; private final List sourceTableHandles; + private final Multimap updateCaseColumnHandles; @JsonCreator public MergeTarget( @@ -738,13 +740,15 @@ public MergeTarget( @JsonProperty("mergeHandle") Optional mergeHandle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName, @JsonProperty("mergeParadigmAndTypes") MergeParadigmAndTypes mergeParadigmAndTypes, - @JsonProperty("sourceTableHandles") List sourceTableHandles) + @JsonProperty("sourceTableHandles") List sourceTableHandles, + @JsonProperty("updateCaseColumnHandles") Multimap updateCaseColumnHandles) { this.handle = requireNonNull(handle, "handle is null"); this.mergeHandle = requireNonNull(mergeHandle, "mergeHandle is null"); this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.mergeParadigmAndTypes = requireNonNull(mergeParadigmAndTypes, "mergeElements is null"); this.sourceTableHandles = ImmutableList.copyOf(requireNonNull(sourceTableHandles, "sourceTableHandles is null")); + this.updateCaseColumnHandles = requireNonNull(updateCaseColumnHandles, "updateCaseColumnHandles is null"); } @JsonProperty @@ -800,6 +804,12 @@ public List getSourceTableHandles() { return sourceTableHandles; } + + @JsonProperty + public Multimap getUpdateCaseColumnHandles() + { + return updateCaseColumnHandles; + } } public static class MergeParadigmAndTypes diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java index db53d565e6df..38920c16d8bb 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java @@ -770,6 +770,15 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT } } + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) + { + Span span = startSpan("beginMerge", tableHandle); + try (var _ = scopedSpan(span)) { + return delegate.beginMerge(session, tableHandle, updateCaseColumns, retryMode); + } + } + @Override public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, List sourceTableHandles, Collection fragments, Collection computedStatistics) { diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java index 2441cee43568..4e4b58f4ca1a 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java @@ -14,6 +14,7 @@ package io.trino.tracing; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; import com.google.inject.Inject; import io.airlift.slice.Slice; @@ -822,11 +823,11 @@ public Optional getUpdateLayout(Session session, TableHandle } @Override - public MergeHandle beginMerge(Session session, TableHandle tableHandle) + public MergeHandle beginMerge(Session session, TableHandle tableHandle, Multimap updateCaseColumns) { Span span = startSpan("beginMerge", tableHandle); try (var _ = scopedSpan(span)) { - return delegate.beginMerge(session, tableHandle); + return delegate.beginMerge(session, tableHandle, updateCaseColumns); } } diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index 6955c5dfccbf..c1952b4090da 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.trino.Session; @@ -549,7 +550,7 @@ public Optional getUpdateLayout(Session session, TableHandle } @Override - public MergeHandle beginMerge(Session session, TableHandle tableHandle) + public MergeHandle beginMerge(Session session, TableHandle tableHandle, Multimap updateCaseColumnHandles) { throw new UnsupportedOperationException(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java index ce28fa59dcc2..4c853965a762 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -80,7 +80,7 @@ public void testSimpleDeletedRowMerge() assertThat((int) TINYINT.getByte(outputPage.getBlock(3), 0)).isEqualTo(DELETE_OPERATION_NUMBER); // Show that the row to be deleted is rowId 0, e.g. ('Dave', 11, 'Devon') - SqlRow rowIdRow = ((RowBlock) outputPage.getBlock(4)).getRow(0); + SqlRow rowIdRow = ((RowBlock) outputPage.getBlock(5)).getRow(0); assertThat(BIGINT.getLong(rowIdRow.getRawFieldBlock(1), rowIdRow.getRawIndex())).isEqualTo(0); } @@ -123,7 +123,7 @@ public void testUpdateAndDeletedMerge() Page outputPage = processor.transformPage(inputPage); assertThat(outputPage.getPositionCount()).isEqualTo(8); - RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(5); assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); // Show that the first row has address "Arches" assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); @@ -163,7 +163,7 @@ public void testAnotherMergeCase() Page outputPage = processor.transformPage(inputPage); assertThat(outputPage.getPositionCount()).isEqualTo(8); - RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(5); assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); // Show that the first row has address "Arches/Arches" assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 73d17730ff34..27204bba84a0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -15,6 +15,7 @@ import com.google.common.base.Functions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; @@ -787,7 +788,8 @@ public MergeTarget mergeTarget(SchemaTableName schemaTableName, MergeParadigmAnd Optional.empty(), schemaTableName, mergeParadigmAndTypes, - List.of()); + List.of(), + ImmutableListMultimap.of()); } public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index f83fe90b67ee..1e27c3807ba5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -878,12 +878,24 @@ default Optional getUpdateLayout(ConnectorSession s /** * Do whatever is necessary to start an MERGE query, returning the {@link ConnectorMergeTableHandle} * instance that will be passed to the PageSink, and to the {@link #finishMerge} method. + * + * @deprecated {Use {@link #beginMerge(ConnectorSession, ConnectorTableHandle, Map, RetryMode)}} */ + @Deprecated default ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) { throw new TrinoException(NOT_SUPPORTED, MODIFYING_ROWS_MESSAGE); } + /** + * Do whatever is necessary to start an MERGE query, returning the {@link ConnectorMergeTableHandle} + * instance that will be passed to the PageSink, and to the {@link #finishMerge} method. + */ + default ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) + { + return beginMerge(session, tableHandle, retryMode); + } + /** * Finish a merge query * diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java index ad40f85f0a9f..4b9745289708 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java @@ -62,15 +62,15 @@ public static MergePage createDeleteAndInsertPages(Page inputPage, int dataColum { // see page description in ConnectorMergeSink int inputChannelCount = inputPage.getChannelCount(); - if (inputChannelCount != dataColumnCount + 2) { - throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 2", inputChannelCount, dataColumnCount)); + if (inputChannelCount != dataColumnCount + 3) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 3", inputChannelCount, dataColumnCount)); } int positionCount = inputPage.getPositionCount(); if (positionCount <= 0) { throw new IllegalArgumentException("positionCount should be > 0, but is " + positionCount); } - Block operationBlock = inputPage.getBlock(inputChannelCount - 2); + Block operationBlock = inputPage.getBlock(dataColumnCount); int[] deletePositions = new int[positionCount]; int[] insertPositions = new int[positionCount]; @@ -99,7 +99,7 @@ public static MergePage createDeleteAndInsertPages(Page inputPage, int dataColum for (int i = 0; i < dataColumnCount; i++) { columns[i] = i; } - columns[dataColumnCount] = dataColumnCount + 1; // row ID channel + columns[dataColumnCount] = dataColumnCount + 2; // row ID channel deletePage = Optional.of(inputPage .getColumns(columns) .getPositions(deletePositions, 0, deletePositionCount)); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 1ccbd5baa1d1..89b29ca6740a 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -1214,6 +1214,14 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT } } + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) + { + try (ThreadContextClassLoader _ = new ThreadContextClassLoader(classLoader)) { + return delegate.beginMerge(session, tableHandle, updateCaseColumns, retryMode); + } + } + @Override public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle mergeTableHandle, List sourceTableHandles, Collection fragments, Collection computedStatistics) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index feb08a62f5fb..10a1544eb977 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -818,10 +818,10 @@ public OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandle } @Override - public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) { // TODO Fix BaseBigQueryFailureRecoveryTest when implementing this method - return ConnectorMetadata.super.beginMerge(session, tableHandle, retryMode); + return ConnectorMetadata.super.beginMerge(session, tableHandle, updateCaseColumns, retryMode); } @Override diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index e020a467f333..c287e2ccdba4 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -471,7 +471,7 @@ private static boolean isHiddenIdColumn(CassandraColumnHandle columnHandle) } @Override - public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) { throw new TrinoException(NOT_SUPPORTED, "Delete without primary key or partition key is not supported"); } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index c12b97fe7f48..e8344b7e3950 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -190,7 +190,7 @@ public DeltaLakeMergeSink( dataColumnsIndices[i] = i; dataAndRowIdColumnsIndices[i] = i; } - dataAndRowIdColumnsIndices[tableColumnCount] = tableColumnCount + 1; // row ID channel + dataAndRowIdColumnsIndices[tableColumnCount] = tableColumnCount + 2; // row ID channel } @Override @@ -252,15 +252,15 @@ private void processDeletion(Page deletions, String cdfOperation) private DeltaLakeMergePage createPages(Page inputPage, int dataColumnCount) { int inputChannelCount = inputPage.getChannelCount(); - if (inputChannelCount != dataColumnCount + 2) { - throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 2", inputChannelCount, dataColumnCount)); + if (inputChannelCount != dataColumnCount + 3) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 3", inputChannelCount, dataColumnCount)); } int positionCount = inputPage.getPositionCount(); if (positionCount <= 0) { throw new IllegalArgumentException("positionCount should be > 0, but is " + positionCount); } - Block operationBlock = inputPage.getBlock(inputChannelCount - 2); + Block operationBlock = inputPage.getBlock(inputChannelCount - 3); int[] deletePositions = new int[positionCount]; int[] insertPositions = new int[positionCount]; int[] updateInsertPositions = new int[positionCount]; diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java index d574a9bc9c3a..511a424be7b1 100755 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java @@ -431,7 +431,7 @@ public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, Connecto } @Override - public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) { KuduTableHandle kuduTableHandle = (KuduTableHandle) tableHandle; KuduTable table = kuduTableHandle.getTable(clientSession); diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java index c802041b74d3..5bb1fb342d32 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java @@ -207,11 +207,11 @@ else if (type instanceof DecimalType decimalType) { @Override public void storeMergedRows(Page page) { - // The last channel in the page is the rowId block, the next-to-last is the operation block + // The last channel in the page is the rowId block, the next-to-last is the case number block, then the next is operation block int columnCount = originalColumnTypes.size(); - checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + checkArgument(page.getChannelCount() == 3 + columnCount, "The page size should be 3 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); Block operationBlock = page.getBlock(columnCount); - Block rowIds = page.getBlock(columnCount + 1); + Block rowIds = page.getBlock(columnCount + 2); Schema schema = table.getSchema(); try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientSession(session)) { diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java index 3361c1b5f99a..933bee906e2f 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java @@ -157,7 +157,7 @@ private static SinkSqlProvider deleteSinkProvider( @Override public void storeMergedRows(Page page) { - checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + checkArgument(page.getChannelCount() == 3 + columnCount, "The page size should be 3 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); int positionCount = page.getPositionCount(); Block operationBlock = page.getBlock(columnCount); @@ -194,7 +194,7 @@ public void storeMergedRows(Page page) insertSink.appendPage(dataPage.getPositions(insertPositions, 0, insertPositionCount)); } - List rowIdFields = RowBlock.getRowFieldsFromBlock(page.getBlock(columnCount + 1)); + List rowIdFields = RowBlock.getRowFieldsFromBlock(page.getBlock(columnCount + 2)); if (deletePositionCount > 0) { Block[] deleteBlocks = new Block[rowIdFields.size()]; for (int field = 0; field < rowIdFields.size(); field++) { diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java index 8acdc1ca1a62..7bb6d700f78c 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java @@ -330,7 +330,7 @@ public JdbcColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, Conn } @Override - public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateColumnHandles, RetryMode retryMode) { JdbcTableHandle handle = (JdbcTableHandle) tableHandle; checkArgument(handle.isNamedRelation(), "Merge target must be named relation table"); From 8ef1bf07611bcdd7b4e50493ddb6cd1c91d113bd Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Wed, 13 Nov 2024 19:31:48 +0800 Subject: [PATCH 2/2] Support partial update in Phoenix connector --- .../plugin/phoenix5/PhoenixMergeSink.java | 87 +++++- .../phoenix5/PhoenixMergeTableHandle.java | 17 +- .../plugin/phoenix5/PhoenixMetadata.java | 15 +- .../phoenix5/TestPhoenixConnectorTest.java | 294 +++++++++++++++++- .../plugin/phoenix5/TestingPhoenixServer.java | 8 + 5 files changed, 394 insertions(+), 27 deletions(-) diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java index 933bee906e2f..6e6b6e214b5b 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java @@ -13,7 +13,9 @@ */ package io.trino.plugin.phoenix5; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcOutputTableHandle; @@ -32,19 +34,27 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import org.apache.phoenix.util.SchemaUtil; import java.sql.Connection; import java.sql.SQLException; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY; import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY_COLUMN_HANDLE; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TinyintType.TINYINT; import static java.util.concurrent.CompletableFuture.completedFuture; import static org.apache.phoenix.util.SchemaUtil.getEscapedArgument; @@ -56,9 +66,11 @@ public class PhoenixMergeSink private final int columnCount; private final ConnectorPageSink insertSink; - private final ConnectorPageSink updateSink; + private final Map> updateSinkSuppliers; private final ConnectorPageSink deleteSink; + private final Map> updateCaseChannels; + public PhoenixMergeSink( ConnectorSession session, ConnectorMergeTableHandle mergeHandle, @@ -73,7 +85,6 @@ public PhoenixMergeSink( this.columnCount = phoenixOutputTableHandle.getColumnNames().size(); this.insertSink = new JdbcPageSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, JdbcClient::buildInsertSql); - this.updateSink = createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier); ImmutableList.Builder mergeRowIdFieldNamesBuilder = ImmutableList.builder(); ImmutableList.Builder mergeRowIdFieldTypesBuilder = ImmutableList.builder(); @@ -84,6 +95,31 @@ public PhoenixMergeSink( mergeRowIdFieldTypesBuilder.add(field.getType()); } List mergeRowIdFieldNames = mergeRowIdFieldNamesBuilder.build(); + List dataColumnNames = phoenixOutputTableHandle.getColumnNames().stream() + .map(SchemaUtil::getEscapedArgument) + .collect(toImmutableList()); + Set mergeRowIdChannels = mergeRowIdFieldNames.stream() + .map(dataColumnNames::indexOf) + .collect(toImmutableSet()); + + Map> updateCaseChannels = new HashMap<>(); + for (Map.Entry> entry : phoenixMergeTableHandle.updateCaseColumns().entrySet()) { + updateCaseChannels.put(entry.getKey(), entry.getValue()); + if (!hasRowKey) { + checkArgument(!mergeRowIdChannels.isEmpty() && !mergeRowIdChannels.contains(-1), "No primary keys found"); + updateCaseChannels.get(entry.getKey()).addAll(mergeRowIdChannels); + } + } + this.updateCaseChannels = ImmutableMap.copyOf(updateCaseChannels); + + ImmutableMap.Builder> updateSinksBuilder = ImmutableMap.builder(); + for (Map.Entry> entry : this.updateCaseChannels.entrySet()) { + int caseNumber = entry.getKey(); + Supplier updateSupplier = Suppliers.memoize(() -> createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, entry.getValue())); + updateSinksBuilder.put(caseNumber, updateSupplier); + } + this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow(); + this.deleteSink = createDeleteSink(session, mergeRowIdFieldTypesBuilder.build(), phoenixClient, phoenixMergeTableHandle, mergeRowIdFieldNames, pageSinkId, remoteQueryModifier, queryBuilder); } @@ -92,12 +128,17 @@ private static ConnectorPageSink createUpdateSink( PhoenixOutputTableHandle phoenixOutputTableHandle, PhoenixClient phoenixClient, ConnectorPageSinkId pageSinkId, - RemoteQueryModifier remoteQueryModifier) + RemoteQueryModifier remoteQueryModifier, + Set updateChannels) { ImmutableList.Builder columnNamesBuilder = ImmutableList.builder(); ImmutableList.Builder columnTypesBuilder = ImmutableList.builder(); - columnNamesBuilder.addAll(phoenixOutputTableHandle.getColumnNames()); - columnTypesBuilder.addAll(phoenixOutputTableHandle.getColumnTypes()); + for (int channel = 0; channel < phoenixOutputTableHandle.getColumnNames().size(); channel++) { + if (updateChannels.contains(channel)) { + columnNamesBuilder.add(phoenixOutputTableHandle.getColumnNames().get(channel)); + columnTypesBuilder.add(phoenixOutputTableHandle.getColumnTypes().get(channel)); + } + } if (phoenixOutputTableHandle.rowkeyColumn().isPresent()) { columnNamesBuilder.add(ROWKEY); columnTypesBuilder.add(ROWKEY_COLUMN_HANDLE.getColumnType()); @@ -168,8 +209,10 @@ public void storeMergedRows(Page page) int insertPositionCount = 0; int[] deletePositions = new int[positionCount]; int deletePositionCount = 0; - int[] updatePositions = new int[positionCount]; - int updatePositionCount = 0; + + Block updateCaseBlock = page.getBlock(columnCount + 1); + Map updatePositions = new HashMap<>(); + Map updatePositionCounts = new HashMap<>(); for (int position = 0; position < positionCount; position++) { int operation = TINYINT.getByte(operationBlock, position); @@ -183,8 +226,10 @@ public void storeMergedRows(Page page) deletePositionCount++; } case UPDATE_OPERATION_NUMBER -> { - updatePositions[updatePositionCount] = position; - updatePositionCount++; + int caseNumber = INTEGER.getInt(updateCaseBlock, position); + int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0); + updatePositions.computeIfAbsent(caseNumber, _ -> new int[positionCount])[updatePositionCount] = position; + updatePositionCounts.put(caseNumber, updatePositionCount + 1); } default -> throw new IllegalStateException("Unexpected value: " + operation); } @@ -203,13 +248,21 @@ public void storeMergedRows(Page page) deleteSink.appendPage(new Page(deletePositionCount, deleteBlocks)); } - if (updatePositionCount > 0) { - Page updatePage = dataPage.getPositions(updatePositions, 0, updatePositionCount); - if (hasRowKey) { - updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions, 0, updatePositionCount)); - } + for (Map.Entry entry : updatePositionCounts.entrySet()) { + int caseNumber = entry.getKey(); + int updatePositionCount = entry.getValue(); + if (updatePositionCount > 0) { + checkArgument(updatePositions.containsKey(caseNumber), "Unexpected case number %s", caseNumber); - updateSink.appendPage(updatePage); + Page updatePage = dataPage + .getColumns(updateCaseChannels.get(caseNumber).stream().mapToInt(Integer::intValue).sorted().toArray()) + .getPositions(updatePositions.get(caseNumber), 0, updatePositionCount); + if (hasRowKey) { + updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions.get(caseNumber), 0, updatePositionCount)); + } + + updateSinkSuppliers.get(caseNumber).get().appendPage(updatePage); + } } } @@ -218,7 +271,7 @@ public CompletableFuture> finish() { insertSink.finish(); deleteSink.finish(); - updateSink.finish(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish); return completedFuture(ImmutableList.of()); } @@ -227,6 +280,6 @@ public void abort() { insertSink.abort(); deleteSink.abort(); - updateSink.abort(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort); } } diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeTableHandle.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeTableHandle.java index 4959111e1e40..f0d5c6d256c0 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeTableHandle.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeTableHandle.java @@ -21,13 +21,17 @@ import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.predicate.TupleDomain; +import java.util.Map; +import java.util.Set; + import static java.util.Objects.requireNonNull; public record PhoenixMergeTableHandle( JdbcTableHandle tableHandle, PhoenixOutputTableHandle phoenixOutputTableHandle, JdbcColumnHandle mergeRowIdColumnHandle, - TupleDomain primaryKeysDomain) + TupleDomain primaryKeysDomain, + Map> updateCaseColumns) implements ConnectorMergeTableHandle { @JsonCreator @@ -35,12 +39,14 @@ public PhoenixMergeTableHandle( @JsonProperty("tableHandle") JdbcTableHandle tableHandle, @JsonProperty("phoenixOutputTableHandle") PhoenixOutputTableHandle phoenixOutputTableHandle, @JsonProperty("mergeRowIdColumnHandle") JdbcColumnHandle mergeRowIdColumnHandle, - @JsonProperty("primaryKeysDomain") TupleDomain primaryKeysDomain) + @JsonProperty("primaryKeysDomain") TupleDomain primaryKeysDomain, + @JsonProperty("updateCaseColumns") Map> updateCaseColumns) { this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); this.phoenixOutputTableHandle = requireNonNull(phoenixOutputTableHandle, "phoenixOutputTableHandle is null"); this.mergeRowIdColumnHandle = requireNonNull(mergeRowIdColumnHandle, "mergeRowIdColumnHandle is null"); this.primaryKeysDomain = requireNonNull(primaryKeysDomain, "primaryKeysDomain is null"); + this.updateCaseColumns = requireNonNull(updateCaseColumns, "updateCaseColumns is null"); } @JsonProperty @@ -70,4 +76,11 @@ public TupleDomain primaryKeysDomain() { return primaryKeysDomain; } + + @Override + @JsonProperty + public Map> updateCaseColumns() + { + return updateCaseColumns; + } } diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java index 7bb6d700f78c..f3fbb9d01d60 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java @@ -66,6 +66,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.jdbc.JdbcMetadata.getColumns; import static io.trino.plugin.phoenix5.MetadataUtil.getEscapedTableName; import static io.trino.plugin.phoenix5.MetadataUtil.toTrinoSchemaName; @@ -350,11 +351,23 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT primaryKeysDomainBuilder.put(columnHandle, dummy); } + ImmutableMap.Builder> updateColumnChannelsBuilder = ImmutableMap.builder(); + for (Map.Entry> entry : updateColumnHandles.entrySet()) { + int caseNumber = entry.getKey(); + Set updateColumnChannels = entry.getValue().stream() + .map(JdbcColumnHandle.class::cast) + .peek(column -> checkArgument(columns.contains(column), "update column %s not found in the target table", column)) + .map(columns::indexOf) + .collect(toImmutableSet()); + updateColumnChannelsBuilder.put(caseNumber, updateColumnChannels); + } + return new PhoenixMergeTableHandle( phoenixClient.updatedScanColumnTable(session, handle, handle.getColumns(), mergeRowIdColumnHandle), phoenixOutputTableHandle, mergeRowIdColumnHandle, - TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow())); + TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow()), + updateColumnChannelsBuilder.buildOrThrow()); } @Override diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index 892170623014..31ee9ff5a4f4 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -24,12 +24,22 @@ import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; import io.trino.testing.sql.TestTable; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.util.Bytes; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -470,13 +480,6 @@ private void testMergeWithSpecifiedRowkeys(String rowkeyDefinition) assertUpdate("DROP TABLE " + targetTable); } - @Test - @Override - public void testUpdateRowConcurrently() - { - abort("Phoenix doesn't support concurrent update of different columns in a row"); - } - @Test public void testSchemaOperations() { @@ -796,6 +799,283 @@ public void testExecuteProcedure() } } + @Test + public void testMergeUpdatePartial() + throws IOException + { + String targetTableName = "test_merge_update_partial_target" + randomNameSuffix(); + String schema = getSession().getSchema().orElseThrow(); + onRemoteDatabase().execute("CREATE TABLE " + schema + "." + targetTableName + " (pk varchar primary key , col_a bigint, col_b bigint, col_c bigint)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + targetTableName + " VALUES ('p1', 2, 3, 4)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + targetTableName + " VALUES ('p2', 3, 5, 7)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + targetTableName + " VALUES ('p3', 4, 6, 7)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + targetTableName + " VALUES ('p4', 5, 8, 7)"); + assertQuery("SELECT * FROM " + targetTableName, "VALUES ('p1', 2, 3, 4), ('p2', 3, 5, 7), ('p3', 4, 6, 7), ('p4', 5, 8, 7)"); + + try (org.apache.hadoop.hbase.client.Connection connection = testingPhoenixServer.getConnection()) { + // record pk=p1,p2 all column modified times + List p1ModifiedTimesBefore = new ArrayList<>(); + List p2ModifiedTimesBefore = new ArrayList<>(); + for (String column : List.of("col_a", "col_b", "col_c")) { + p1ModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, targetTableName, column, "p1")); + p2ModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, targetTableName, column, "p2")); + } + + // update single column in sing row, pk=p1 + assertUpdate(format("MERGE INTO %s t USING (VALUES ('p1', 3, 4, 5)) AS s(pk, a, b, c) " + + " ON t.pk = s.pk " + + " WHEN MATCHED THEN UPDATE SET col_a = s.a " + + " WHEN NOT MATCHED THEN INSERT (pk, col_a, col_b, col_c) VALUES (t.pk, s.a, s.b, s.c)", targetTableName), 1); + + List p1ModifiedTimesAfter = new ArrayList<>(); + List p2ModifiedTimesAfter = new ArrayList<>(); + for (String column : List.of("col_a", "col_b", "col_c")) { + p1ModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, targetTableName, column, "p1")); + p2ModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, targetTableName, column, "p2")); + } + // pk=p1 col_a is updated + assertThat(p1ModifiedTimesAfter.get(0)).isGreaterThan(p1ModifiedTimesBefore.get(0)); + // pk=p2,p3 col_a is not updated + assertThat(p1ModifiedTimesAfter.get(1)).isEqualTo(p1ModifiedTimesBefore.get(1)); + assertThat(p1ModifiedTimesAfter.get(2)).isEqualTo(p1ModifiedTimesBefore.get(2)); + // col_b values are not updated + assertThat(p2ModifiedTimesAfter).isEqualTo(p2ModifiedTimesBefore); + + // record col_a, col_b modified times + List colAModifiedTimesBefore = new ArrayList<>(); + List colBModifiedTimesBefore = new ArrayList<>(); + for (String pk : List.of("p1", "p2", "p3", "p4")) { + colAModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, targetTableName, "col_a", pk)); + colBModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, targetTableName, "col_b", pk)); + } + + assertQuery("SELECT * FROM " + targetTableName, "VALUES ('p1', 3, 3, 4), ('p2', 3, 5, 7), ('p3', 4, 6, 7), ('p4', 5, 8, 7)"); + + // update single column(col_a) in multi rows + assertUpdate(format("MERGE INTO %s t USING (VALUES ('p1', 3, 4, 5), ('p2', 4, 4, 4), ('p3', 5, 5, 5)) AS s(pk, a, b, c) " + + " ON t.pk = s.pk " + + " WHEN MATCHED AND t.col_a = 4 THEN UPDATE SET col_a = s.a " + // pk=p3 + " WHEN MATCHED THEN UPDATE SET col_a = s.a + 100 " + // pk=p1, p2 + " WHEN NOT MATCHED THEN INSERT (pk, col_a, col_b, col_c) VALUES (t.pk, s.a, s.b, s.c)", targetTableName), 3); + + List colAModifiedTimesAfter = new ArrayList<>(); + List colBModifiedTimesAfter = new ArrayList<>(); + for (String pk : List.of("p1", "p2", "p3", "p4")) { + colAModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, targetTableName, "col_a", pk)); + colBModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, targetTableName, "col_b", pk)); + } + + // pk=p1,p2,p3 col_a are all modified + assertThat(colAModifiedTimesAfter.get(0)).isGreaterThan(colAModifiedTimesBefore.get(0)); + assertThat(colAModifiedTimesAfter.get(1)).isGreaterThan(colAModifiedTimesBefore.get(1)); + assertThat(colAModifiedTimesAfter.get(2)).isGreaterThan(colAModifiedTimesBefore.get(2)); + // pk=p4 col_a is not modified + assertThat(colAModifiedTimesAfter.get(3)).isEqualTo(colAModifiedTimesBefore.get(3)); + // col_b is not modified + assertThat(colBModifiedTimesAfter).isEqualTo(colBModifiedTimesBefore); + + assertQuery("SELECT * FROM " + targetTableName, "VALUES ('p1', 103, 3, 4), ('p2', 104, 5, 7), ('p3', 5, 6, 7), ('p4', 5, 8, 7)"); + + // using source table to test non-overlapping sets columns update + String sourceTableName = "test_merge_update_partial_source" + randomNameSuffix(); + onRemoteDatabase().execute("CREATE TABLE " + schema + "." + sourceTableName + " (pk varchar primary key , col_a bigint, col_b bigint, col_c bigint)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + sourceTableName + " VALUES ('p1', 1, 1, 1)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + sourceTableName + " VALUES ('p2', 2, 2, 2)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + sourceTableName + " VALUES ('p3', 3, 3, 3)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + sourceTableName + " VALUES ('p4', 4, 4, 4)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + sourceTableName + " VALUES ('p5', 5, 5, 5)"); + assertQuery("SELECT * FROM " + sourceTableName, "VALUES ('p1', 1, 1, 1), ('p2', 2, 2, 2), ('p3', 3, 3, 3), ('p4', 4, 4, 4), ('p5', 5, 5, 5)"); + + // update multi rows with non-overlapping columns, with delete and insert cases + // update cell is (p2, col_a), (p3, col_c), (p4, col_c) + + // record before time that will be updated + long p2ColAModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_a", "p2"); + long p3ColCModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p3"); + long p4ColCModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p4"); + // record before time that will not be updated + long p2ColBModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_b", "p2"); + long p2ColCModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p2"); + long p3ColAModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_a", "p3"); + long p3ColBModifiedTimeBefore = readLatestColumnVersion(connection, schema, targetTableName, "col_b", "p3"); + + assertUpdate(format("MERGE INTO %s t USING %s s " + + " ON t.pk = s.pk " + + " WHEN MATCHED AND mod(t.col_a, 2) = 0 THEN UPDATE SET col_a = s.col_a + 1 " + // pk=p2 + " WHEN MATCHED AND t.col_a > 100 THEN DELETE " + // pk=p1 + " WHEN MATCHED THEN UPDATE SET col_c = t.col_c + s.col_c " + // pk=p3, p4 + " WHEN NOT MATCHED THEN INSERT (pk, col_a, col_b, col_c) VALUES (s.pk, s.col_a, s.col_b, s.col_c)", targetTableName, sourceTableName), 5); + + // check updated cell (p2, col_a), (p3, col_c), (p4, col_c) + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_a", "p2")).isGreaterThan(p2ColAModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p3")).isGreaterThan(p3ColCModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p4")).isGreaterThan(p4ColCModifiedTimeBefore); + + // check not updated cell (p2, col_b), (p2, col_c), (p3, col_a), (p3, col_b) + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_b", "p2")).isEqualTo(p2ColBModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_c", "p2")).isEqualTo(p2ColCModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_a", "p3")).isEqualTo(p3ColAModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, targetTableName, "col_b", "p3")).isEqualTo(p3ColBModifiedTimeBefore); + + assertQuery("SELECT * FROM " + targetTableName, "VALUES ('p2', 3, 5, 7), ('p3', 5, 6, 10), ('p4', 5, 8, 11), ('p5', 5, 5, 5)"); + + assertUpdate("DROP TABLE " + sourceTableName); + } + + assertUpdate("DROP TABLE " + targetTableName); + } + + @Test + public void testUpdatePartial() + throws IOException + { + String tableName = "test_update_partial" + randomNameSuffix(); + String schema = getSession().getSchema().orElseThrow(); + onRemoteDatabase().execute("CREATE TABLE " + schema + "." + tableName + " (pk varchar primary key , col_a bigint, col_b bigint, col_c bigint)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + tableName + " VALUES ('p1', 2, 3, 4)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + tableName + " VALUES ('p2', 3, 5, 7)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + tableName + " VALUES ('p3', 4, 6, 7)"); + onRemoteDatabase().execute("UPSERT INTO " + schema + "." + tableName + " VALUES ('p4', 5, 8, 7)"); + assertQuery("SELECT * FROM " + tableName, "VALUES ('p1', 2, 3, 4), ('p2', 3, 5, 7), ('p3', 4, 6, 7), ('p4', 5, 8, 7)"); + try (org.apache.hadoop.hbase.client.Connection connection = testingPhoenixServer.getConnection()) { + // update single column single row + + // nothing changed + long colAP1ModifiedTimeBefore = readLatestColumnVersion(connection, schema, tableName, "col_a", "p1"); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_a", "p1")).isEqualTo(colAP1ModifiedTimeBefore); + + long colBP1ModifiedTimeBefore = readLatestColumnVersion(connection, schema, tableName, "col_b", "p1"); + long colCP1ModifiedTimeBefore = readLatestColumnVersion(connection, schema, tableName, "col_c", "p1"); + + // record p2 each column modified time + List p2ModifiedTimesBefore = new ArrayList<>(); + p2ModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_a", "p2")); + p2ModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_b", "p2")); + p2ModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_c", "p2")); + + // update col_a pk=p1 + assertUpdate("UPDATE " + tableName + " SET col_a = -1 WHERE pk = 'p1'", 1); + + // row with pk=p1 only col_a is updated + long colAP1ModifiedTimeAfter = readLatestColumnVersion(connection, schema, tableName, "col_a", "p1"); + assertThat(colAP1ModifiedTimeAfter).isGreaterThan(colAP1ModifiedTimeBefore); + + // row with pk=p1 col_a and col_b are not updated + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_b", "p1")).isEqualTo(colBP1ModifiedTimeBefore); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_c", "p1")).isEqualTo(colCP1ModifiedTimeBefore); + + // row with pk=p2 nothing changed + List p2ModifiedTimesAfter = new ArrayList<>(); + p2ModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_a", "p2")); + p2ModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_b", "p2")); + p2ModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_c", "p2")); + assertThat(p2ModifiedTimesBefore).isEqualTo(p2ModifiedTimesAfter); + + assertQuery("SELECT * FROM " + tableName, "VALUES ('p1', -1, 3, 4), ('p2', 3, 5, 7), ('p3', 4, 6, 7), ('p4', 5, 8, 7)"); + + // update single column in multi rows + List colAModifiedTimesBefore = new ArrayList<>(); + List colBModifiedTimesBefore = new ArrayList<>(); + List colCModifiedTimesBefore = new ArrayList<>(); + for (String pk : List.of("p1", "p2", "p3", "p4")) { + colAModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_a", pk)); + colBModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_b", pk)); + colCModifiedTimesBefore.add(readLatestColumnVersion(connection, schema, tableName, "col_c", pk)); + } + // update all col_b + assertUpdate("UPDATE " + tableName + " SET col_b = col_b + col_c WHERE pk IS NOT NULL AND col_b > 0", 4); + List colAModifiedTimesAfter = new ArrayList<>(); + List colBModifiedTimesAfter = new ArrayList<>(); + List colCModifiedTimesAfter = new ArrayList<>(); + for (String pk : List.of("p1", "p2", "p3", "p4")) { + colAModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_a", pk)); + colBModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_b", pk)); + colCModifiedTimesAfter.add(readLatestColumnVersion(connection, schema, tableName, "col_c", pk)); + } + // col_a and col_c are not updated + assertThat(colAModifiedTimesAfter).isEqualTo(colAModifiedTimesBefore); + assertThat(colCModifiedTimesAfter).isEqualTo(colCModifiedTimesBefore); + // col_b all is updated + assertThat(colBModifiedTimesAfter).hasSameSizeAs(colBModifiedTimesBefore); + for (int i = 0; i < colBModifiedTimesAfter.size(); i++) { + assertThat(colBModifiedTimesAfter.get(i)).isGreaterThan(colBModifiedTimesBefore.get(i)); + } + + assertQuery("SELECT * FROM " + tableName, "VALUES ('p1', -1, 7, 4), ('p2', 3, 12, 7), ('p3', 4, 13, 7), ('p4', 5, 15, 7)"); + + // multi rows and multi columns update + + // update the col_a, col_c, pk=p3,p4 + assertUpdate("UPDATE " + tableName + " SET col_a = 0, col_c = col_c + 1 WHERE pk IS NOT NULL AND col_a > 0 AND col_b != 12", 2); + + // col_a pk=p1,p2 not change, pk=p3,p4 updated + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_a", "p1")).isEqualTo(colAModifiedTimesAfter.get(0)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_a", "p2")).isEqualTo(colAModifiedTimesAfter.get(1)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_a", "p3")).isGreaterThan(colAModifiedTimesAfter.get(2)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_a", "p4")).isGreaterThan(colAModifiedTimesAfter.get(3)); + + // col_b no changes + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_b", "p1")).isEqualTo(colBModifiedTimesAfter.get(0)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_b", "p2")).isEqualTo(colBModifiedTimesAfter.get(1)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_b", "p3")).isEqualTo(colBModifiedTimesAfter.get(2)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_b", "p4")).isEqualTo(colBModifiedTimesAfter.get(3)); + + // col_c is the same as col_a. pk=p1,p2 are not updated, pk=p3,p4 are updated + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_c", "p1")).isEqualTo(colCModifiedTimesAfter.get(0)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_c", "p2")).isEqualTo(colCModifiedTimesAfter.get(1)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_c", "p3")).isGreaterThan(colCModifiedTimesAfter.get(2)); + assertThat(readLatestColumnVersion(connection, schema, tableName, "col_c", "p4")).isGreaterThan(colCModifiedTimesAfter.get(3)); + + assertQuery("SELECT * FROM " + tableName, "VALUES ('p1', -1, 7, 4), ('p2', 3, 12, 7), ('p3', 0, 13, 8), ('p4', 0, 15, 8)"); + } + + assertUpdate("DROP TABLE " + tableName); + } + + private byte[] getActualQualifier(String tableName, String columnName) + { + String query = "SELECT COLUMN_QUALIFIER FROM SYSTEM.CATALOG WHERE TABLE_NAME = ? AND COLUMN_NAME = ?"; + try (Connection connection = DriverManager.getConnection(testingPhoenixServer.getJdbcUrl()); + PreparedStatement statement = connection.prepareStatement(query)) { + statement.setString(1, tableName); + statement.setString(2, columnName); + ResultSet rs = statement.executeQuery(); + if (rs.next()) { + return rs.getBytes("COLUMN_QUALIFIER"); + } + throw new RuntimeException("Failed to get actual qualifier"); + } + catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private long readLatestColumnVersion(org.apache.hadoop.hbase.client.Connection connection, String schema, String tableName, String columnName, String rowkeyValue) + throws IOException + { + tableName = tableName.toUpperCase(ENGLISH); + columnName = columnName.toUpperCase(ENGLISH); + schema = schema.toUpperCase(ENGLISH); + + TableName name = TableName.valueOf(schema, tableName); + Table table = connection.getTable(name); + + byte[] rowKey = Bytes.toBytes(rowkeyValue); + byte[] columnFamily = Bytes.toBytes("0"); + byte[] column = getActualQualifier(tableName, columnName); + + Get getVersion = new Get(rowKey); + getVersion.addColumn(columnFamily, column); + // only read the latest version + getVersion.readVersions(1); + + Result result = table.get(getVersion); + Cell[] cells = result.rawCells(); + assertThat(cells).hasSize(1); + table.close(); + return cells[0].getTimestamp(); + } + @Override protected OptionalInt maxTableNameLength() { diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestingPhoenixServer.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestingPhoenixServer.java index 83f70efb5a6c..0fb1dd163b91 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestingPhoenixServer.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestingPhoenixServer.java @@ -20,7 +20,9 @@ import org.apache.hadoop.hbase.HBaseTestingUtility; import org.apache.hadoop.hbase.MiniHBaseCluster; import org.apache.hadoop.hbase.StartMiniClusterOption; +import org.apache.hadoop.hbase.client.Connection; import org.apache.hadoop.hbase.zookeeper.MiniZooKeeperCluster; +import org.apache.phoenix.query.HBaseFactoryProvider; import java.io.IOException; import java.io.UncheckedIOException; @@ -88,6 +90,12 @@ private TestingPhoenixServer() } } + public Connection getConnection() + throws IOException + { + return HBaseFactoryProvider.getHConnectionFactory().createConnection(this.conf); + } + @Override public void close() {