Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 91 additions & 48 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand Down Expand Up @@ -285,13 +286,6 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
return new LogicalFilter(child, optimized);
}

/**
* Ensure NESTED function is not used in GROUP BY, and HAVING clauses. Fallback to legacy engine.
* Can remove when support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING
* clauses.
*
* @param condition : Filter condition
*/
private void verifySupportsCondition(Expression condition) {
if (condition instanceof FunctionExpression) {
if (((FunctionExpression) condition)
Expand Down Expand Up @@ -387,53 +381,106 @@ public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) {
public LogicalPlan visitProject(Project node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);

if (node.hasArgument()) {
Argument argument = node.getArgExprList().get(0);
Boolean exclude = (Boolean) argument.getValue().getValue();
if (exclude) {
TypeEnvironment curEnv = context.peek();
List<ReferenceExpression> referenceExpressions =
node.getProjectList().stream()
.map(expr -> (ReferenceExpression) expressionAnalyzer.analyze(expr, context))
.collect(Collectors.toList());
referenceExpressions.forEach(ref -> curEnv.remove(ref));
return new LogicalRemove(child, ImmutableSet.copyOf(referenceExpressions));
}
if (isExcludeMode(node)) {
return buildLogicalRemove(node, child, context);
}

// For each unresolved window function, analyze it by "insert" a window and sort operator
// between project and its child.
for (UnresolvedExpression expr : node.getProjectList()) {
WindowExpressionAnalyzer windowAnalyzer =
new WindowExpressionAnalyzer(expressionAnalyzer, child);
child = windowAnalyzer.analyze(expr, context);
}

for (UnresolvedExpression expr : node.getProjectList()) {
HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child);
child = highlightAnalyzer.analyze(expr, context);
}
child = processWindowExpressions(node.getProjectList(), child, context);
child = processHighlightExpressions(node.getProjectList(), child, context);

List<NamedExpression> namedExpressions =
selectExpressionAnalyzer.analyze(
node.getProjectList(),
context,
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child));

for (UnresolvedExpression expr : node.getProjectList()) {
NestedAnalyzer nestedAnalyzer =
new NestedAnalyzer(namedExpressions, expressionAnalyzer, child);
child = nestedAnalyzer.analyze(expr, context);
}
resolveFieldExpressions(node.getProjectList(), child, context);

child = processNestedAnalysis(node.getProjectList(), namedExpressions, child, context);

// new context
context.push();
TypeEnvironment newEnv = context.peek();
namedExpressions.forEach(
expr ->
newEnv.define(new Symbol(Namespace.FIELD_NAME, expr.getNameOrAlias()), expr.type()));
List<NamedExpression> namedParseExpressions = context.getNamedParseExpressions();
return new LogicalProject(child, namedExpressions, namedParseExpressions);

return new LogicalProject(child, namedExpressions, context.getNamedParseExpressions());
}

private boolean isExcludeMode(Project node) {
if (!node.hasArgument()) {
return false;
}
try {
Argument argument = node.getArgExprList().get(0);
Object value = argument.getValue().getValue();
return Boolean.TRUE.equals(value);
} catch (IndexOutOfBoundsException | NullPointerException e) {
return false;
}
}

private LogicalRemove buildLogicalRemove(
Project node, LogicalPlan child, AnalysisContext context) {
TypeEnvironment curEnv = context.peek();
List<ReferenceExpression> referenceExpressions =
collectExclusionFields(node.getProjectList(), context);

Set<String> allFields = curEnv.lookupAllFields(Namespace.FIELD_NAME).keySet();
Set<String> fieldsToExclude =
referenceExpressions.stream().map(ReferenceExpression::getAttr).collect(Collectors.toSet());

if (allFields.equals(fieldsToExclude)) {
throw new IllegalArgumentException(
"Invalid field exclusion: operation would exclude all fields from the result set");
}

referenceExpressions.forEach(curEnv::remove);
return new LogicalRemove(child, ImmutableSet.copyOf(referenceExpressions));
}

private LogicalPlan processWindowExpressions(
List<UnresolvedExpression> projectList, LogicalPlan child, AnalysisContext context) {
for (UnresolvedExpression expr : projectList) {
child = new WindowExpressionAnalyzer(expressionAnalyzer, child).analyze(expr, context);
}
return child;
}

private LogicalPlan processHighlightExpressions(
List<UnresolvedExpression> projectList, LogicalPlan child, AnalysisContext context) {
for (UnresolvedExpression expr : projectList) {
child = new HighlightAnalyzer(expressionAnalyzer, child).analyze(expr, context);
}
return child;
}

private List<NamedExpression> resolveFieldExpressions(
List<UnresolvedExpression> projectList, LogicalPlan child, AnalysisContext context) {
return selectExpressionAnalyzer.analyze(
projectList,
context,
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child));
}

