Skip to content

Commit f17fdac

Browse files
chenjian2664electrum
authored andcommitted
Support partial update in Phoenix connector
1 parent ae2448a commit f17fdac

File tree

5 files changed

+394
-27
lines changed

5 files changed

+394
-27
lines changed

plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeSink.java

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
*/
1414
package io.trino.plugin.phoenix5;
1515

16+
import com.google.common.base.Suppliers;
1617
import com.google.common.collect.ImmutableList;
18+
import com.google.common.collect.ImmutableMap;
1719
import io.airlift.slice.Slice;
1820
import io.trino.plugin.jdbc.JdbcClient;
1921
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
@@ -32,19 +34,27 @@
3234
import io.trino.spi.connector.ConnectorSession;
3335
import io.trino.spi.type.RowType;
3436
import io.trino.spi.type.Type;
37+
import org.apache.phoenix.util.SchemaUtil;
3538

3639
import java.sql.Connection;
3740
import java.sql.SQLException;
3841
import java.util.Collection;
42+
import java.util.HashMap;
3943
import java.util.List;
44+
import java.util.Map;
4045
import java.util.Optional;
46+
import java.util.Set;
4147
import java.util.concurrent.CompletableFuture;
48+
import java.util.function.Supplier;
4249
import java.util.stream.IntStream;
4350

4451
import static com.google.common.base.Preconditions.checkArgument;
52+
import static com.google.common.collect.ImmutableList.toImmutableList;
53+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
4554
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
4655
import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY;
4756
import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY_COLUMN_HANDLE;
57+
import static io.trino.spi.type.IntegerType.INTEGER;
4858
import static io.trino.spi.type.TinyintType.TINYINT;
4959
import static java.util.concurrent.CompletableFuture.completedFuture;
5060
import static org.apache.phoenix.util.SchemaUtil.getEscapedArgument;
@@ -56,9 +66,11 @@ public class PhoenixMergeSink
5666
private final int columnCount;
5767

5868
private final ConnectorPageSink insertSink;
59-
private final ConnectorPageSink updateSink;
69+
private final Map<Integer, Supplier<ConnectorPageSink>> updateSinkSuppliers;
6070
private final ConnectorPageSink deleteSink;
6171

72+
private final Map<Integer, Set<Integer>> updateCaseChannels;
73+
6274
public PhoenixMergeSink(
6375
ConnectorSession session,
6476
ConnectorMergeTableHandle mergeHandle,
@@ -73,7 +85,6 @@ public PhoenixMergeSink(
7385
this.columnCount = phoenixOutputTableHandle.getColumnNames().size();
7486

7587
this.insertSink = new JdbcPageSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, JdbcClient::buildInsertSql);
76-
this.updateSink = createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier);
7788

7889
ImmutableList.Builder<String> mergeRowIdFieldNamesBuilder = ImmutableList.builder();
7990
ImmutableList.Builder<Type> mergeRowIdFieldTypesBuilder = ImmutableList.builder();
@@ -84,6 +95,31 @@ public PhoenixMergeSink(
8495
mergeRowIdFieldTypesBuilder.add(field.getType());
8596
}
8697
List<String> mergeRowIdFieldNames = mergeRowIdFieldNamesBuilder.build();
98+
List<String> dataColumnNames = phoenixOutputTableHandle.getColumnNames().stream()
99+
.map(SchemaUtil::getEscapedArgument)
100+
.collect(toImmutableList());
101+
Set<Integer> mergeRowIdChannels = mergeRowIdFieldNames.stream()
102+
.map(dataColumnNames::indexOf)
103+
.collect(toImmutableSet());
104+
105+
Map<Integer, Set<Integer>> updateCaseChannels = new HashMap<>();
106+
for (Map.Entry<Integer, Set<Integer>> entry : phoenixMergeTableHandle.updateCaseColumns().entrySet()) {
107+
updateCaseChannels.put(entry.getKey(), entry.getValue());
108+
if (!hasRowKey) {
109+
checkArgument(!mergeRowIdChannels.isEmpty() && !mergeRowIdChannels.contains(-1), "No primary keys found");
110+
updateCaseChannels.get(entry.getKey()).addAll(mergeRowIdChannels);
111+
}
112+
}
113+
this.updateCaseChannels = ImmutableMap.copyOf(updateCaseChannels);
114+
115+
ImmutableMap.Builder<Integer, Supplier<ConnectorPageSink>> updateSinksBuilder = ImmutableMap.builder();
116+
for (Map.Entry<Integer, Set<Integer>> entry : this.updateCaseChannels.entrySet()) {
117+
int caseNumber = entry.getKey();
118+
Supplier<ConnectorPageSink> updateSupplier = Suppliers.memoize(() -> createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, entry.getValue()));
119+
updateSinksBuilder.put(caseNumber, updateSupplier);
120+
}
121+
this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow();
122+
87123
this.deleteSink = createDeleteSink(session, mergeRowIdFieldTypesBuilder.build(), phoenixClient, phoenixMergeTableHandle, mergeRowIdFieldNames, pageSinkId, remoteQueryModifier, queryBuilder);
88124
}
89125

