Skip to content

Commit d03c176

Browse files
authored
Validate field and fields parameters in relevance search functions (opensearch-project#1067)
Change relevance functions that query fields to throw a SemanticCheckException when a field is queried that does not exist. Signed-off-by: forestmvey <[email protected]>
1 parent 94b6bec commit d03c176

File tree

18 files changed

+259
-153
lines changed

18 files changed

+259
-153
lines changed

core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
package org.opensearch.sql.expression.function;
77

8-
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
9-
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
10-
11-
import com.google.common.collect.ImmutableMap;
128
import java.util.List;
139
import java.util.stream.Collectors;
1410
import lombok.experimental.UtilityClass;
@@ -48,46 +44,46 @@ public void register(BuiltinFunctionRepository repository) {
4844

4945
private static FunctionResolver match_bool_prefix() {
5046
FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName();
51-
return new RelevanceFunctionResolver(name, STRING);
47+
return new RelevanceFunctionResolver(name);
5248
}
5349

5450
private static FunctionResolver match(BuiltinFunctionName match) {
5551
FunctionName funcName = match.getName();
56-
return new RelevanceFunctionResolver(funcName, STRING);
52+
return new RelevanceFunctionResolver(funcName);
5753
}
5854

5955
private static FunctionResolver match_phrase_prefix() {
6056
FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName();
61-
return new RelevanceFunctionResolver(funcName, STRING);
57+
return new RelevanceFunctionResolver(funcName);
6258
}
6359

6460
private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) {
6561
FunctionName funcName = matchPhrase.getName();
66-
return new RelevanceFunctionResolver(funcName, STRING);
62+
return new RelevanceFunctionResolver(funcName);
6763
}
6864

6965
private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) {
70-
return new RelevanceFunctionResolver(multiMatchName.getName(), STRUCT);
66+
return new RelevanceFunctionResolver(multiMatchName.getName());
7167
}
7268

7369
private static FunctionResolver simple_query_string() {
7470
FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName();
75-
return new RelevanceFunctionResolver(funcName, STRUCT);
71+
return new RelevanceFunctionResolver(funcName);
7672
}
7773

7874
private static FunctionResolver query() {
7975
FunctionName funcName = BuiltinFunctionName.QUERY.getName();
80-
return new RelevanceFunctionResolver(funcName, STRING);
76+
return new RelevanceFunctionResolver(funcName);
8177
}
8278

8379
private static FunctionResolver query_string() {
8480
FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName();
85-
return new RelevanceFunctionResolver(funcName, STRUCT);
81+
return new RelevanceFunctionResolver(funcName);
8682
}
8783

8884
private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) {
8985
FunctionName funcName = wildcardQuery.getName();
90-
return new RelevanceFunctionResolver(funcName, STRING);
86+
return new RelevanceFunctionResolver(funcName);
9187
}
9288

9389
public static class OpenSearchFunction extends FunctionExpression {

core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,13 @@ public class RelevanceFunctionResolver
2020
@Getter
2121
private final FunctionName functionName;
2222

23-
@Getter
24-
private final ExprType declaredFirstParamType;
25-
2623
@Override
2724
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
2825
if (!unresolvedSignature.getFunctionName().equals(functionName)) {
2926
throw new SemanticCheckException(String.format("Expected '%s' but got '%s'",
3027
functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName()));
3128
}
3229
List<ExprType> paramTypes = unresolvedSignature.getParamTypeList();
33-
ExprType providedFirstParamType = paramTypes.get(0);
34-
35-
// Check if the first parameter is of the specified type.
36-
if (!declaredFirstParamType.equals(providedFirstParamType)) {
37-
throw new SemanticCheckException(
38-
getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType));
39-
}
40-
4130
// Check if all but the first parameter are of type STRING.
4231
for (int i = 1; i < paramTypes.size(); i++) {
4332
ExprType paramType = paramTypes.get(i);

core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,10 @@ public void named_non_parse_expression() {
374374
void match_bool_prefix_expression() {
375375
assertAnalyzeEqual(
376376
DSL.match_bool_prefix(
377-
DSL.namedArgument("field", DSL.literal("fieldA")),
377+
DSL.namedArgument("field", DSL.literal("field_value1")),
378378
DSL.namedArgument("query", DSL.literal("sample query"))),
379379
AstDSL.function("match_bool_prefix",
380-
AstDSL.unresolvedArg("field", stringLiteral("fieldA")),
380+
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
381381
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
382382
}
383383

@@ -418,11 +418,11 @@ void multi_match_expression() {
418418
DSL.multi_match(
419419
DSL.namedArgument("fields", DSL.literal(
420420
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
421-
"field", ExprValueUtils.floatValue(1.F)))))),
421+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
422422
DSL.namedArgument("query", DSL.literal("sample query"))),
423423
AstDSL.function("multi_match",
424424
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
425-
"field", 1.F))),
425+
"field_value1", 1.F))),
426426
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
427427
}
428428

