18
18
19
19
import com .alibaba .graphscope .common .antlr4 .ExprUniqueAliasInfer ;
20
20
import com .alibaba .graphscope .common .antlr4 .ExprVisitorResult ;
21
- import com .alibaba .graphscope .common .ir .rel .GraphLogicalAggregate ;
22
21
import com .alibaba .graphscope .common .ir .rel .GraphProcedureCall ;
23
22
import com .alibaba .graphscope .common .ir .rel .graph .GraphLogicalGetV ;
24
23
import com .alibaba .graphscope .common .ir .rel .graph .GraphLogicalPathExpand ;
25
24
import com .alibaba .graphscope .common .ir .rel .type .group .GraphAggCall ;
26
25
import com .alibaba .graphscope .common .ir .rex .RexTmpVariableConverter ;
27
- import com .alibaba .graphscope .common .ir .rex .RexVariableAliasCollector ;
28
26
import com .alibaba .graphscope .common .ir .tools .GraphBuilder ;
29
27
import com .alibaba .graphscope .common .ir .tools .config .GraphOpt ;
30
28
import com .alibaba .graphscope .grammar .CypherGSBaseVisitor ;
39
37
import org .apache .calcite .plan .RelOptUtil ;
40
38
import org .apache .calcite .plan .RelTraitSet ;
41
39
import org .apache .calcite .rel .RelNode ;
40
+ import org .apache .calcite .rel .type .RelDataType ;
42
41
import org .apache .calcite .rex .RexCall ;
43
42
import org .apache .calcite .rex .RexNode ;
44
43
import org .apache .calcite .rex .RexSubQuery ;
49
48
import java .util .ArrayList ;
50
49
import java .util .List ;
51
50
import java .util .Objects ;
51
+ import java .util .concurrent .atomic .AtomicReference ;
52
52
import java .util .stream .Collectors ;
53
53
54
54
public class GraphBuilderVisitor extends CypherGSBaseVisitor <GraphBuilder > {
@@ -271,49 +271,34 @@ public GraphBuilder visitOC_ProjectionBody(CypherGSParser.OC_ProjectionBodyConte
271
271
List <RexNode > keyExprs = new ArrayList <>();
272
272
List <String > keyAliases = new ArrayList <>();
273
273
List <RelBuilder .AggCall > aggCalls = new ArrayList <>();
274
- List <RexNode > extraExprs = new ArrayList <>();
275
- List <String > extraAliases = new ArrayList <>();
276
- if (isGroupPattern (ctx , keyExprs , keyAliases , aggCalls , extraExprs , extraAliases )) {
274
+ AtomicReference <ColumnOrder > columnManagerRef = new AtomicReference <>();
275
+ if (isGroupPattern (ctx , keyExprs , keyAliases , aggCalls , columnManagerRef )) {
277
276
RelBuilder .GroupKey groupKey ;
278
277
if (keyExprs .isEmpty ()) {
279
278
groupKey = builder .groupKey ();
280
279
} else {
281
280
groupKey = builder .groupKey (keyExprs , keyAliases );
282
281
}
283
282
builder .aggregate (groupKey , aggCalls );
284
- if (!extraExprs .isEmpty ()) {
283
+ RelDataType inputType = builder .peek ().getRowType ();
284
+ List <ColumnOrder .Field > originalFields =
285
+ inputType .getFieldList ().stream ()
286
+ .map (
287
+ k ->
288
+ new ColumnOrder .Field (
289
+ builder .variable (k .getName ()), k .getName ()))
290
+ .collect (Collectors .toList ());
291
+ List <ColumnOrder .Field > newFields = columnManagerRef .get ().getFields (inputType );
292
+ if (!originalFields .equals (newFields )) {
293
+ List <RexNode > extraExprs = new ArrayList <>();
294
+ List <@ Nullable String > extraAliases = new ArrayList <>();
285
295
RexTmpVariableConverter converter = new RexTmpVariableConverter (true , builder );
286
- extraExprs =
287
- extraExprs .stream ()
288
- .map (k -> k .accept (converter ))
289
- .collect (Collectors .toList ());
290
- List <RexNode > projectExprs = Lists .newArrayList ();
291
- List <String > projectAliases = Lists .newArrayList ();
292
- List <String > extraVarNames = Lists .newArrayList ();
293
- RexVariableAliasCollector <String > varNameCollector =
294
- new RexVariableAliasCollector <>(
295
- true ,
296
- v -> {
297
- String [] splits = v .getName ().split ("\\ ." );
298
- return splits [0 ];
299
- });
300
- extraExprs .forEach (k -> extraVarNames .addAll (k .accept (varNameCollector )));
301
- GraphLogicalAggregate aggregate = (GraphLogicalAggregate ) builder .peek ();
302
- aggregate
303
- .getRowType ()
304
- .getFieldList ()
305
- .forEach (
306
- field -> {
307
- if (!extraVarNames .contains (field .getName ())) {
308
- projectExprs .add (builder .variable (field .getName ()));
309
- projectAliases .add (field .getName ());
310
- }
311
- });
312
- for (int i = 0 ; i < extraExprs .size (); ++i ) {
313
- projectExprs .add (extraExprs .get (i ));
314
- projectAliases .add (extraAliases .get (i ));
315
- }
316
- builder .project (projectExprs , projectAliases , false );
296
+ newFields .forEach (
297
+ k -> {
298
+ extraExprs .add (k .getExpr ().accept (converter ));
299
+ extraAliases .add (k .getAlias ());
300
+ });
301
+ builder .project (extraExprs , extraAliases , false );
317
302
}
318
303
} else if (isDistinct ) {
319
304
builder .aggregate (builder .groupKey (keyExprs , keyAliases ));
@@ -334,22 +319,28 @@ private boolean isGroupPattern(
334
319
List <RexNode > keyExprs ,
335
320
List <String > keyAliases ,
336
321
List <RelBuilder .AggCall > aggCalls ,
337
- List < RexNode > extraExprs ,
338
- List <String > extraAliases ) {
322
+ AtomicReference < ColumnOrder > columnManagerRef ) {
323
+ List <ColumnOrder . FieldSupplier > fieldSuppliers = Lists . newArrayList ();
339
324
for (CypherGSParser .OC_ProjectionItemContext itemCtx :
340
325
ctx .oC_ProjectionItems ().oC_ProjectionItem ()) {
341
326
ExprVisitorResult item = expressionVisitor .visitOC_Expression (itemCtx .oC_Expression ());
342
327
String alias =
343
328
(itemCtx .AS () == null ) ? null : Utils .getAliasName (itemCtx .oC_Variable ());
344
329
if (item .getAggCalls ().isEmpty ()) {
330
+ int ordinal = keyExprs .size ();
331
+ fieldSuppliers .add (new ColumnOrder .FieldSupplier .Default (builder , () -> ordinal ));
345
332
keyExprs .add (item .getExpr ());
346
333
keyAliases .add (alias );
347
334
} else {
348
335
if (item .getExpr () instanceof RexCall ) {
349
- extraExprs .add (item . getExpr ());
350
- extraAliases . add ( alias );
336
+ fieldSuppliers .add (
337
+ ( RelDataType type ) -> new ColumnOrder . Field ( item . getExpr (), alias ) );
351
338
aggCalls .addAll (item .getAggCalls ());
352
339
} else if (item .getAggCalls ().size () == 1 ) { // count(a.name)
340
+ int ordinal = aggCalls .size ();
341
+ fieldSuppliers .add (
342
+ new ColumnOrder .FieldSupplier .Default (
343
+ builder , () -> keyExprs .size () + ordinal ));
353
344
GraphAggCall original = (GraphAggCall ) item .getAggCalls ().get (0 );
354
345
aggCalls .add (
355
346
new GraphAggCall (
@@ -363,6 +354,7 @@ private boolean isGroupPattern(
363
354
}
364
355
}
365
356
}
357
+ columnManagerRef .set (new ColumnOrder (fieldSuppliers ));
366
358
return !aggCalls .isEmpty ();
367
359
}
368
360
0 commit comments