@@ -92,12 +128,17 @@ private static ConnectorPageSink createUpdateSink(
92128
PhoenixOutputTableHandle phoenixOutputTableHandle,
93129
PhoenixClient phoenixClient,
94130
ConnectorPageSinkId pageSinkId,
95-
RemoteQueryModifier remoteQueryModifier)
131+
RemoteQueryModifier remoteQueryModifier,
132+
Set<Integer> updateChannels)
96133
{
97134
ImmutableList.Builder<String> columnNamesBuilder = ImmutableList.builder();
98135
ImmutableList.Builder<Type> columnTypesBuilder = ImmutableList.builder();
99-
columnNamesBuilder.addAll(phoenixOutputTableHandle.getColumnNames());
100-
columnTypesBuilder.addAll(phoenixOutputTableHandle.getColumnTypes());
136+
for (int channel = 0; channel < phoenixOutputTableHandle.getColumnNames().size(); channel++) {
137+
if (updateChannels.contains(channel)) {
138+
columnNamesBuilder.add(phoenixOutputTableHandle.getColumnNames().get(channel));
139+
columnTypesBuilder.add(phoenixOutputTableHandle.getColumnTypes().get(channel));
140+
}
141+
}
101142
if (phoenixOutputTableHandle.rowkeyColumn().isPresent()) {
102143
columnNamesBuilder.add(ROWKEY);
103144
columnTypesBuilder.add(ROWKEY_COLUMN_HANDLE.getColumnType());
@@ -168,8 +209,10 @@ public void storeMergedRows(Page page)
168209
int insertPositionCount = 0;
169210
int[] deletePositions = new int[positionCount];
170211
int deletePositionCount = 0;
171-
int[] updatePositions = new int[positionCount];
172-
int updatePositionCount = 0;
212+
213+
Block updateCaseBlock = page.getBlock(columnCount + 1);
214+
Map<Integer, int[]> updatePositions = new HashMap<>();
215+
Map<Integer, Integer> updatePositionCounts = new HashMap<>();
173216

174217
for (int position = 0; position < positionCount; position++) {
175218
int operation = TINYINT.getByte(operationBlock, position);
@@ -183,8 +226,10 @@ public void storeMergedRows(Page page)
183226
deletePositionCount++;
184227
}
185228
case UPDATE_OPERATION_NUMBER -> {
186-
updatePositions[updatePositionCount] = position;
187-
updatePositionCount++;
229+
int caseNumber = INTEGER.getInt(updateCaseBlock, position);
230+
int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0);
231+
updatePositions.computeIfAbsent(caseNumber, _ -> new int[positionCount])[updatePositionCount] = position;
232+
updatePositionCounts.put(caseNumber, updatePositionCount + 1);
188233
}
189234
default -> throw new IllegalStateException("Unexpected value: " + operation);
190235
}
@@ -203,13 +248,21 @@ public void storeMergedRows(Page page)
203248
deleteSink.appendPage(new Page(deletePositionCount, deleteBlocks));
204249
}
205250

206-
if (updatePositionCount > 0) {
207-
Page updatePage = dataPage.getPositions(updatePositions, 0, updatePositionCount);
208-
if (hasRowKey) {
209-
updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions, 0, updatePositionCount));
210-
}
251+
for (Map.Entry<Integer, Integer> entry : updatePositionCounts.entrySet()) {
252+
int caseNumber = entry.getKey();
253+
int updatePositionCount = entry.getValue();
254+
if (updatePositionCount > 0) {
255+
checkArgument(updatePositions.containsKey(caseNumber), "Unexpected case number %s", caseNumber);
211256

212-
updateSink.appendPage(updatePage);
257+
Page updatePage = dataPage
258+
.getColumns(updateCaseChannels.get(caseNumber).stream().mapToInt(Integer::intValue).sorted().toArray())
259+
.getPositions(updatePositions.get(caseNumber), 0, updatePositionCount);
260+
if (hasRowKey) {
261+
updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions.get(caseNumber), 0, updatePositionCount));
262+
}
263+
264+
updateSinkSuppliers.get(caseNumber).get().appendPage(updatePage);
265+
}
213266
}
214267
}
215268

@@ -218,7 +271,7 @@ public CompletableFuture<Collection<Slice>> finish()
218271
{
219272
insertSink.finish();
220273
deleteSink.finish();
221-
updateSink.finish();
274+
updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish);
222275
return completedFuture(ImmutableList.of());
223276
}
224277

