1313 */
1414package io .trino .plugin .phoenix5 ;
1515
16+ import com .google .common .base .Suppliers ;
1617import com .google .common .collect .ImmutableList ;
18+ import com .google .common .collect .ImmutableMap ;
1719import io .airlift .slice .Slice ;
1820import io .trino .plugin .jdbc .JdbcClient ;
1921import io .trino .plugin .jdbc .JdbcOutputTableHandle ;
3234import io .trino .spi .connector .ConnectorSession ;
3335import io .trino .spi .type .RowType ;
3436import io .trino .spi .type .Type ;
37+ import org .apache .phoenix .util .SchemaUtil ;
3538
3639import java .sql .Connection ;
3740import java .sql .SQLException ;
3841import java .util .Collection ;
42+ import java .util .HashMap ;
3943import java .util .List ;
44+ import java .util .Map ;
4045import java .util .Optional ;
46+ import java .util .Set ;
4147import java .util .concurrent .CompletableFuture ;
48+ import java .util .function .Supplier ;
4249import java .util .stream .IntStream ;
4350
4451import static com .google .common .base .Preconditions .checkArgument ;
52+ import static com .google .common .collect .ImmutableList .toImmutableList ;
53+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
4554import static io .trino .plugin .jdbc .JdbcErrorCode .JDBC_ERROR ;
4655import static io .trino .plugin .phoenix5 .PhoenixClient .ROWKEY ;
4756import static io .trino .plugin .phoenix5 .PhoenixClient .ROWKEY_COLUMN_HANDLE ;
57+ import static io .trino .spi .type .IntegerType .INTEGER ;
4858import static io .trino .spi .type .TinyintType .TINYINT ;
4959import static java .util .concurrent .CompletableFuture .completedFuture ;
5060import static org .apache .phoenix .util .SchemaUtil .getEscapedArgument ;
@@ -56,9 +66,11 @@ public class PhoenixMergeSink
5666 private final int columnCount ;
5767
5868 private final ConnectorPageSink insertSink ;
59- private final ConnectorPageSink updateSink ;
69+ private final Map < Integer , Supplier < ConnectorPageSink >> updateSinkSuppliers ;
6070 private final ConnectorPageSink deleteSink ;
6171
72+ private final Map <Integer , Set <Integer >> updateCaseChannels ;
73+
6274 public PhoenixMergeSink (
6375 ConnectorSession session ,
6476 ConnectorMergeTableHandle mergeHandle ,
@@ -73,7 +85,6 @@ public PhoenixMergeSink(
7385 this .columnCount = phoenixOutputTableHandle .getColumnNames ().size ();
7486
7587 this .insertSink = new JdbcPageSink (session , phoenixOutputTableHandle , phoenixClient , pageSinkId , remoteQueryModifier , JdbcClient ::buildInsertSql );
76- this .updateSink = createUpdateSink (session , phoenixOutputTableHandle , phoenixClient , pageSinkId , remoteQueryModifier );
7788
7889 ImmutableList .Builder <String > mergeRowIdFieldNamesBuilder = ImmutableList .builder ();
7990 ImmutableList .Builder <Type > mergeRowIdFieldTypesBuilder = ImmutableList .builder ();
@@ -84,6 +95,31 @@ public PhoenixMergeSink(
8495 mergeRowIdFieldTypesBuilder .add (field .getType ());
8596 }
8697 List <String > mergeRowIdFieldNames = mergeRowIdFieldNamesBuilder .build ();
98+ List <String > dataColumnNames = phoenixOutputTableHandle .getColumnNames ().stream ()
99+ .map (SchemaUtil ::getEscapedArgument )
100+ .collect (toImmutableList ());
101+ Set <Integer > mergeRowIdChannels = mergeRowIdFieldNames .stream ()
102+ .map (dataColumnNames ::indexOf )
103+ .collect (toImmutableSet ());
104+
105+ Map <Integer , Set <Integer >> updateCaseChannels = new HashMap <>();
106+ for (Map .Entry <Integer , Set <Integer >> entry : phoenixMergeTableHandle .updateCaseColumns ().entrySet ()) {
107+ updateCaseChannels .put (entry .getKey (), entry .getValue ());
108+ if (!hasRowKey ) {
109+ checkArgument (!mergeRowIdChannels .isEmpty () && !mergeRowIdChannels .contains (-1 ), "No primary keys found" );
110+ updateCaseChannels .get (entry .getKey ()).addAll (mergeRowIdChannels );
111+ }
112+ }
113+ this .updateCaseChannels = ImmutableMap .copyOf (updateCaseChannels );
114+
115+ ImmutableMap .Builder <Integer , Supplier <ConnectorPageSink >> updateSinksBuilder = ImmutableMap .builder ();
116+ for (Map .Entry <Integer , Set <Integer >> entry : this .updateCaseChannels .entrySet ()) {
117+ int caseNumber = entry .getKey ();
118+ Supplier <ConnectorPageSink > updateSupplier = Suppliers .memoize (() -> createUpdateSink (session , phoenixOutputTableHandle , phoenixClient , pageSinkId , remoteQueryModifier , entry .getValue ()));
119+ updateSinksBuilder .put (caseNumber , updateSupplier );
120+ }
121+ this .updateSinkSuppliers = updateSinksBuilder .buildOrThrow ();
122+
87123 this .deleteSink = createDeleteSink (session , mergeRowIdFieldTypesBuilder .build (), phoenixClient , phoenixMergeTableHandle , mergeRowIdFieldNames , pageSinkId , remoteQueryModifier , queryBuilder );
88124 }
89125
@@ -92,12 +128,17 @@ private static ConnectorPageSink createUpdateSink(
92128 PhoenixOutputTableHandle phoenixOutputTableHandle ,
93129 PhoenixClient phoenixClient ,
94130 ConnectorPageSinkId pageSinkId ,
95- RemoteQueryModifier remoteQueryModifier )
131+ RemoteQueryModifier remoteQueryModifier ,
132+ Set <Integer > updateChannels )
96133 {
97134 ImmutableList .Builder <String > columnNamesBuilder = ImmutableList .builder ();
98135 ImmutableList .Builder <Type > columnTypesBuilder = ImmutableList .builder ();
99- columnNamesBuilder .addAll (phoenixOutputTableHandle .getColumnNames ());
100- columnTypesBuilder .addAll (phoenixOutputTableHandle .getColumnTypes ());
136+ for (int channel = 0 ; channel < phoenixOutputTableHandle .getColumnNames ().size (); channel ++) {
137+ if (updateChannels .contains (channel )) {
138+ columnNamesBuilder .add (phoenixOutputTableHandle .getColumnNames ().get (channel ));
139+ columnTypesBuilder .add (phoenixOutputTableHandle .getColumnTypes ().get (channel ));
140+ }
141+ }
101142 if (phoenixOutputTableHandle .rowkeyColumn ().isPresent ()) {
102143 columnNamesBuilder .add (ROWKEY );
103144 columnTypesBuilder .add (ROWKEY_COLUMN_HANDLE .getColumnType ());
@@ -168,8 +209,10 @@ public void storeMergedRows(Page page)
168209 int insertPositionCount = 0 ;
169210 int [] deletePositions = new int [positionCount ];
170211 int deletePositionCount = 0 ;
171- int [] updatePositions = new int [positionCount ];
172- int updatePositionCount = 0 ;
212+
213+ Block updateCaseBlock = page .getBlock (columnCount + 1 );
214+ Map <Integer , int []> updatePositions = new HashMap <>();
215+ Map <Integer , Integer > updatePositionCounts = new HashMap <>();
173216
174217 for (int position = 0 ; position < positionCount ; position ++) {
175218 int operation = TINYINT .getByte (operationBlock , position );
@@ -183,8 +226,10 @@ public void storeMergedRows(Page page)
183226 deletePositionCount ++;
184227 }
185228 case UPDATE_OPERATION_NUMBER -> {
186- updatePositions [updatePositionCount ] = position ;
187- updatePositionCount ++;
229+ int caseNumber = INTEGER .getInt (updateCaseBlock , position );
230+ int updatePositionCount = updatePositionCounts .getOrDefault (caseNumber , 0 );
231+ updatePositions .computeIfAbsent (caseNumber , _ -> new int [positionCount ])[updatePositionCount ] = position ;
232+ updatePositionCounts .put (caseNumber , updatePositionCount + 1 );
188233 }
189234 default -> throw new IllegalStateException ("Unexpected value: " + operation );
190235 }
@@ -203,13 +248,21 @@ public void storeMergedRows(Page page)
203248 deleteSink .appendPage (new Page (deletePositionCount , deleteBlocks ));
204249 }
205250
206- if ( updatePositionCount > 0 ) {
207- Page updatePage = dataPage . getPositions ( updatePositions , 0 , updatePositionCount );
208- if ( hasRowKey ) {
209- updatePage = updatePage . appendColumn ( rowIdFields . get ( 0 ). getPositions ( updatePositions , 0 , updatePositionCount ));
210- }
251+ for ( Map . Entry < Integer , Integer > entry : updatePositionCounts . entrySet () ) {
252+ int caseNumber = entry . getKey ( );
253+ int updatePositionCount = entry . getValue ();
254+ if ( updatePositionCount > 0 ) {
255+ checkArgument ( updatePositions . containsKey ( caseNumber ), "Unexpected case number %s" , caseNumber );
211256
212- updateSink .appendPage (updatePage );
257+ Page updatePage = dataPage
258+ .getColumns (updateCaseChannels .get (caseNumber ).stream ().mapToInt (Integer ::intValue ).sorted ().toArray ())
259+ .getPositions (updatePositions .get (caseNumber ), 0 , updatePositionCount );
260+ if (hasRowKey ) {
261+ updatePage = updatePage .appendColumn (rowIdFields .get (0 ).getPositions (updatePositions .get (caseNumber ), 0 , updatePositionCount ));
262+ }
263+
264+ updateSinkSuppliers .get (caseNumber ).get ().appendPage (updatePage );
265+ }
213266 }
214267 }
215268
@@ -218,7 +271,7 @@ public CompletableFuture<Collection<Slice>> finish()
218271 {
219272 insertSink .finish ();
220273 deleteSink .finish ();
221- updateSink . finish ( );
274+ updateSinkSuppliers . values (). stream (). map ( Supplier :: get ). forEach ( ConnectorPageSink :: finish );
222275 return completedFuture (ImmutableList .of ());
223276 }
224277
@@ -227,6 +280,6 @@ public void abort()
227280 {
228281 insertSink .abort ();
229282 deleteSink .abort ();
230- updateSink . abort ( );
283+ updateSinkSuppliers . values (). stream (). map ( Supplier :: get ). forEach ( ConnectorPageSink :: abort );
231284 }
232285}
0 commit comments