Skip to content

Commit 64f3d6c

Browse files
Add Support for Field Star in Nested Function (#1773) (#1797)
* Add Support for Field Star in Nested Function. * Removing toString for NestedAllTupleFields. * Adding IT test for nested all fields in invalid clause of SQL statement. * Use utility function for checking is nested in NestedAnalyzer. * Formatting fixes. --------- (cherry picked from commit fa840e0) Signed-off-by: forestmvey <[email protected]> Co-authored-by: Forest Vey <[email protected]>
1 parent 2144ddd commit 64f3d6c

File tree

16 files changed

+577
-28
lines changed

16 files changed

+577
-28
lines changed

core/src/main/java/org/opensearch/sql/analysis/Analyzer.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.stream.Collectors;
3030
import org.apache.commons.lang3.tuple.ImmutablePair;
3131
import org.apache.commons.lang3.tuple.Pair;
32+
import org.apache.commons.math3.analysis.function.Exp;
3233
import org.opensearch.sql.DataSourceSchemaName;
3334
import org.opensearch.sql.analysis.symbol.Namespace;
3435
import org.opensearch.sql.analysis.symbol.Symbol;
@@ -469,8 +470,13 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) {
469470
node.getSortList().stream()
470471
.map(
471472
sortField -> {
472-
Expression expression = optimizer.optimize(
473-
expressionAnalyzer.analyze(sortField.getField(), context), context);
473+
var analyzed = expressionAnalyzer.analyze(sortField.getField(), context);
474+
if (analyzed == null) {
475+
throw new UnsupportedOperationException(
476+
String.format("Invalid use of expression %s", sortField.getField())
477+
);
478+
}
479+
Expression expression = optimizer.optimize(analyzed, context);
474480
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
475481
})
476482
.collect(Collectors.toList());

core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,16 @@ public Expression visitFunction(Function node, AnalysisContext context) {
186186
FunctionName functionName = FunctionName.of(node.getFuncName());
187187
List<Expression> arguments =
188188
node.getFuncArgs().stream()
189-
.map(unresolvedExpression -> analyze(unresolvedExpression, context))
189+
.map(unresolvedExpression -> {
190+
var ret = analyze(unresolvedExpression, context);
191+
if (ret == null) {
192+
throw new UnsupportedOperationException(
193+
String.format("Invalid use of expression %s", unresolvedExpression)
194+
);
195+
} else {
196+
return ret;
197+
}
198+
})
190199
.collect(Collectors.toList());
191200
return (Expression) repository.compile(context.getFunctionProperties(),
192201
functionName, arguments);

core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.sql.ast.AbstractNodeVisitor;
1616
import org.opensearch.sql.ast.expression.Alias;
1717
import org.opensearch.sql.ast.expression.Function;
18+
import org.opensearch.sql.ast.expression.NestedAllTupleFields;
1819
import org.opensearch.sql.ast.expression.QualifiedName;
1920
import org.opensearch.sql.ast.expression.UnresolvedExpression;
2021
import org.opensearch.sql.expression.Expression;
@@ -45,6 +46,28 @@ public LogicalPlan visitAlias(Alias node, AnalysisContext context) {
4546
return node.getDelegated().accept(this, context);
4647
}
4748

49+
@Override
50+
public LogicalPlan visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) {
51+
List<Map<String, ReferenceExpression>> args = new ArrayList<>();
52+
for (NamedExpression namedExpr : namedExpressions) {
53+
if (isNestedFunction(namedExpr.getDelegated())) {
54+
ReferenceExpression field =
55+
(ReferenceExpression) ((FunctionExpression) namedExpr.getDelegated())
56+
.getArguments().get(0);
57+
58+
// If path is same as NestedAllTupleFields path
59+
if (field.getAttr().substring(0, field.getAttr().lastIndexOf("."))
60+
.equalsIgnoreCase(node.getPath())) {
61+
args.add(Map.of(
62+
"field", field,
63+
"path", new ReferenceExpression(node.getPath(), STRING)));
64+
}
65+
}
66+
}
67+
68+
return mergeChildIfLogicalNested(args);
69+
}
70+
4871
@Override
4972
public LogicalPlan visitFunction(Function node, AnalysisContext context) {
5073
if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) {
@@ -54,6 +77,8 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) {
5477
ReferenceExpression nestedField =
5578
(ReferenceExpression)expressionAnalyzer.analyze(expressions.get(0), context);
5679
Map<String, ReferenceExpression> args;
80+
81+
// Path parameter is supplied
5782
if (expressions.size() == 2) {
5883
args = Map.of(
5984
"field", nestedField,
@@ -65,16 +90,28 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) {
6590
"path", generatePath(nestedField.toString())
6691
);
6792
}
68-
if (child instanceof LogicalNested) {
69-
((LogicalNested)child).addFields(args);
70-
return child;
71-
} else {
72-
return new LogicalNested(child, new ArrayList<>(Arrays.asList(args)), namedExpressions);
73-
}
93+
94+
return mergeChildIfLogicalNested(new ArrayList<>(Arrays.asList(args)));
7495
}
7596
return null;
7697
}
7798