@@ -432,12 +432,12 @@ void multi_match_expression_with_params() {
432432
DSL.multi_match(
433433
DSL.namedArgument("fields", DSL.literal(
434434
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
435-
"field", ExprValueUtils.floatValue(1.F)))))),
435+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
436436
DSL.namedArgument("query", DSL.literal("sample query")),
437437
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
438438
AstDSL.function("multi_match",
439439
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
440-
"field", 1.F))),
440+
"field_value1", 1.F))),
441441
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
442442
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
443443
}
@@ -448,12 +448,12 @@ void multi_match_expression_two_fields() {
448448
DSL.multi_match(
449449
DSL.namedArgument("fields", DSL.literal(
450450
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
451-
"field1", ExprValueUtils.floatValue(1.F),
452-
"field2", ExprValueUtils.floatValue(.3F)))))),
451+
"field_value1", ExprValueUtils.floatValue(1.F),
452+
"field_value2", ExprValueUtils.floatValue(.3F)))))),
453453
DSL.namedArgument("query", DSL.literal("sample query"))),
454454
AstDSL.function("multi_match",
455455
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
456-
"field1", 1.F, "field2", .3F))),
456+
"field_value1", 1.F, "field_value2", .3F))),
457457
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
458458
}
459459

@@ -463,11 +463,11 @@ void simple_query_string_expression() {
463463
DSL.simple_query_string(
464464
DSL.namedArgument("fields", DSL.literal(
465465
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
466-
"field", ExprValueUtils.floatValue(1.F)))))),
466+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
467467
DSL.namedArgument("query", DSL.literal("sample query"))),
468468
AstDSL.function("simple_query_string",
469469
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
470-
"field", 1.F))),
470+
"field_value1", 1.F))),
471471
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
472472
}
473473

@@ -477,12 +477,12 @@ void simple_query_string_expression_with_params() {
477477
DSL.simple_query_string(
478478
DSL.namedArgument("fields", DSL.literal(
479479
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
480-
"field", ExprValueUtils.floatValue(1.F)))))),
480+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
481481
DSL.namedArgument("query", DSL.literal("sample query")),
482482
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
483483
AstDSL.function("simple_query_string",
484484
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
485-
"field", 1.F))),
485+
"field_value1", 1.F))),
486486
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
487487
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
488488
}
@@ -493,12 +493,12 @@ void simple_query_string_expression_two_fields() {
493493
DSL.simple_query_string(
494494
DSL.namedArgument("fields", DSL.literal(
495495
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
496-
"field1", ExprValueUtils.floatValue(1.F),
497-
"field2", ExprValueUtils.floatValue(.3F)))))),
496+
"field_value1", ExprValueUtils.floatValue(1.F),
497+
"field_value2", ExprValueUtils.floatValue(.3F)))))),
498498
DSL.namedArgument("query", DSL.literal("sample query"))),
499499
AstDSL.function("simple_query_string",
500500
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
501-
"field1", 1.F, "field2", .3F))),
501+
"field_value1", 1.F, "field_value2", .3F))),
502502
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
503503
}
504504

@@ -517,11 +517,11 @@ void query_string_expression() {
517517
DSL.query_string(
518518
DSL.namedArgument("fields", DSL.literal(
519519
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
520-
"field", ExprValueUtils.floatValue(1.F)))))),
520+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
521521
DSL.namedArgument("query", DSL.literal("query_value"))),
522522
AstDSL.function("query_string",
523523
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
524-
"field", 1.F))),
524+
"field_value1", 1.F))),
525525
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
526526
}
527527

@@ -531,12 +531,12 @@ void query_string_expression_with_params() {
531531
DSL.query_string(
532532
DSL.namedArgument("fields", DSL.literal(
533533
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
534-
"field", ExprValueUtils.floatValue(1.F)))))),
534+
"field_value1", ExprValueUtils.floatValue(1.F)))))),
535535
DSL.namedArgument("query", DSL.literal("query_value")),
536536
DSL.namedArgument("escape", DSL.literal("false"))),
537537
AstDSL.function("query_string",
538538
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
539-
"field", 1.F))),
539+
"field_value1", 1.F))),
540540
AstDSL.unresolvedArg("query", stringLiteral("query_value")),
541541
AstDSL.unresolvedArg("escape", stringLiteral("false"))));
542542
}
@@ -547,12 +547,12 @@ void query_string_expression_two_fields() {
547547
DSL.query_string(
548548
DSL.namedArgument("fields", DSL.literal(
549549
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
550-
"field1", ExprValueUtils.floatValue(1.F),
551-
"field2", ExprValueUtils.floatValue(.3F)))))),
550+
"field_value1", ExprValueUtils.floatValue(1.F),
551+
"field_value2", ExprValueUtils.floatValue(.3F)))))),
552552
DSL.namedArgument("query", DSL.literal("query_value"))),
553553
AstDSL.function("query_string",
554554
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
555-
"field1", 1.F, "field2", .3F))),
555+
"field_value1", 1.F, "field_value2", .3F))),
556556
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
557557
}
558558