@@ -227,6 +280,6 @@ public void abort()
227280
{
228281
insertSink.abort();
229282
deleteSink.abort();
230-
updateSink.abort();
283+
updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort);
231284
}
232285
}

plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMergeTableHandle.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,32 @@
2121
import io.trino.spi.connector.ConnectorMergeTableHandle;
2222
import io.trino.spi.predicate.TupleDomain;
2323

24+
import java.util.Map;
25+
import java.util.Set;
26+
2427
import static java.util.Objects.requireNonNull;
2528

2629
public record PhoenixMergeTableHandle(
2730
JdbcTableHandle tableHandle,
2831
PhoenixOutputTableHandle phoenixOutputTableHandle,
2932
JdbcColumnHandle mergeRowIdColumnHandle,
30-
TupleDomain<ColumnHandle> primaryKeysDomain)
33+
TupleDomain<ColumnHandle> primaryKeysDomain,
34+
Map<Integer, Set<Integer>> updateCaseColumns)
3135
implements ConnectorMergeTableHandle
3236
{
3337
@JsonCreator
3438
public PhoenixMergeTableHandle(
3539
@JsonProperty("tableHandle") JdbcTableHandle tableHandle,
3640
@JsonProperty("phoenixOutputTableHandle") PhoenixOutputTableHandle phoenixOutputTableHandle,
3741
@JsonProperty("mergeRowIdColumnHandle") JdbcColumnHandle mergeRowIdColumnHandle,
38-
@JsonProperty("primaryKeysDomain") TupleDomain<ColumnHandle> primaryKeysDomain)
42+
@JsonProperty("primaryKeysDomain") TupleDomain<ColumnHandle> primaryKeysDomain,
43+
@JsonProperty("updateCaseColumns") Map<Integer, Set<Integer>> updateCaseColumns)
3944
{
4045
this.tableHandle = requireNonNull(tableHandle, "tableHandle is null");
4146
this.phoenixOutputTableHandle = requireNonNull(phoenixOutputTableHandle, "phoenixOutputTableHandle is null");
4247
this.mergeRowIdColumnHandle = requireNonNull(mergeRowIdColumnHandle, "mergeRowIdColumnHandle is null");
4348
this.primaryKeysDomain = requireNonNull(primaryKeysDomain, "primaryKeysDomain is null");
49+
this.updateCaseColumns = requireNonNull(updateCaseColumns, "updateCaseColumns is null");
4450
}
4551

4652
@JsonProperty
@@ -70,4 +76,11 @@ public TupleDomain<ColumnHandle> primaryKeysDomain()
7076
{
7177
return primaryKeysDomain;
7278
}
79+
80+
@Override
81+
@JsonProperty
82+
public Map<Integer, Set<Integer>> updateCaseColumns()
83+
{
84+
return updateCaseColumns;
85+
}
7386
}

plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import static com.google.common.base.Preconditions.checkArgument;
6767
import static com.google.common.base.Verify.verify;
6868
import static com.google.common.collect.ImmutableList.toImmutableList;
69+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
6970
import static io.trino.plugin.jdbc.JdbcMetadata.getColumns;
7071
import static io.trino.plugin.phoenix5.MetadataUtil.getEscapedTableName;
7172
import static io.trino.plugin.phoenix5.MetadataUtil.toTrinoSchemaName;
@@ -350,11 +351,23 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT
350351
primaryKeysDomainBuilder.put(columnHandle, dummy);
351352
}
352353

354+
ImmutableMap.Builder<Integer, Set<Integer>> updateColumnChannelsBuilder = ImmutableMap.builder();
355+
for (Map.Entry<Integer, Collection<ColumnHandle>> entry : updateColumnHandles.entrySet()) {
356+
int caseNumber = entry.getKey();
357+
Set<Integer> updateColumnChannels = entry.getValue().stream()
358+
.map(JdbcColumnHandle.class::cast)
359+
.peek(column -> checkArgument(columns.contains(column), "update column %s not found in the target table", column))
360+
.map(columns::indexOf)
361+
.collect(toImmutableSet());
362+
updateColumnChannelsBuilder.put(caseNumber, updateColumnChannels);
363+
}
364+
353365
return new PhoenixMergeTableHandle(
354366
phoenixClient.updatedScanColumnTable(session, handle, handle.getColumns(), mergeRowIdColumnHandle),
355367
phoenixOutputTableHandle,
356368
mergeRowIdColumnHandle,
357-
TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow()));
369+
TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow()),
370+
updateColumnChannelsBuilder.buildOrThrow());
358371
}
359372

360373
@Override

0 commit comments

Comments
 (0)