Skip to content

Commit

Permalink
[BugFix] support SqlToScalarOperatorTranslator visit lambda functions…
Browse files Browse the repository at this point in the history
… more than one times (#19843) (#19902)

---------

Signed-off-by: Zhuhe Fang <[email protected]>
  • Loading branch information
fzhedu authored Mar 27, 2023
1 parent 9239550 commit 6d57a55
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@
import com.starrocks.common.AnalysisException;
import com.starrocks.sql.common.ErrorType;
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.thrift.TExprNode;

public class LambdaArgument extends Expr {
private String name;
private boolean nullable;

ColumnRefOperator transformedOp = null;

public ColumnRefOperator getTransformed() {
return transformedOp;
}

public void setTransformed(ColumnRefOperator op) {
transformedOp = op;
}

public LambdaArgument(String name) {
this.name = name;
}
Expand All @@ -22,7 +33,7 @@ public LambdaArgument(LambdaArgument rhs) {
super(rhs);
name = rhs.getName();
}

public String getName() {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.SlotRef;
import com.starrocks.common.AnalysisException;
import com.starrocks.sql.optimizer.operator.scalar.LambdaFunctionOperator;
import com.starrocks.thrift.TExprNode;
import com.starrocks.thrift.TExprNodeType;

import java.util.List;
import java.util.Map;

public class LambdaFunctionExpr extends Expr {
private LambdaFunctionOperator transformedOp = null;
private int commonSubOperatorNum = 0;

public LambdaFunctionExpr(List<Expr> arguments) {
Expand All @@ -37,14 +35,6 @@ public LambdaFunctionExpr(LambdaFunctionExpr rhs) {
super(rhs);
}

public LambdaFunctionOperator getTransformed() {
return transformedOp;
}

public void setTransformed(LambdaFunctionOperator op) {
transformedOp = op;
}

@Override
protected void analyzeImpl(Analyzer analyzer) throws AnalysisException {
Preconditions.checkState(false, "unreachable");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,6 @@ public ScalarOperator visitArrowExpr(ArrowExpr node, Context context) {
@Override
public ScalarOperator visitLambdaFunctionExpr(LambdaFunctionExpr node,
Context context) {
// To avoid the ids of lambda arguments are different after each visit()
if (node.getTransformed() != null) {
return node.getTransformed();
}
Preconditions.checkArgument(node.getChildren().size() >= 2);
List<ColumnRefOperator> refs = Lists.newArrayList();
List<LambdaArgument> args = Lists.newArrayList();
Expand All @@ -341,13 +337,16 @@ public ScalarOperator visitLambdaFunctionExpr(LambdaFunctionExpr node,
expressionMapping = new ExpressionMapping(scope, refs, expressionMapping);
ScalarOperator lambda = visit(node.getChild(0), context.clone(node));
expressionMapping = old; // recover it
node.setTransformed(new LambdaFunctionOperator(refs, lambda, Type.FUNCTION));
return node.getTransformed();
return new LambdaFunctionOperator(refs, lambda, Type.FUNCTION);
}

@Override
public ScalarOperator visitLambdaArguments(LambdaArgument node, Context context) {
return columnRefFactory.create(node.getName(), node.getType(), node.isNullable(), true);
// To avoid the ids of lambda arguments are different after each visit()
if (node.getTransformed() == null) {
node.setTransformed(columnRefFactory.create(node.getName(), node.getType(), node.isNullable(), true));
}
return node.getTransformed();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public void testLambdaFunction() {
analyzeSuccess("select array_map(x-> x + count(v1) over (partition by v1 order by v2),[111]) from tarray");
analyzeSuccess("select v1, v2, count(v1) over (partition by v1 order by v2) from tarray");
analyzeSuccess("select v1, v2, count(v1) over (partition by array_sum(array_map(x->x+1, [1])) order by v2) from tarray");
analyzeSuccess("with x2 as (select array_map((ss) -> ss * v1, v3) from tarray) select * from x2;");

analyzeFail("select array_map(x,y -> x + y, [], [])"); // should be (x,y)
analyzeFail("select array_map((x,y,z) -> x + y, [], [])");
Expand Down

0 comments on commit 6d57a55

Please sign in to comment.