@@ -588,7 +588,7 @@ void wildcard_query_expression_all_params() {
588588
public void match_phrase_prefix_all_params() {
589589
assertAnalyzeEqual(
590590
DSL.match_phrase_prefix(
591-
DSL.namedArgument("field", "test"),
591+
DSL.namedArgument("field", "field_value1"),
592592
DSL.namedArgument("query", "search query"),
593593
DSL.namedArgument("slop", "3"),
594594
DSL.namedArgument("boost", "1.5"),
@@ -597,7 +597,7 @@ public void match_phrase_prefix_all_params() {
597597
DSL.namedArgument("zero_terms_query", "NONE")
598598
),
599599
AstDSL.function("match_phrase_prefix",
600-
unresolvedArg("field", stringLiteral("test")),
600+
unresolvedArg("field", stringLiteral("field_value1")),
601601
unresolvedArg("query", stringLiteral("search query")),
602602
unresolvedArg("slop", stringLiteral("3")),
603603
unresolvedArg("boost", stringLiteral("1.5")),

core/src/test/java/org/opensearch/sql/config/TestConfig.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ public class TestConfig {
5757
.put("struct_value", ExprCoreType.STRUCT)
5858
.put("array_value", ExprCoreType.ARRAY)
5959
.put("timestamp_value", ExprCoreType.TIMESTAMP)
60+
.put("field_value1", ExprCoreType.STRING)
61+
.put("field_value2", ExprCoreType.STRING)
6062
.build();
6163

6264
@Bean

core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class RelevanceFunctionResolverTest {
2424

2525
@BeforeEach
2626
void setUp() {
27-
resolver = new RelevanceFunctionResolver(sampleFuncName, STRING);
27+
resolver = new RelevanceFunctionResolver(sampleFuncName);
2828
}
2929

3030
@Test
@@ -44,15 +44,6 @@ void resolve_invalid_name_test() {
4444
exception.getMessage());
4545
}
4646

47-
@Test
48-
void resolve_invalid_first_param_type_test() {
49-
var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER));
50-
Exception exception = assertThrows(SemanticCheckException.class,
51-
() -> resolver.resolve(sig));
52-
assertEquals("Expected type STRING instead of INTEGER for parameter #1",
53-
exception.getMessage());
54-
}
55-
5647
@Test
5748
void resolve_invalid_third_param_type_test() {
5849
var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING));

integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.json.JSONObject;
1616
import org.junit.Test;
1717
import org.opensearch.sql.legacy.SQLIntegTestCase;
18+
import org.opensearch.sql.legacy.utils.StringUtils;
1819

1920
public class MatchIT extends SQLIntegTestCase {
2021
@Override
@@ -36,6 +37,42 @@ public void match_in_having() throws IOException {
3637
verifyDataRows(result, rows("Bates"));
3738
}
3839

40+
@Test
41+
public void missing_field_test() {
42+
String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT);
43+
final RuntimeException exception =
44+
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
45+
46+
assertTrue(exception.getMessage()
47+
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));
48+
49+
assertTrue(exception.getMessage().contains("SemanticCheckException"));
50+
}
51+
52+
@Test
53+
public void missing_quoted_field_test() {
54+
String query = StringUtils.format("SELECT * FROM %s WHERE match('invalid', 'Bates')", TEST_INDEX_ACCOUNT);
55+
final RuntimeException exception =
56+
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
57+
58+
assertTrue(exception.getMessage()
59+
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));
60+
61+
assertTrue(exception.getMessage().contains("SemanticCheckException"));
62+
}
63+
64+
@Test
65+
public void missing_backtick_field_test() {
66+
String query = StringUtils.format("SELECT * FROM %s WHERE match(`invalid`, 'Bates')", TEST_INDEX_ACCOUNT);
67+
final RuntimeException exception =
68+
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
69+
70+
assertTrue(exception.getMessage()
71+
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));
72+
73+
assertTrue(exception.getMessage().contains("SemanticCheckException"));
74+
}
75+
3976
@Test
4077
public void matchquery_in_where() throws IOException {
4178
JSONObject result = executeJdbcRequest("SELECT firstname FROM " + TEST_INDEX_ACCOUNT + " WHERE matchquery(lastname, 'Bates')");

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.opensearch.index.query.QueryBuilder;
1111
import org.opensearch.sql.exception.SemanticCheckException;
1212
import org.opensearch.sql.expression.NamedArgumentExpression;
13+
import org.opensearch.sql.expression.ReferenceExpression;
1314

1415
/**
1516
* Base class to represent builder class for relevance queries like match_query, match_bool_prefix,
@@ -36,7 +37,7 @@ protected T createQueryBuilder(List<NamedArgumentExpression> arguments) {
3637
.orElseThrow(() -> new SemanticCheckException("'query' parameter is missing"));
3738

3839
return createBuilder(
39-
field.getValue().valueOf().stringValue(),
40+
((ReferenceExpression)field.getValue()).getAttr(),
4041
query.getValue().valueOf().stringValue());
4142
}
4243

0 commit comments

Comments
 (0)