diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 9054f80e93d..dc12bdab73f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -292,6 +292,12 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) { child = windowAnalyzer.analyze(expr, context); } + for (UnresolvedExpression expr : node.getProjectList()) { + HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child); + child = highlightAnalyzer.analyze(expr, context); + + } + List namedExpressions = selectExpressionAnalyzer.analyze(node.getProjectList(), context, new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child)); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 91b964d0749..670da5c85ce 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -11,9 +11,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import lombok.Getter; @@ -29,6 +27,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Literal; @@ -44,12 +43,12 @@ import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.NamedExpression; @@ -191,6 +190,12 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte return expr; } + @Override + public Expression visitHighlight(HighlightFunction node, AnalysisContext context) { + Expression expr = node.getHighlightField().accept(this, context); + return new HighlightExpression(expr); + } + @Override public Expression visitIn(In node, AnalysisContext context) { return visitIn(node.getField(), node.getValueList(), context); diff --git a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java new file mode 100644 index 00000000000..06a601327ce --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.analysis; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.HighlightFunction; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalPlan; + +/** + * Analyze the highlight in the {@link AnalysisContext} to construct the {@link + * LogicalPlan}. + */ +@RequiredArgsConstructor +public class HighlightAnalyzer extends AbstractNodeVisitor { + private final ExpressionAnalyzer expressionAnalyzer; + private final LogicalPlan child; + + public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { + LogicalPlan highlight = projectItem.accept(this, context); + return (highlight == null) ? child : highlight; + } + + @Override + public LogicalPlan visitAlias(Alias node, AnalysisContext context) { + if (!(node.getDelegated() instanceof HighlightFunction)) { + return null; + } + + HighlightFunction unresolved = (HighlightFunction) node.getDelegated(); + Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); + return new LogicalHighlight(child, field); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 37163821e32..17321bc4739 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -18,6 +18,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Let; @@ -254,4 +255,8 @@ public T visitKmeans(Kmeans node, C context) { public T visitAD(AD node, C context) { return visitChildren(node, context); } + + public T visitHighlight(HighlightFunction node, C context) { + return visitChildren(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 4056347d44f..510482c6455 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -22,6 +22,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Let; @@ -261,6 +262,10 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) { return new When(condition, result); } + public UnresolvedExpression highlight(UnresolvedExpression fieldName) { + return new HighlightFunction(fieldName); + } + public UnresolvedExpression window(UnresolvedExpression function, List partitionByList, List> sortList) { diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java new file mode 100644 index 00000000000..5f1bb652d95 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of Highlight function. + */ +@AllArgsConstructor +@EqualsAndHashCode(callSuper = false) +@Getter +@ToString +public class HighlightFunction extends UnresolvedExpression { + private final UnresolvedExpression highlightField; + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitHighlight(this, context); + } + + @Override + public List getChild() { + return List.of(highlightField); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java b/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java index b05b0924a8e..d53371dd58f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java @@ -55,6 +55,10 @@ public T visitNamed(NamedExpression node, C context) { return node.getDelegated().accept(this, context); } + public T visitHighlight(HighlightExpression node, C context) { + return visitNode(node, context); + } + public T visitReference(ReferenceExpression node, C context) { return visitNode(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java new file mode 100644 index 00000000000..97456961114 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression; + +import java.util.List; +import lombok.Getter; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Highlight Expression. + */ +@Getter +public class HighlightExpression extends FunctionExpression { + private final Expression highlightField; + + /** + * HighlightExpression Constructor. + * @param highlightField : Highlight field for expression. + */ + public HighlightExpression(Expression highlightField) { + super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField)); + this.highlightField = highlightField; + } + + /** + * Return collection value matching highlight field. + * @param valueEnv : Dataset to parse value from. + * @return : collection value of highlight fields. + */ + @Override + public ExprValue valueOf(Environment valueEnv) { + String refName = "_highlight" + "." + StringUtils.unquoteText(getHighlightField().toString()); + return valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); + } + + /** + * Get type for HighlightExpression. + * @return : String type. + */ + @Override + public ExprType type() { + return ExprCoreType.ARRAY; + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitHighlight(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 5eb9e715b83..cd343284530 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -192,6 +192,7 @@ public enum BuiltinFunctionName { MATCHPHRASE(FunctionName.of("matchphrase")), QUERY_STRING(FunctionName.of("query_string")), MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")), + HIGHLIGHT(FunctionName.of("highlight")), MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")), /** * Legacy Relevance Function. diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index dff3ec1f200..c3e5cc55947 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -8,7 +8,6 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Collections; @@ -16,11 +15,13 @@ import java.util.Map; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.env.Environment; @@ -50,6 +51,14 @@ public void register(BuiltinFunctionRepository repository) { repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); repository.register(match_phrase_prefix()); + repository.register(highlight()); + } + + private static FunctionResolver highlight() { + FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName(); + FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); + FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0)); + return new FunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); } private static FunctionResolver match_bool_prefix() { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java new file mode 100644 index 00000000000..986a5454860 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.expression.Expression; + +@EqualsAndHashCode(callSuper = true) +@Getter +@ToString +public class LogicalHighlight extends LogicalPlan { + private final Expression highlightField; + + public LogicalHighlight(LogicalPlan childPlan, Expression field) { + super(Collections.singletonList(childPlan)); + highlightField = field; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitHighlight(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index d5fef43c0db..4185d55c559 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -62,6 +62,10 @@ public LogicalPlan window(LogicalPlan input, return new LogicalWindow(input, windowFunction, windowDefinition); } + public LogicalPlan highlight(LogicalPlan input, Expression field) { + return new LogicalHighlight(input, field); + } + public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) { return new LogicalRemove(input, ImmutableSet.copyOf(fields)); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 5163e44edb2..df23b9cd20d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -26,6 +26,10 @@ public R visitFilter(LogicalFilter plan, C context) { return visitNode(plan, context); } + public R visitHighlight(LogicalHighlight plan, C context) { + return visitNode(plan, context); + } + public R visitAggregation(LogicalAggregation plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java index 938a4c532c8..e6e59990c82 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanDSL.java @@ -105,5 +105,4 @@ public ValuesOperator values(List... values) { public static LimitOperator limit(PhysicalPlan input, Integer limit, Integer offset) { return new LimitOperator(input, limit, offset); } - } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 87582df3bbf..646aae8220a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -79,6 +79,4 @@ public R visitMLCommons(PhysicalPlan node, C context) { public R visitAD(PhysicalPlan node, C context) { return visitNode(node, context); } - - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 797c5ab11ad..b9de96b30a3 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -22,6 +22,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.span; +import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.tree.Sort.NullOrder; import static org.opensearch.sql.ast.tree.Sort.SortOption; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; @@ -45,6 +46,7 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; @@ -52,6 +54,7 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; @@ -231,6 +234,22 @@ public void project_source() { AstDSL.alias("double_value", AstDSL.field("double_value")))); } + @Test + public void project_highlight() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema"), + DSL.literal("fieldA")), + DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA"))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA"))) + ) + ); + } + @Test public void remove_source() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index dd0e199cf4c..72db4025522 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.field; -import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -35,6 +34,7 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -45,6 +45,7 @@ import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; import org.springframework.context.annotation.Configuration; @@ -536,6 +537,12 @@ public void match_phrase_prefix_all_params() { ); } + @Test + void highlight() { + assertAnalyzeEqual(new HighlightExpression(DSL.literal("fieldA")), + new HighlightFunction(stringLiteral("fieldA"))); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index 738e81bfd60..2293d125aa2 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -12,6 +12,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.HighlightFunction; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.springframework.context.annotation.Configuration; @@ -32,4 +37,15 @@ void visit_named_select_item() { NamedExpression analyze = analyzer.analyze(alias, analysisContext); assertEquals("integer_value", analyze.getNameOrAlias()); } + + @Test + void visit_highlight() { + Alias alias = AstDSL.alias("highlight(fieldA)", + new HighlightFunction(AstDSL.stringLiteral("fieldA"))); + NamedExpressionAnalyzer analyzer = + new NamedExpressionAnalyzer(expressionAnalyzer); + + NamedExpression analyze = analyzer.analyze(alias, analysisContext); + assertEquals("highlight(fieldA)", analyze.getNameOrAlias()); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java index caf11064ae8..b0b2bc5b2bf 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java @@ -34,6 +34,7 @@ class ExpressionNodeVisitorTest { @Test void should_return_null_by_default() { ExpressionNodeVisitor visitor = new ExpressionNodeVisitor(){}; + assertNull(new HighlightExpression(DSL.literal("Title")).accept(visitor, null)); assertNull(literal(10).accept(visitor, null)); assertNull(ref("name", STRING).accept(visitor, null)); assertNull(named("bool", literal(true)).accept(visitor, null)); diff --git a/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java new file mode 100644 index 00000000000..c6e2dccf692 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; +import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; + +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.DoNotCall; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.env.Environment; + +public class HighlightExpressionTest extends ExpressionTestBase { + + @Test + public void single_highlight_test() { + Environment hlTuple = ExprValueUtils.tupleValue( + ImmutableMap.of("_highlight.Title", "result value")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("Title")); + ExprValue resultVal = expr.valueOf(hlTuple); + + assertEquals(expr.type(), ARRAY); + assertEquals("result value", resultVal.stringValue()); + } + + @Test + public void missing_highlight_test() { + Environment hlTuple = ExprValueUtils.tupleValue( + ImmutableMap.of("_highlight.Title", "result value")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("invalid")); + ExprValue resultVal = expr.valueOf(hlTuple); + + assertTrue(resultVal.isMissing()); + } + + /** + * Enable me when '*' is supported in highlight. + */ + @DoNotCall + public void highlight_all_test() { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + var hlBuilder = ImmutableMap.builder(); + hlBuilder.put("Title", ExprValueUtils.stringValue("correct result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("incorrect result value")); + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("*")); + ExprValue resultVal = hlExpr.valueOf( + ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); + assertEquals(ARRAY, resultVal.type()); + for (var field : resultVal.tupleValue().entrySet()) { + assertTrue(field.toString().contains(hlExpr.getHighlightField().toString())); + } + assertTrue(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("\"correct result value\""))); + assertTrue(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("\"correct result value\""))); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index a2455a9a3d5..e899351d4f9 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -19,13 +19,14 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; @@ -113,6 +114,11 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + LogicalPlan highlight = new LogicalHighlight(filter, + new LiteralExpression(ExprValueUtils.stringValue("fieldA"))); + assertNull(highlight.accept(new LogicalPlanNodeVisitor() { + }, null)); + LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), "kmeans", ImmutableMap.builder() diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 240bf32dfb4..f1d4d987a37 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2479,3 +2479,29 @@ Another example to show how to set custom values for the optional parameters:: |------+--------------------------+----------------------| | 1 | The House at Pooh Corner | Alan Alexander Milne | +------+--------------------------+----------------------+ + + +HIGHLIGHT +------------ + +Description +>>>>>>>>>>> + +``highlight(field_expression)`` + +The highlight function maps to the highlight function used in search engine to return highlight fields for the given search. +The syntax allows to specify the field in double quotes or single quotes or without any wrap. +Please refer to examples below: + +| ``highlight(title)`` + +Example searching for field Tags:: + + os> select highlight(title) from books where query_string(['title'], 'Pooh House'); + fetched rows / total rows = 2/2 + +----------------------------------------------+ + | highlight(title) | + |----------------------------------------------| + | [The House at Pooh Corner] | + | [Winnie-the-Pooh] | + +----------------------------------------------+ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java new file mode 100644 index 00000000000..422f71968fb --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import com.google.errorprone.annotations.DoNotCall; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +public class HighlightFunctionIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + loadIndex(Index.BEER); + } + + @Test + public void single_highlight_test() { + String query = "SELECT Tags, highlight('Tags') FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("Tags", null, "text"), + schema("highlight('Tags')", null, "nested")); + assertEquals(1, response.getInt("total")); + } + + @Test + public void accepts_unquoted_test() { + String query = "SELECT Tags, highlight(Tags) FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("Tags", null, "text"), + schema("highlight(Tags)", null, "nested")); + assertEquals(1, response.getInt("total")); + } + + @Test + public void multiple_highlight_test() { + String query = "SELECT highlight(Title), highlight(Body) FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight(Title)", null, "nested"), + schema("highlight(Body)", null, "nested")); + assertEquals(1, response.getInt("total")); + } + + // Enable me when * is supported + @DoNotCall + public void wildcard_highlight_test() { + String query = "SELECT highlight('*itle') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight('*itle')", null, "nested")); + assertEquals(1, response.getInt("total")); + } + + // Enable me when * is supported + @DoNotCall + public void wildcard_multi_field_highlight_test() { + String query = "SELECT highlight('T*') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight('T*')", null, "nested")); + var resultMap = response.getJSONArray("datarows").getJSONArray(0).getJSONObject(0); + assertEquals(1, response.getInt("total")); + assertTrue(resultMap.has("highlight(\"T*\").Title")); + assertTrue(resultMap.has("highlight(\"T*\").Tags")); + } + + // Enable me when * is supported + @DoNotCall + public void highlight_all_test() { + String query = "SELECT highlight('*') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight('*')", null, "nested")); + assertEquals(1, response.getInt("total")); + } + + @Test + public void highlight_no_limit_test() { + String query = "SELECT highlight(Body) FROM %s WHERE MATCH(Body, 'hops')"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight(Body)", null, "nested")); + assertEquals(2, response.getInt("total")); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index 7dc77d7d295..aadd73efdde 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -10,6 +10,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Map; +import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import lombok.ToString; import org.opensearch.action.search.SearchResponse; @@ -17,6 +18,7 @@ import org.opensearch.search.aggregations.Aggregations; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; /** @@ -91,7 +93,23 @@ public Iterator iterator() { }).iterator(); } else { return Arrays.stream(hits.getHits()) - .map(hit -> (ExprValue) exprValueFactory.construct(hit.getSourceAsString())).iterator(); + .map(hit -> { + ExprValue docData = exprValueFactory.construct(hit.getSourceAsString()); + if (hit.getHighlightFields().isEmpty()) { + return docData; + } else { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + builder.putAll(docData.tupleValue()); + var hlBuilder = ImmutableMap.builder(); + for (var es : hit.getHighlightFields().entrySet()) { + hlBuilder.put(es.getKey(), ExprValueUtils.collectionValue( + Arrays.stream(es.getValue().fragments()).map( + t -> (t.toString())).collect(Collectors.toList()))); + } + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + return ExprTupleValue.fromExprValueMap(builder.build()); + } + }).iterator(); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 49301cbf536..c028f283a27 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -33,6 +33,7 @@ import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalRelation; @@ -187,5 +188,11 @@ public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { return new ADOperator(visitChild(node, context), node.getArguments(), client.getNodeClient()); } + + @Override + public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { + context.pushDownHighlight(node.getHighlightField().toString()); + return visitChild(node, context); + } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java index c35a5ba9dbf..6e88f3de891 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java @@ -25,8 +25,10 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.ReferenceExpression; @@ -158,6 +160,21 @@ public void pushDownLimit(Integer limit, Integer offset) { sourceBuilder.from(offset).size(limit); } + /** + * Add highlight to DSL requests. + * @param field name of the field to highlight + */ + public void pushDownHighlight(String field) { + SearchSourceBuilder sourceBuilder = request.getSourceBuilder(); + if (sourceBuilder.highlighter() != null) { + sourceBuilder.highlighter().field(StringUtils.unquoteText(field)); + } else { + HighlightBuilder highlightBuilder = + new HighlightBuilder().field(StringUtils.unquoteText(field)); + sourceBuilder.highlighter(highlightBuilder); + } + } + /** * Push down project list to DSL requets. */ diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index 5bffa1cfa86..ee981a4abca 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -36,7 +36,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; -import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index fa25b4f4086..0a60503415d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -17,18 +17,24 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.text.Text; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -148,4 +154,35 @@ void aggregation_iterator() { i++; } } + + @Test + void highlight_iterator() { + SearchHit searchHit = new SearchHit(1); + searchHit.sourceRef( + new BytesArray("{\"name\":\"John\"}")); + Map highlightMap = Map.of("highlights", + new HighlightField("Title", new Text[] {new Text("field")})); + searchHit.highlightFields(Map.of("highlights", new HighlightField("Title", + new Text[] {new Text("field")}))); + ExprValue resultTuple = ExprValueUtils.tupleValue(searchHit.getSourceAsMap()); + + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[]{searchHit1}, + new TotalHits(1L, TotalHits.Relation.EQUAL_TO), + 1.0F)); + + when(searchHit1.getHighlightFields()).thenReturn(highlightMap); + when(factory.construct(any())).thenReturn(resultTuple); + + for (ExprValue resultHit : new OpenSearchResponse(searchResponse, factory)) { + var expected = ExprValueUtils.collectionValue( + Arrays.stream(searchHit.getHighlightFields().get("highlights").getFragments()) + .map(t -> (t.toString())).collect(Collectors.toList())); + var result = resultHit.tupleValue().get( + "_highlight").tupleValue().get("highlights"); + assertTrue(expected.equals(result)); + } + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 0770ea3938c..c83172955c1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import org.junit.jupiter.api.Test; @@ -19,6 +21,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -67,4 +70,16 @@ public void visitAD() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); assertNotNull(implementor.visitAD(node, indexScan)); } + + @Test + public void visitHighlight() { + LogicalHighlight node = Mockito.mock(LogicalHighlight.class, + Answers.RETURNS_DEEP_STUBS); + Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); + OpenSearchIndex.OpenSearchDefaultImplementor implementor = + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + + implementor.visitHighlight(node, indexScan); + verify(indexScan).pushDownHighlight(any()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java index 429c639da98..41769914d95 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java @@ -31,6 +31,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; +import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -110,6 +111,16 @@ void pushDownFilters() { .filter(QueryBuilders.rangeQuery("balance").gte(10000))); } + @Test + void pushDownHighlight() { + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title") + .pushDownHighlight("Body") + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + new HighlightBuilder().field("Title").field("Body")); + } + private PushDownAssertion assertThat() { return new PushDownAssertion(client, exprValueFactory, settings); } @@ -135,6 +146,22 @@ PushDownAssertion pushDown(QueryBuilder query) { return this; } + PushDownAssertion pushDownHighlight(String query) { + indexScan.pushDownHighlight(query); + return this; + } + + PushDownAssertion shouldQueryHighlight(QueryBuilder query, HighlightBuilder highlight) { + OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); + request.getSourceBuilder() + .query(query) + .highlighter(highlight) + .sort(DOC_FIELD_NAME, ASC); + when(client.search(request)).thenReturn(response); + indexScan.open(); + return this; + } + PushDownAssertion shouldQuery(QueryBuilder expected) { OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); request.getSourceBuilder() diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 8b74e31aacc..9fca2942cff 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -346,6 +346,7 @@ TIE_BREAKER: 'TIE_BREAKER'; TIME_ZONE: 'TIME_ZONE'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; +HIGHLIGHT: 'HIGHLIGHT'; // RELEVANCE FUNCTIONS MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index d75893b6961..a9316b55a4d 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -300,6 +300,11 @@ functionCall | aggregateFunction #aggregateFunctionCall | aggregateFunction (orderByClause)? filterClause #filteredAggregationFunctionCall | relevanceFunction #relevanceFunctionCall + | highlightFunction #highlightFunctionCall + ; + +highlightFunction + : HIGHLIGHT LR_BRACKET relevanceField RR_BRACKET ; scalarFunctionName diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 3c686b5e8e3..453162e3356 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -63,6 +63,7 @@ import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.Literal; @@ -130,6 +131,12 @@ public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ct return visitFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs()); } + @Override + public UnresolvedExpression visitHighlightFunctionCall( + OpenSearchSQLParser.HighlightFunctionCallContext ctx) { + return new HighlightFunction(visit(ctx.highlightFunction().relevanceField())); + } + @Override public UnresolvedExpression visitTableFilter(TableFilterContext ctx) { return new Function( diff --git a/sql/src/test/java/org/opensearch/sql/sql/antlr/HighlightTest.java b/sql/src/test/java/org/opensearch/sql/sql/antlr/HighlightTest.java new file mode 100644 index 00000000000..6826a37c0b5 --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/antlr/HighlightTest.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.antlr; + +import org.junit.jupiter.api.Test; + +public class HighlightTest extends SQLParserTest { + @Test + void single_field_test() { + acceptQuery("SELECT HIGHLIGHT(Tags) FROM Index WHERE MATCH(Tags, 'Time')"); + } + + @Test + void multiple_highlights_test() { + acceptQuery("SELECT HIGHLIGHT(Tags), HIGHLIGHT(Body) FROM Index " + + "WHERE MULTI_MATCH([Tags, Body], 'Time')"); + } + + @Test + void wildcard_test() { + acceptQuery("SELECT HIGHLIGHT('T*') FROM Index " + + "WHERE MULTI_MATCH([Tags, Body], 'Time')"); + } + + @Test + void highlight_all_test() { + + acceptQuery("SELECT HIGHLIGHT('*') FROM Index WHERE MULTI_MATCH([Tags, Body], 'Time')"); + } + + @Test + void multiple_parameters_failure_test() { + rejectQuery("SELECT HIGHLIGHT(Tags1, Tags2) FROM Index " + + "WHERE MULTI_MATCH([Tags, Body], 'Time')"); + } + + @Test + void no_parameters_failure_test() { + rejectQuery("SELECT HIGHLIGHT() FROM Index " + + "WHERE MULTI_MATCH([Tags, Body], 'Time')"); + } +} diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index d576389595a..8bf38b14a67 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.limit; import static org.opensearch.sql.ast.dsl.AstDSL.project; @@ -33,6 +34,7 @@ import com.google.common.collect.ImmutableList; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -667,6 +669,24 @@ public void can_build_limit_clause_with_offset() { buildAST("SELECT name FROM test LIMIT 5, 10")); } + @Test + public void can_build_qualified_name_highlight() { + assertEquals( + project(relation("test"), + alias("highlight(fieldA)", highlight(AstDSL.qualifiedName("fieldA")))), + buildAST("SELECT highlight(fieldA) FROM test") + ); + } + + @Test + public void can_build_string_literal_highlight() { + assertEquals( + project(relation("test"), + alias("highlight(\"fieldA\")", highlight(AstDSL.stringLiteral("fieldA")))), + buildAST("SELECT highlight(\"fieldA\") FROM test") + ); + } + private UnresolvedPlan buildAST(String query) { ParseTree parseTree = parser.parse(query); return parseTree.accept(new AstBuilder(query)); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index 9ce3dd6714d..ef881275e5b 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.intervalLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.longLiteral; @@ -308,6 +309,22 @@ public void canBuildWindowFunctionWithNullOrderSpecified() { buildExprAst("DENSE_RANK() OVER (ORDER BY age ASC NULLS LAST)")); } + @Test + public void canBuildStringLiteralHighlightFunction() { + assertEquals( + highlight(AstDSL.stringLiteral("fieldA")), + buildExprAst("highlight(\"fieldA\")") + ); + } + + @Test + public void canBuildQualifiedNameHighlightFunction() { + assertEquals( + highlight(AstDSL.qualifiedName("fieldA")), + buildExprAst("highlight(fieldA)") + ); + } + @Test public void canBuildWindowFunctionWithoutOrderBy() { assertEquals(