private LogicalPlan processNestedAnalysis(
List<UnresolvedExpression> projectList,
List<NamedExpression> namedExpressions,
LogicalPlan child,
AnalysisContext context) {
for (UnresolvedExpression expr : projectList) {
child =
new NestedAnalyzer(namedExpressions, expressionAnalyzer, child).analyze(expr, context);
}
return child;
}

private List<ReferenceExpression> collectExclusionFields(
List<UnresolvedExpression> projectList, AnalysisContext context) {
List<NamedExpression> namedExpressions =
projectList.stream()
.map(expr -> expressionAnalyzer.analyze(expr, context))
.map(DSL::named)
.collect(Collectors.toList());

return namedExpressions.stream()
.map(field -> (ReferenceExpression) field.getDelegated())
.collect(Collectors.toList());
}

/** Build {@link LogicalEval}. */
Expand Down Expand Up @@ -745,10 +792,6 @@ private LogicalSort buildSort(
return new LogicalSort(child, count, sortList);
}

/**
* The first argument is always "asc", others are optional. Given nullFirst argument, use its
* value. Otherwise just use DEFAULT_ASC/DESC.
*/
private SortOption analyzeSortOption(List<Argument> fieldArgs) {
Boolean asc = (Boolean) fieldArgs.get(0).getValue().getValue();
Optional<Argument> nullFirst =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.calcite.utils.WildcardUtils;
import org.opensearch.sql.common.patterns.PatternUtils;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.exception.CalciteUnsupportedException;
Expand Down Expand Up @@ -192,31 +193,116 @@ private boolean containsSubqueryExpression(Node expr) {
@Override
public RelNode visitProject(Project node, CalcitePlanContext context) {
visitChildren(node, context);
List<RexNode> projectList;
if (node.getProjectList().size() == 1
&& node.getProjectList().get(0) instanceof AllFields) {
AllFields allFields = (AllFields) node.getProjectList().get(0);
tryToRemoveNestedFields(context);
tryToRemoveMetaFields(context, allFields instanceof AllFieldsExcludeMeta);
return context.relBuilder.peek();
} else {
projectList =
node.getProjectList().stream()
.map(expr -> rexVisitor.analyze(expr, context))
.collect(Collectors.toList());

if (isSingleAllFieldsProject(node)) {
return handleAllFieldsProject(node, context);
}

List<String> currentFields = context.relBuilder.peek().getRowType().getFieldNames();
List<RexNode> expandedFields =
expandProjectFields(node.getProjectList(), currentFields, context);

if (node.isExcluded()) {
context.relBuilder.projectExcept(projectList);
validateExclusion(expandedFields, currentFields);
context.relBuilder.projectExcept(expandedFields);
} else {
// Only set when not resolving subquery and it's not projectExcept.
if (!context.isResolvingSubquery()) {
context.setProjectVisited(true);
}
context.relBuilder.project(projectList);
context.relBuilder.project(expandedFields);
}
return context.relBuilder.peek();
}

private boolean isSingleAllFieldsProject(Project node) {
return node.getProjectList().size() == 1
&& node.getProjectList().get(0) instanceof AllFields;
}

private RelNode handleAllFieldsProject(Project node, CalcitePlanContext context) {
if (node.isExcluded()) {
throw new IllegalArgumentException(
"Invalid field exclusion: operation would exclude all fields from the result set");
}
AllFields allFields = (AllFields) node.getProjectList().get(0);
tryToRemoveNestedFields(context);
tryToRemoveMetaFields(context, allFields instanceof AllFieldsExcludeMeta);
return context.relBuilder.peek();
}

private List<RexNode> expandProjectFields(
List<UnresolvedExpression> projectList,
List<String> currentFields,
CalcitePlanContext context) {
List<RexNode> expandedFields = new ArrayList<>();
Set<String> addedFields = new HashSet<>();

for (UnresolvedExpression expr : projectList) {
if (expr instanceof Field) {
Field field = (Field) expr;
String fieldName = field.getField().toString();
if (WildcardUtils.containsWildcard(fieldName)) {
List<String> matchingFields =
WildcardUtils.expandWildcardPattern(fieldName, currentFields).stream()
.filter(f -> !isMetadataField(f))
.filter(addedFields::add)
.collect(Collectors.toList());
if (matchingFields.isEmpty()) {
continue;
}
matchingFields.forEach(f -> expandedFields.add(context.relBuilder.field(f)));
} else if (addedFields.add(fieldName)) {
expandedFields.add(rexVisitor.analyze(field, context));
}
} else if (expr instanceof AllFields) {
currentFields.stream()
.filter(field -> !isMetadataField(field))
.filter(addedFields::add)
.forEach(field -> expandedFields.add(context.relBuilder.field(field)));
} else {
throw new IllegalStateException(
"Unexpected expression type in project list: " + expr.getClass().getSimpleName());
}
}

if (expandedFields.isEmpty()) {
validateWildcardPatterns(projectList, currentFields);
}

return expandedFields;
}

private void validateExclusion(List<RexNode> fieldsToExclude, List<String> currentFields) {
Set<String> nonMetaFields =
currentFields.stream().filter(field -> !isMetadataField(field)).collect(Collectors.toSet());

if (fieldsToExclude.size() >= nonMetaFields.size()) {
throw new IllegalArgumentException(
"Invalid field exclusion: operation would exclude all fields from the result set");
}
}

private void validateWildcardPatterns(
List<UnresolvedExpression> projectList, List<String> currentFields) {
String firstWildcardPattern =
projectList.stream()
.filter(expr -> expr instanceof Field)
.map(expr -> (Field) expr)
.filter(field -> WildcardUtils.containsWildcard(field.getField().toString()))
.map(field -> field.getField().toString())
.findFirst()
.orElse(null);

if (firstWildcardPattern != null) {
throw new IllegalArgumentException(
String.format("wildcard pattern [%s] matches no fields", firstWildcardPattern));
}
}

private boolean isMetadataField(String fieldName) {
return OpenSearchConstants.METADATAFIELD_TYPE_MAP.containsKey(fieldName);
}

/** See logic in {@link org.opensearch.sql.analysis.symbol.SymbolTable#lookupAllFields} */
private static void tryToRemoveNestedFields(CalcitePlanContext context) {
Set<String> allFields = new HashSet<>(context.relBuilder.peek().getRowType().getFieldNames());
Expand Down Expand Up @@ -503,7 +589,6 @@ public RelNode visitPatterns(Patterns node, CalcitePlanContext context) {
@Override
public RelNode visitEval(Eval node, CalcitePlanContext context) {
visitChildren(node, context);
List<String> originalFieldNames = context.relBuilder.peek().getRowType().getFieldNames();
node.getExpressionList()
.forEach(
expr -> {
Expand Down
Loading
Loading