99+
/**
100+
* NestedAnalyzer visits all functions in SELECT clause, creates logical plans for each and
101+
* merges them. This is to avoid another merge rule in LogicalPlanOptimizer:create().
102+
* @param args field and path params to add to logical plan.
103+
* @return child of logical nested with added args, or new LogicalNested.
104+
*/
105+
private LogicalPlan mergeChildIfLogicalNested(List<Map<String, ReferenceExpression>> args) {
106+
if (child instanceof LogicalNested) {
107+
for (var arg : args) {
108+
((LogicalNested) child).addFields(arg);
109+
}
110+
return child;
111+
}
112+
return new LogicalNested(child, args, namedExpressions);
113+
}
114+
78115
/**
79116
* Validate each parameter used in nested function in SELECT clause. Any supplied parameter
80117
* for a nested function in a SELECT statement must be a valid qualified name, and the field

core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010
import java.util.Collections;
1111
import java.util.List;
1212
import java.util.Map;
13+
import java.util.regex.Pattern;
1314
import java.util.stream.Collectors;
1415
import lombok.RequiredArgsConstructor;
1516
import org.opensearch.sql.analysis.symbol.Namespace;
17+
import org.opensearch.sql.analysis.symbol.Symbol;
1618
import org.opensearch.sql.ast.AbstractNodeVisitor;
1719
import org.opensearch.sql.ast.expression.Alias;
1820
import org.opensearch.sql.ast.expression.AllFields;
1921
import org.opensearch.sql.ast.expression.Field;
22+
import org.opensearch.sql.ast.expression.Function;
23+
import org.opensearch.sql.ast.expression.NestedAllTupleFields;
2024
import org.opensearch.sql.ast.expression.QualifiedName;
2125
import org.opensearch.sql.ast.expression.UnresolvedExpression;
2226
import org.opensearch.sql.data.type.ExprType;
@@ -58,6 +62,11 @@ public List<NamedExpression> visitField(Field node, AnalysisContext context) {
5862

5963
@Override
6064
public List<NamedExpression> visitAlias(Alias node, AnalysisContext context) {
65+
// Expand all nested fields if used in SELECT clause
66+
if (node.getDelegated() instanceof NestedAllTupleFields) {
67+
return node.getDelegated().accept(this, context);
68+
}
69+
6170
Expression expr = referenceIfSymbolDefined(node, context);
6271
return Collections.singletonList(DSL.named(
6372
unqualifiedNameIfFieldOnly(node, context),
@@ -100,6 +109,29 @@ public List<NamedExpression> visitAllFields(AllFields node,
100109
new ReferenceExpression(entry.getKey(), entry.getValue()))).collect(Collectors.toList());
101110
}
102111

112+
@Override
113+
public List<NamedExpression> visitNestedAllTupleFields(NestedAllTupleFields node,
114+
AnalysisContext context) {
115+
TypeEnvironment environment = context.peek();
116+
Map<String, ExprType> lookupAllTupleFields =
117+
environment.lookupAllTupleFields(Namespace.FIELD_NAME);
118+
environment.resolve(new Symbol(Namespace.FIELD_NAME, node.getPath()));
119+
120+
// Match all fields with same path as used in nested function.
121+
Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$");
122+
return lookupAllTupleFields.entrySet().stream()
123+
.filter(field -> p.matcher(field.getKey()).find())
124+
.map(entry -> {
125+
Expression nestedFunc = new Function(
126+
"nested",
127+
List.of(
128+
new QualifiedName(List.of(entry.getKey().split("\\."))))
129+
).accept(expressionAnalyzer, context);
130+
return DSL.named("nested(" + entry.getKey() + ")", nestedFunc);
131+
})
132+
.collect(Collectors.toList());
133+
}
134+
103135
/**
104136
* Get unqualified name if select item is just a field. For example, suppose an index
105137
* named "accounts", return "age" for "SELECT accounts.age". But do nothing for expression

core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@ public Map<String, ExprType> lookupAllFields(Namespace namespace) {
8585
return result;
8686
}
8787

88+
/**
89+
* Resolve all fields in the current environment.
90+
* @param namespace a namespace
91+
* @return all symbols in the namespace
92+
*/
93+
public Map<String, ExprType> lookupAllTupleFields(Namespace namespace) {
94+
Map<String, ExprType> result = new LinkedHashMap<>();
95+
symbolTable.lookupAllTupleFields(namespace).forEach(result::putIfAbsent);
96+
return result;
97+
}
98+
8899
/**
89100
* Define symbol with the type.
90101
*

core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,21 @@ public Map<String, ExprType> lookupAllFields(Namespace namespace) {
128128
return results;
129129
}
130130

131+
/**
132+
* Look up all top level symbols in the namespace.
133+
*
134+
* @param namespace a namespace
135+
* @return all symbols in the namespace map
136+
*/
137+
public Map<String, ExprType> lookupAllTupleFields(Namespace namespace) {
138+
final LinkedHashMap<String, ExprType> allSymbols =
139+
orderedTable.getOrDefault(namespace, new LinkedHashMap<>());
140+
final LinkedHashMap<String, ExprType> result = new LinkedHashMap<>();
141+
allSymbols.entrySet().stream()
142+
.forEach(entry -> result.put(entry.getKey(), entry.getValue()));
143+
return result;
144+
}
145+
131146
/**
132147
* Check if namespace map in empty (none definition).
133148
*

core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.sql.ast.expression.Let;
2626
import org.opensearch.sql.ast.expression.Literal;
2727
import org.opensearch.sql.ast.expression.Map;
28+
import org.opensearch.sql.ast.expression.NestedAllTupleFields;
2829
import org.opensearch.sql.ast.expression.Not;
2930
import org.opensearch.sql.ast.expression.Or;
3031
import org.opensearch.sql.ast.expression.QualifiedName;
@@ -238,6 +239,10 @@ public T visitAllFields(AllFields node, C context) {
238239
return visitChildren(node, context);
239240
}
240241

242+
public T visitNestedAllTupleFields(NestedAllTupleFields node, C context) {
243+
return visitChildren(node, context);
244+
}
245+
241246
public T visitInterval(Interval node, C context) {
242247
return visitChildren(node, context);
243248
}

core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.sql.ast.expression.Let;
3131
import org.opensearch.sql.ast.expression.Literal;
3232
import org.opensearch.sql.ast.expression.Map;
33+
import org.opensearch.sql.ast.expression.NestedAllTupleFields;
3334
import org.opensearch.sql.ast.expression.Not;
3435
import org.opensearch.sql.ast.expression.Or;
3536
import org.opensearch.sql.ast.expression.ParseMethod;
@@ -377,6 +378,10 @@ public Alias alias(String name, UnresolvedExpression expr, String alias) {
377378
return new Alias(name, expr, alias);
378379
}
379380

381+
public NestedAllTupleFields nestedAllTupleFields(String path) {
382+
return new NestedAllTupleFields(path);
383+
}
384+
380385
public static List<UnresolvedExpression> exprList(UnresolvedExpression... exprList) {
381386
return Arrays.asList(exprList);
382387
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
7+
package org.opensearch.sql.ast.expression;
8+
9+
import java.util.Collections;
10+
import java.util.List;
11+
import lombok.EqualsAndHashCode;
12+
import lombok.Getter;
13+
import lombok.RequiredArgsConstructor;
14+
import lombok.ToString;
15+
import org.opensearch.sql.ast.AbstractNodeVisitor;
16+
import org.opensearch.sql.ast.Node;
17+
18+
/**
19+
* Represents all tuple fields used in nested function.
20+
*/
21+
@RequiredArgsConstructor
22+
@EqualsAndHashCode(callSuper = false)
23+
public class NestedAllTupleFields extends UnresolvedExpression {
24+
@Getter
25+
private final String path;
26+
27+
@Override
28+
public List<? extends Node> getChild() {
29+
return Collections.emptyList();
30+
}
31+
32+
@Override
33+
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
34+
return nodeVisitor.visitNestedAllTupleFields(this, context);
35+
}
36+
37+
@Override
38+
public String toString() {
39+
return String.format("nested(%s.*)", path);
40+
}
41+
}

0 commit comments

Comments
 (0)