Skip to content

Commit f41a761

Browse files
fix(interactive): Fix Aggregate Column Order Mismatch (#4364)
<!-- Thanks for your contribution! please review https://github.com/alibaba/GraphScope/blob/main/CONTRIBUTING.md before opening an issue. --> ## What do these changes do? as titled. <!-- Please give a short brief about these changes. --> ## Related issue number <!-- Are there any issues opened that will be resolved by merging this change? --> Fixes #4360 --------- Co-authored-by: BingqingLyu <[email protected]>
1 parent a9a865b commit f41a761

File tree

7 files changed

+187
-70
lines changed

7 files changed

+187
-70
lines changed

.github/workflows/gaia.yml

+6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ jobs:
5252
~/.cache/sccache
5353
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
5454

55+
- name: Install Rust
56+
uses: actions-rs/toolchain@v1
57+
with:
58+
toolchain: 1.81.0
59+
override: true
60+
5561
- name: Rust Format Check
5662
run: |
5763
cd ${GITHUB_WORKSPACE}/interactive_engine/executor && ./check_format.sh

interactive_engine/compiler/ir_experimental_ci.sh

+20-17
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,29 @@ if [ $exit_code -ne 0 ]; then
5757
fi
5858
unset DISTRIBUTED_ENV
5959

60-
# Test4: run cypher movie tests on experimental store via ir-core
61-
cd ${base_dir}/../executor/ir/target/release && DATA_PATH=/tmp/gstest/movie_graph_exp_bin RUST_LOG=info ./start_rpc_server --config ${base_dir}/../executor/ir/integrated/config &
62-
sleep 5s
63-
# start compiler service
64-
cd ${base_dir} && make run graph.schema:=../executor/ir/core/resource/movie_schema.json &
65-
sleep 10s
66-
export ENGINE_TYPE=pegasus
67-
# run cypher movie tests
68-
cd ${base_dir} && make cypher_test
69-
exit_code=$?
70-
# clean service
71-
ps -ef | grep "com.alibaba.graphscope.GraphServer" | grep -v grep | awk '{print $2}' | xargs kill -9 || true
72-
# report test result
73-
if [ $exit_code -ne 0 ]; then
74-
echo "ir cypher movie integration test on experimental store fail"
75-
exit 1
76-
fi
60+
## Test4: run cypher movie tests on experimental store via ir-core
61+
#cd ${base_dir}/../executor/ir/target/release && DATA_PATH=/tmp/gstest/movie_graph_exp_bin RUST_LOG=info ./start_rpc_server --config ${base_dir}/../executor/ir/integrated/config &
62+
#sleep 5s
63+
## start compiler service
64+
#cd ${base_dir} && make run graph.schema:=../executor/ir/core/resource/movie_schema.json &
65+
#sleep 10s
66+
#export ENGINE_TYPE=pegasus
67+
## run cypher movie tests
68+
#cd ${base_dir} && make cypher_test
69+
#exit_code=$?
70+
## clean service
71+
#ps -ef | grep "com.alibaba.graphscope.GraphServer" | grep -v grep | awk '{print $2}' | xargs kill -9 || true
72+
## report test result
73+
#if [ $exit_code -ne 0 ]; then
74+
# echo "ir cypher movie integration test on experimental store fail"
75+
# exit 1
76+
#fi
7777

7878

7979
# Test5: run cypher movie tests on experimental store via calcite-based ir
80+
# start engine service and load movie graph
81+
cd ${base_dir}/../executor/ir/target/release && DATA_PATH=/tmp/gstest/movie_graph_exp_bin RUST_LOG=info ./start_rpc_server --config ${base_dir}/../executor/ir/integrated/config &
82+
sleep 5s
8083
# restart compiler service
8184
cd ${base_dir} && make run graph.schema:=../executor/ir/core/resource/movie_schema.json graph.planner.opt=CBO graph.statistics:=./src/test/resources/statistics/movie_statistics.json graph.physical.opt=proto graph.planner.rules=FilterIntoJoinRule,FilterMatchRule,ExtendIntersectRule,ExpandGetVFusionRule &
8285
sleep 10s
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
*
3+
* * Copyright 2020 Alibaba Group Holding Limited.
4+
* *
5+
* * Licensed under the Apache License, Version 2.0 (the "License");
6+
* * you may not use this file except in compliance with the License.
7+
* * You may obtain a copy of the License at
8+
* *
9+
* * http://www.apache.org/licenses/LICENSE-2.0
10+
* *
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS,
13+
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* * See the License for the specific language governing permissions and
15+
* * limitations under the License.
16+
*
17+
*/
18+
19+
package com.alibaba.graphscope.cypher.antlr4.visitor;
20+
21+
import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
22+
import com.google.common.base.Objects;
23+
24+
import org.apache.calcite.rel.type.RelDataType;
25+
import org.apache.calcite.rex.RexNode;
26+
import org.checkerframework.checker.nullness.qual.Nullable;
27+
28+
import java.util.List;
29+
import java.util.function.Supplier;
30+
import java.util.stream.Collectors;
31+
32+
/**
33+
* ColumnOrder keeps fields as the same order with RETURN clause
34+
*/
35+
public class ColumnOrder {
36+
public static class Field {
37+
private final RexNode expr;
38+
private final String alias;
39+
40+
public Field(RexNode expr, String alias) {
41+
this.expr = expr;
42+
this.alias = alias;
43+
}
44+
45+
public RexNode getExpr() {
46+
return expr;
47+
}
48+
49+
public String getAlias() {
50+
return alias;
51+
}
52+
53+
@Override
54+
public boolean equals(Object o) {
55+
if (this == o) return true;
56+
if (o == null || getClass() != o.getClass()) return false;
57+
Field field = (Field) o;
58+
return Objects.equal(expr, field.expr) && Objects.equal(alias, field.alias);
59+
}
60+
61+
@Override
62+
public int hashCode() {
63+
return Objects.hashCode(expr, alias);
64+
}
65+
}
66+
67+
public interface FieldSupplier {
68+
Field get(RelDataType inputType);
69+
70+
class Default implements FieldSupplier {
71+
private final GraphBuilder builder;
72+
private final Supplier<Integer> ordinalSupplier;
73+
74+
public Default(GraphBuilder builder, Supplier<Integer> ordinalSupplier) {
75+
this.builder = builder;
76+
this.ordinalSupplier = ordinalSupplier;
77+
}
78+
79+
@Override
80+
public Field get(RelDataType inputType) {
81+
String aliasName = inputType.getFieldList().get(ordinalSupplier.get()).getName();
82+
return new Field(this.builder.variable(aliasName), aliasName);
83+
}
84+
}
85+
}
86+
87+
private final List<FieldSupplier> fieldSuppliers;
88+
89+
public ColumnOrder(List<FieldSupplier> fieldSuppliers) {
90+
this.fieldSuppliers = fieldSuppliers;
91+
}
92+
93+
public @Nullable List<Field> getFields(RelDataType inputType) {
94+
return this.fieldSuppliers.stream().map(k -> k.get(inputType)).collect(Collectors.toList());
95+
}
96+
}

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/GraphBuilderVisitor.java

+33-41
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818

1919
import com.alibaba.graphscope.common.antlr4.ExprUniqueAliasInfer;
2020
import com.alibaba.graphscope.common.antlr4.ExprVisitorResult;
21-
import com.alibaba.graphscope.common.ir.rel.GraphLogicalAggregate;
2221
import com.alibaba.graphscope.common.ir.rel.GraphProcedureCall;
2322
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalGetV;
2423
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalPathExpand;
2524
import com.alibaba.graphscope.common.ir.rel.type.group.GraphAggCall;
2625
import com.alibaba.graphscope.common.ir.rex.RexTmpVariableConverter;
27-
import com.alibaba.graphscope.common.ir.rex.RexVariableAliasCollector;
2826
import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
2927
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
3028
import com.alibaba.graphscope.grammar.CypherGSBaseVisitor;
@@ -39,6 +37,7 @@
3937
import org.apache.calcite.plan.RelOptUtil;
4038
import org.apache.calcite.plan.RelTraitSet;
4139
import org.apache.calcite.rel.RelNode;
40+
import org.apache.calcite.rel.type.RelDataType;
4241
import org.apache.calcite.rex.RexCall;
4342
import org.apache.calcite.rex.RexNode;
4443
import org.apache.calcite.rex.RexSubQuery;
@@ -49,6 +48,7 @@
4948
import java.util.ArrayList;
5049
import java.util.List;
5150
import java.util.Objects;
51+
import java.util.concurrent.atomic.AtomicReference;
5252
import java.util.stream.Collectors;
5353

5454
public class GraphBuilderVisitor extends CypherGSBaseVisitor<GraphBuilder> {
@@ -271,49 +271,34 @@ public GraphBuilder visitOC_ProjectionBody(CypherGSParser.OC_ProjectionBodyConte
271271
List<RexNode> keyExprs = new ArrayList<>();
272272
List<String> keyAliases = new ArrayList<>();
273273
List<RelBuilder.AggCall> aggCalls = new ArrayList<>();
274-
List<RexNode> extraExprs = new ArrayList<>();
275-
List<String> extraAliases = new ArrayList<>();
276-
if (isGroupPattern(ctx, keyExprs, keyAliases, aggCalls, extraExprs, extraAliases)) {
274+
AtomicReference<ColumnOrder> columnManagerRef = new AtomicReference<>();
275+
if (isGroupPattern(ctx, keyExprs, keyAliases, aggCalls, columnManagerRef)) {
277276
RelBuilder.GroupKey groupKey;
278277
if (keyExprs.isEmpty()) {
279278
groupKey = builder.groupKey();
280279
} else {
281280
groupKey = builder.groupKey(keyExprs, keyAliases);
282281
}
283282
builder.aggregate(groupKey, aggCalls);
284-
if (!extraExprs.isEmpty()) {
283+
RelDataType inputType = builder.peek().getRowType();
284+
List<ColumnOrder.Field> originalFields =
285+
inputType.getFieldList().stream()
286+
.map(
287+
k ->
288+
new ColumnOrder.Field(
289+
builder.variable(k.getName()), k.getName()))
290+
.collect(Collectors.toList());
291+
List<ColumnOrder.Field> newFields = columnManagerRef.get().getFields(inputType);
292+
if (!originalFields.equals(newFields)) {
293+
List<RexNode> extraExprs = new ArrayList<>();
294+
List<@Nullable String> extraAliases = new ArrayList<>();
285295
RexTmpVariableConverter converter = new RexTmpVariableConverter(true, builder);
286-
extraExprs =
287-
extraExprs.stream()
288-
.map(k -> k.accept(converter))
289-
.collect(Collectors.toList());
290-
List<RexNode> projectExprs = Lists.newArrayList();
291-
List<String> projectAliases = Lists.newArrayList();
292-
List<String> extraVarNames = Lists.newArrayList();
293-
RexVariableAliasCollector<String> varNameCollector =
294-
new RexVariableAliasCollector<>(
295-
true,
296-
v -> {
297-
String[] splits = v.getName().split("\\.");
298-
return splits[0];
299-
});
300-
extraExprs.forEach(k -> extraVarNames.addAll(k.accept(varNameCollector)));
301-
GraphLogicalAggregate aggregate = (GraphLogicalAggregate) builder.peek();
302-
aggregate
303-
.getRowType()
304-
.getFieldList()
305-
.forEach(
306-
field -> {
307-
if (!extraVarNames.contains(field.getName())) {
308-
projectExprs.add(builder.variable(field.getName()));
309-
projectAliases.add(field.getName());
310-
}
311-
});
312-
for (int i = 0; i < extraExprs.size(); ++i) {
313-
projectExprs.add(extraExprs.get(i));
314-
projectAliases.add(extraAliases.get(i));
315-
}
316-
builder.project(projectExprs, projectAliases, false);
296+
newFields.forEach(
297+
k -> {
298+
extraExprs.add(k.getExpr().accept(converter));
299+
extraAliases.add(k.getAlias());
300+
});
301+
builder.project(extraExprs, extraAliases, false);
317302
}
318303
} else if (isDistinct) {
319304
builder.aggregate(builder.groupKey(keyExprs, keyAliases));
@@ -334,22 +319,28 @@ private boolean isGroupPattern(
334319
List<RexNode> keyExprs,
335320
List<String> keyAliases,
336321
List<RelBuilder.AggCall> aggCalls,
337-
List<RexNode> extraExprs,
338-
List<String> extraAliases) {
322+
AtomicReference<ColumnOrder> columnManagerRef) {
323+
List<ColumnOrder.FieldSupplier> fieldSuppliers = Lists.newArrayList();
339324
for (CypherGSParser.OC_ProjectionItemContext itemCtx :
340325
ctx.oC_ProjectionItems().oC_ProjectionItem()) {
341326
ExprVisitorResult item = expressionVisitor.visitOC_Expression(itemCtx.oC_Expression());
342327
String alias =
343328
(itemCtx.AS() == null) ? null : Utils.getAliasName(itemCtx.oC_Variable());
344329
if (item.getAggCalls().isEmpty()) {
330+
int ordinal = keyExprs.size();
331+
fieldSuppliers.add(new ColumnOrder.FieldSupplier.Default(builder, () -> ordinal));
345332
keyExprs.add(item.getExpr());
346333
keyAliases.add(alias);
347334
} else {
348335
if (item.getExpr() instanceof RexCall) {
349-
extraExprs.add(item.getExpr());
350-
extraAliases.add(alias);
336+
fieldSuppliers.add(
337+
(RelDataType type) -> new ColumnOrder.Field(item.getExpr(), alias));
351338
aggCalls.addAll(item.getAggCalls());
352339
} else if (item.getAggCalls().size() == 1) { // count(a.name)
340+
int ordinal = aggCalls.size();
341+
fieldSuppliers.add(
342+
new ColumnOrder.FieldSupplier.Default(
343+
builder, () -> keyExprs.size() + ordinal));
353344
GraphAggCall original = (GraphAggCall) item.getAggCalls().get(0);
354345
aggCalls.add(
355346
new GraphAggCall(
@@ -363,6 +354,7 @@ private boolean isGroupPattern(
363354
}
364355
}
365356
}
357+
columnManagerRef.set(new ColumnOrder(fieldSuppliers));
366358
return !aggCalls.isEmpty();
367359
}
368360

interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/BITest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ public void bi1_test() {
124124
+ " totalMessageCount)], isAppend=[false])\n"
125125
+ " GraphLogicalProject(totalMessageCount=[totalMessageCount], year=[year],"
126126
+ " isComment=[isComment], lengthCategory=[lengthCategory],"
127-
+ " messageCount=[messageCount], sumMessageLength=[sumMessageLength],"
128-
+ " averageMessageLength=[/(EXPR$2, EXPR$3)], isAppend=[false])\n"
127+
+ " messageCount=[messageCount], averageMessageLength=[/(EXPR$2, EXPR$3)],"
128+
+ " sumMessageLength=[sumMessageLength], isAppend=[false])\n"
129129
+ " GraphLogicalAggregate(keys=[{variables=[totalMessageCount, year, $f0,"
130130
+ " $f1], aliases=[totalMessageCount, year, isComment, lengthCategory]}],"
131131
+ " values=[[{operands=[message], aggFunction=COUNT, alias='messageCount',"

interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/planner/cbo/LdbcTest.java

+12-10
Original file line numberDiff line numberDiff line change
@@ -511,30 +511,32 @@ public void ldbc7_test() {
511511
+ " messageContent=[message.content], messageImageFile=[message.imageFile],"
512512
+ " minutesLatency=[/(/(-(likeTime, message.creationDate), 1000), 60)],"
513513
+ " isNew=[isNew], isAppend=[false])\n"
514-
+ " GraphLogicalAggregate(keys=[{variables=[liker, person, isNew],"
514+
+ " GraphLogicalProject(liker=[liker], person=[person], message=[message],"
515+
+ " likeTime=[likeTime], isNew=[isNew], isAppend=[false])\n"
516+
+ " GraphLogicalAggregate(keys=[{variables=[liker, person, isNew],"
515517
+ " aliases=[liker, person, isNew]}], values=[[{operands=[message],"
516518
+ " aggFunction=FIRST_VALUE, alias='message', distinct=false},"
517519
+ " {operands=[likeTime], aggFunction=FIRST_VALUE, alias='likeTime',"
518520
+ " distinct=false}]])\n"
519-
+ " GraphLogicalSort(sort0=[likeTime], sort1=[message.id], dir0=[DESC],"
521+
+ " GraphLogicalSort(sort0=[likeTime], sort1=[message.id], dir0=[DESC],"
520522
+ " dir1=[ASC])\n"
521-
+ " GraphLogicalProject(liker=[liker], message=[message],"
523+
+ " GraphLogicalProject(liker=[liker], message=[message],"
522524
+ " likeTime=[like.creationDate], person=[person], isNew=[IS NULL(k)],"
523525
+ " isAppend=[false])\n"
524-
+ " MultiJoin(joinFilter=[=(liker, liker)], isFullOuterJoin=[false],"
526+
+ " MultiJoin(joinFilter=[=(liker, liker)], isFullOuterJoin=[false],"
525527
+ " joinTypes=[[INNER, INNER]], outerJoinConditions=[[NULL, NULL]],"
526528
+ " projFields=[[ALL, ALL]])\n"
527-
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
529+
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
528530
+ " alias=[liker], opt=[START])\n"
529-
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
531+
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
530532
+ " tables=[LIKES]}], alias=[like], startAlias=[message], opt=[IN])\n"
531-
+ " CommonTableScan(table=[[common#378747223]])\n"
532-
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
533+
+ " CommonTableScan(table=[[common#378747223]])\n"
534+
+ " GraphLogicalGetV(tableConfig=[{isAll=false, tables=[PERSON]}],"
533535
+ " alias=[liker], opt=[OTHER])\n"
534-
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
536+
+ " GraphLogicalExpand(tableConfig=[{isAll=false,"
535537
+ " tables=[KNOWS]}], alias=[k], startAlias=[person], opt=[BOTH],"
536538
+ " optional=[true])\n"
537-
+ " CommonTableScan(table=[[common#378747223]])\n"
539+
+ " CommonTableScan(table=[[common#378747223]])\n"
538540
+ "common#378747223:\n"
539541
+ "GraphPhysicalExpand(tableConfig=[{isAll=false, tables=[HASCREATOR]}],"
540542
+ " alias=[message], startAlias=[person], opt=[IN], physicalOpt=[VERTEX])\n"

interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java

+18
Original file line numberDiff line numberDiff line change
@@ -731,4 +731,22 @@ public void special_label_name_test() {
731731
+ " alias=[n], opt=[VERTEX])",
732732
after.explain().trim());
733733
}
734+
735+
// the return column order should align with the query given
736+
@Test
737+
public void aggregate_column_order_test() {
738+
GraphBuilder builder =
739+
com.alibaba.graphscope.common.ir.Utils.mockGraphBuilder(optimizer, irMeta);
740+
RelNode node =
741+
Utils.eval("Match (n:person) Return count(n), n, sum(n.age)", builder).build();
742+
RelNode after = optimizer.optimize(node, new GraphIOProcessor(builder, irMeta));
743+
Assert.assertEquals(
744+
"GraphLogicalProject($f1=[$f1], n=[n], $f2=[$f2], isAppend=[false])\n"
745+
+ " GraphLogicalAggregate(keys=[{variables=[n], aliases=[n]}],"
746+
+ " values=[[{operands=[n], aggFunction=COUNT, alias='$f1', distinct=false},"
747+
+ " {operands=[n.age], aggFunction=SUM, alias='$f2', distinct=false}]])\n"
748+
+ " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}],"
749+
+ " alias=[n], opt=[VERTEX])",
750+
after.explain().trim());
751+
}
734752
}

0 commit comments

Comments
 (0)