Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Support CASE clause in new engine (#818)
Browse files Browse the repository at this point in the history
* Change grammar and UT

* Add case expression

* Pass jacoco in sql module

* Pass jacoco in core module and refactor when clause

* Fix broken IT and add javadoc

* Add comparison test for basic case

* Add comparison test for complex case

* Add doctest

* Add type check

* Prepare PR

* Prepare PR

* Address PR comments
  • Loading branch information
dai-chen authored and penghuo committed Dec 15, 2020
1 parent cdc9658 commit b618c63
Show file tree
Hide file tree
Showing 21 changed files with 833 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.And;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Compare;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.EqualTo;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field;
Expand All @@ -32,6 +33,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.When;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.WindowFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Xor;
import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxCheckException;
Expand All @@ -42,9 +44,13 @@
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -163,6 +169,43 @@ public Expression visitCompare(Compare node, AnalysisContext context) {
repository.compile(functionName, Arrays.asList(left, right));
}

@Override
public Expression visitCase(Case node, AnalysisContext context) {
List<WhenClause> whens = new ArrayList<>();
for (When when : node.getWhenClauses()) {
if (node.getCaseValue() == null) {
whens.add((WhenClause) analyze(when, context));
} else {
// Merge case value and condition (compare value) into a single equal condition
whens.add((WhenClause) analyze(
new When(
new Function("=", Arrays.asList(node.getCaseValue(), when.getCondition())),
when.getResult()
), context));
}
}

Expression defaultResult = (node.getElseClause() == null)
? null : analyze(node.getElseClause(), context);
CaseClause caseClause = new CaseClause(whens, defaultResult);

// To make this simple, require all result type same regardless of implicit convert
// Make CaseClause return list so it can be used in error message in determined order
List<ExprType> resultTypes = caseClause.allResultTypes();
if (ImmutableSet.copyOf(resultTypes).size() > 1) {
throw new SemanticCheckException(
"All result types of CASE clause must be the same, but found " + resultTypes);
}
return caseClause;
}

@Override
public Expression visitWhen(When node, AnalysisContext context) {
return new WhenClause(
analyze(node.getCondition(), context),
analyze(node.getResult(), context));
}

@Override
public Expression visitField(Field node, AnalysisContext context) {
String attr = node.getField().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.And;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Argument;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AttributeList;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Compare;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.EqualTo;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field;
Expand All @@ -35,6 +36,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedArgument;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.When;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.WindowFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Xor;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Aggregation;
Expand Down Expand Up @@ -210,6 +212,14 @@ public T visitInterval(Interval node, C context) {
return visitChildren(node, context);
}

public T visitCase(Case node, C context) {
return visitChildren(node, context);
}

public T visitWhen(When node, C context) {
return visitChildren(node, context);
}

public T visitUnresolvedArgument(UnresolvedArgument node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Alias;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.And;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Argument;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Compare;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.EqualTo;
Expand All @@ -35,6 +36,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedArgument;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.When;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.WindowFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Xor;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Aggregation;
Expand Down Expand Up @@ -187,6 +189,35 @@ public static Function function(String funcName, UnresolvedExpression... funcArg
return new Function(funcName, Arrays.asList(funcArgs));
}

/**
* CASE
* WHEN search_condition THEN result_expr
* [WHEN search_condition THEN result_expr] ...
* [ELSE result_expr]
* END
*/
public UnresolvedExpression caseWhen(UnresolvedExpression elseClause,
When... whenClauses) {
return caseWhen(null, elseClause, whenClauses);
}

/**
* CASE case_value_expr
* WHEN compare_expr THEN result_expr
* [WHEN compare_expr THEN result_expr] ...
* [ELSE result_expr]
* END
*/
public UnresolvedExpression caseWhen(UnresolvedExpression caseValueExpr,
UnresolvedExpression elseClause,
When... whenClauses) {
return new Case(caseValueExpr, Arrays.asList(whenClauses), elseClause);
}

public When when(UnresolvedExpression condition, UnresolvedExpression result) {
return new When(condition, result);
}

public UnresolvedExpression window(Function function,
List<UnresolvedExpression> partitionByList,
List<Pair<String, UnresolvedExpression>> sortList) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.opendistroforelasticsearch.sql.ast.expression;

import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;

/**
* AST node that represents CASE clause similar as Switch statement in programming language.
*/
@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
@Getter
@ToString
public class Case extends UnresolvedExpression {

/**
* Value to be compared by WHEN statements. Null in the case of CASE WHEN conditions.
*/
private final UnresolvedExpression caseValue;

/**
* Expression list that represents WHEN statements. Each is a mapping from condition
* to its result.
*/
private final List<When> whenClauses;

/**
* Expression that represents ELSE statement result.
*/
private final UnresolvedExpression elseClause;

@Override
public List<? extends Node> getChild() {
ImmutableList.Builder<Node> children = ImmutableList.builder();
if (caseValue != null) {
children.add(caseValue);
}
children.addAll(whenClauses);

if (elseClause != null) {
children.add(elseClause);
}
return children.build();
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitCase(this, context);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.opendistroforelasticsearch.sql.ast.expression;

import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

/**
* AST node that represents WHEN clause.
*/
@EqualsAndHashCode(callSuper = false)
@Getter
@RequiredArgsConstructor
@ToString
public class When extends UnresolvedExpression {

/**
* WHEN condition, either a search condition or compare value if case value present.
*/
private final UnresolvedExpression condition;

/**
* Result to return if condition matched.
*/
private final UnresolvedExpression result;

@Override
public List<? extends Node> getChild() {
return ImmutableList.of(condition, result);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitWhen(this, context);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.NamedAggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.expression.window.ranking.RankingWindowFunction;
Expand Down Expand Up @@ -508,6 +510,15 @@ public FunctionExpression isnotnull(Expression... expressions) {
return function(BuiltinFunctionName.IS_NOT_NULL, expressions);
}

public static Expression cases(Expression defaultResult,
WhenClause... whenClauses) {
return new CaseClause(Arrays.asList(whenClauses), defaultResult);
}

public static WhenClause when(Expression condition, Expression result) {
return new WhenClause(condition, result);
}

public FunctionExpression interval(Expression value, Expression unit) {
return (FunctionExpression) repository.compile(
BuiltinFunctionName.INTERVAL.getName(), Arrays.asList(value, unit));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.NamedAggregator;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause;
import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause;
import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionImplementation;

/**
Expand Down Expand Up @@ -78,4 +80,13 @@ public T visitAggregator(Aggregator<?> node, C context) {
public T visitNamedAggregator(NamedAggregator node, C context) {
return visitChildren(node, context);
}

public T visitCase(CaseClause node, C context) {
return visitNode(node, context);
}

public T visitWhen(WhenClause node, C context) {
return visitNode(node, context);
}

}
Loading

0 comments on commit b618c63

Please sign in to comment.