Skip to content

Commit b1ab358

Browse files
committed
Add 'key' field to 'function_score' query function definition in explanation response (opensearch-project#1711)
Signed-off-by: Andriy Redko <[email protected]>
1 parent bb2d3af commit b1ab358

31 files changed

+1127
-79
lines changed

server/src/internalClusterTest/java/org/opensearch/search/functionscore/DecayFunctionScoreIT.java

+201
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder;
5353
import org.opensearch.index.query.functionscore.ScoreFunctionBuilders;
5454
import org.opensearch.search.MultiValueMode;
55+
import org.opensearch.search.SearchHit;
5556
import org.opensearch.search.SearchHits;
5657
import org.opensearch.test.OpenSearchIntegTestCase;
5758
import org.opensearch.test.VersionUtils;
@@ -78,7 +79,9 @@
7879
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits;
7980
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits;
8081
import static org.hamcrest.Matchers.anyOf;
82+
import static org.hamcrest.Matchers.arrayWithSize;
8183
import static org.hamcrest.Matchers.closeTo;
84+
import static org.hamcrest.Matchers.containsString;
8285
import static org.hamcrest.Matchers.equalTo;
8386
import static org.hamcrest.Matchers.is;
8487
import static org.hamcrest.Matchers.lessThan;
@@ -627,6 +630,76 @@ public void testCombineModes() throws Exception {
627630

628631
}
629632

633+
public void testCombineModesExplain() throws Exception {
634+
assertAcked(
635+
prepareCreate("test").addMapping(
636+
"type1",
637+
jsonBuilder().startObject()
638+
.startObject("type1")
639+
.startObject("properties")
640+
.startObject("test")
641+
.field("type", "text")
642+
.endObject()
643+
.startObject("num")
644+
.field("type", "double")
645+
.endObject()
646+
.endObject()
647+
.endObject()
648+
.endObject()
649+
)
650+
);
651+
652+
client().prepareIndex()
653+
.setId("1")
654+
.setIndex("test")
655+
.setRefreshPolicy(IMMEDIATE)
656+
.setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject())
657+
.get();
658+
659+
FunctionScoreQueryBuilder baseQuery = functionScoreQuery(
660+
constantScoreQuery(termQuery("test", "value")).queryName("query1"),
661+
ScoreFunctionBuilders.weightFactorFunction(2, "weight1")
662+
);
663+
// decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score
664+
ActionFuture<SearchResponse> response = client().search(
665+
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
666+
.source(
667+
searchSource().explain(true)
668+
.query(
669+
functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode(
670+
CombineFunction.MULTIPLY
671+
)
672+
)
673+
)
674+
);
675+
SearchResponse sr = response.actionGet();
676+
SearchHits sh = sr.getHits();
677+
assertThat(sh.getTotalHits().value, equalTo((long) (1)));
678+
assertThat(sh.getAt(0).getId(), equalTo("1"));
679+
assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2));
680+
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2));
681+
// "description": "ConstantScore(test:value) (_name: query1)"
682+
assertThat(
683+
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(),
684+
equalTo("ConstantScore(test:value) (_name: query1)")
685+
);
686+
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2));
687+
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2));
688+
// "description": "constant score 1.0(_name: func1) - no function provided"
689+
assertThat(
690+
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
691+
equalTo("constant score 1.0(_name: weight1) - no function provided")
692+
);
693+
// "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817,
694+
// _name: func2)"
695+
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
696+
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
697+
assertThat(
698+
sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
699+
containsString("_name: func2")
700+
);
701+
}
702+
630703
public void testExceptionThrownIfScaleLE0() throws Exception {
631704
assertAcked(
632705
prepareCreate("test").addMapping(
@@ -1232,4 +1305,132 @@ public void testMultiFieldOptions() throws Exception {
12321305
sh = sr.getHits();
12331306
assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d));
12341307
}
1308+
1309+
public void testDistanceScoreGeoLinGaussExplain() throws Exception {
1310+
assertAcked(
1311+
prepareCreate("test").addMapping(
1312+
"type1",
1313+
jsonBuilder().startObject()
1314+
.startObject("type1")
1315+
.startObject("properties")
1316+
.startObject("test")
1317+
.field("type", "text")
1318+
.endObject()
1319+
.startObject("loc")
1320+
.field("type", "geo_point")
1321+
.endObject()
1322+
.endObject()
1323+
.endObject()
1324+
.endObject()
1325+
)
1326+
);
1327+
1328+
List<IndexRequestBuilder> indexBuilders = new ArrayList<>();
1329+
indexBuilders.add(
1330+
client().prepareIndex()
1331+
.setId("1")
1332+
.setIndex("test")
1333+
.setSource(
1334+
jsonBuilder().startObject()
1335+
.field("test", "value")
1336+
.startObject("loc")
1337+
.field("lat", 10)
1338+
.field("lon", 20)
1339+
.endObject()
1340+
.endObject()
1341+
)
1342+
);
1343+
indexBuilders.add(
1344+
client().prepareIndex()
1345+
.setId("2")
1346+
.setIndex("test")
1347+
.setSource(
1348+
jsonBuilder().startObject()
1349+
.field("test", "value")
1350+
.startObject("loc")
1351+
.field("lat", 11)
1352+
.field("lon", 22)
1353+
.endObject()
1354+
.endObject()
1355+
)
1356+
);
1357+
1358+
indexRandom(true, indexBuilders);
1359+
1360+
// Test Gauss
1361+
List<Float> lonlat = new ArrayList<>();
1362+
lonlat.add(20f);
1363+
lonlat.add(11f);
1364+
1365+
final String queryName = "query1";
1366+
final String functionName = "func1";
1367+
ActionFuture<SearchResponse> response = client().search(
1368+
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
1369+
.source(
1370+
searchSource().explain(true)
1371+
.query(
1372+
functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName))
1373+
)
1374+
)
1375+
);
1376+
SearchResponse sr = response.actionGet();
1377+
SearchHits sh = sr.getHits();
1378+
assertThat(sh.getTotalHits().value, equalTo(2L));
1379+
assertThat(sh.getAt(0).getId(), equalTo("1"));
1380+
assertThat(sh.getAt(1).getId(), equalTo("2"));
1381+
assertExplain(queryName, functionName, sr);
1382+
1383+
response = client().search(
1384+
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
1385+
.source(
1386+
searchSource().explain(true)
1387+
.query(
1388+
functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName))
1389+
)
1390+
)
1391+
);
1392+
1393+
sr = response.actionGet();
1394+
sh = sr.getHits();
1395+
assertThat(sh.getTotalHits().value, equalTo(2L));
1396+
assertThat(sh.getAt(0).getId(), equalTo("1"));
1397+
assertThat(sh.getAt(1).getId(), equalTo("2"));
1398+
assertExplain(queryName, functionName, sr);
1399+
1400+
response = client().search(
1401+
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
1402+
.source(
1403+
searchSource().explain(true)
1404+
.query(
1405+
functionScoreQuery(
1406+
baseQuery.queryName(queryName),
1407+
exponentialDecayFunction("loc", lonlat, "1000km", functionName)
1408+
)
1409+
)
1410+
)
1411+
);
1412+
1413+
sr = response.actionGet();
1414+
sh = sr.getHits();
1415+
assertThat(sh.getTotalHits().value, equalTo(2L));
1416+
assertThat(sh.getAt(0).getId(), equalTo("1"));
1417+
assertThat(sh.getAt(1).getId(), equalTo("2"));
1418+
assertExplain(queryName, functionName, sr);
1419+
}
1420+
1421+
private void assertExplain(final String queryName, final String functionName, SearchResponse sr) {
1422+
SearchHit firstHit = sr.getHits().getAt(0);
1423+
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
1424+
// "description": "*:* (_name: query1)"
1425+
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
1426+
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
1427+
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
1428+
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
1429+
// "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin))
1430+
// - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)"
1431+
assertThat(
1432+
firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(),
1433+
containsString("_name: " + functionName)
1434+
);
1435+
}
12351436
}

server/src/internalClusterTest/java/org/opensearch/search/functionscore/ExplainableScriptIT.java

+44-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.opensearch.action.search.SearchResponse;
3939
import org.opensearch.action.search.SearchType;
4040
import org.opensearch.common.lucene.search.function.CombineFunction;
41+
import org.opensearch.common.lucene.search.function.Functions;
4142
import org.opensearch.common.settings.Settings;
4243
import org.opensearch.index.fielddata.ScriptDocValues;
4344
import org.opensearch.plugins.Plugin;
@@ -72,6 +73,7 @@
7273
import static org.opensearch.index.query.QueryBuilders.termQuery;
7374
import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
7475
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
76+
import static org.hamcrest.Matchers.arrayWithSize;
7577
import static org.hamcrest.Matchers.containsString;
7678
import static org.hamcrest.Matchers.equalTo;
7779

@@ -121,8 +123,17 @@ static class MyScript extends ScoreScript implements ExplainableScoreScript {
121123

122124
@Override
123125
public Explanation explain(Explanation subQueryScore) throws IOException {
126+
return explain(subQueryScore, null);
127+
}
128+
129+
@Override
130+
public Explanation explain(Explanation subQueryScore, String functionName) throws IOException {
124131
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
125-
return Explanation.match((float) (execute(null)), "This script returned " + execute(null), scoreExp);
132+
return Explanation.match(
133+
(float) (execute(null)),
134+
"This script" + Functions.nameOrEmptyFunc(functionName) + " returned " + execute(null),
135+
scoreExp
136+
);
126137
}
127138

128139
@Override
@@ -174,4 +185,36 @@ public void testExplainScript() throws InterruptedException, IOException, Execut
174185
idCounter--;
175186
}
176187
}
188+
189+
public void testExplainScriptWithName() throws InterruptedException, IOException, ExecutionException {
190+
List<IndexRequestBuilder> indexRequests = new ArrayList<>();
191+
indexRequests.add(
192+
client().prepareIndex("test", "type")
193+
.setId(Integer.toString(1))
194+
.setSource(jsonBuilder().startObject().field("number_field", 1).field("text", "text").endObject())
195+
);
196+
indexRandom(true, true, indexRequests);
197+
client().admin().indices().prepareRefresh().get();
198+
ensureYellow();
199+
SearchResponse response = client().search(
200+
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
201+
.source(
202+
searchSource().explain(true)
203+
.query(
204+
functionScoreQuery(
205+
termQuery("text", "text"),
206+
scriptFunction(new Script(ScriptType.INLINE, "test", "explainable_script", Collections.emptyMap()), "func1")
207+
).boostMode(CombineFunction.REPLACE)
208+
)
209+
)
210+
).actionGet();
211+
212+
OpenSearchAssertions.assertNoFailures(response);
213+
SearchHits hits = response.getHits();
214+
assertThat(hits.getTotalHits().value, equalTo(1L));
215+
assertThat(hits.getHits()[0].getId(), equalTo("1"));
216+
assertThat(hits.getHits()[0].getExplanation().getDetails(), arrayWithSize(2));
217+
assertThat(hits.getHits()[0].getExplanation().getDetails()[0].getDescription(), containsString("_name: func1"));
218+
}
219+
177220
}

server/src/internalClusterTest/java/org/opensearch/search/functionscore/FunctionScoreFieldValueIT.java

+46
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@
3535
import org.opensearch.action.search.SearchPhaseExecutionException;
3636
import org.opensearch.action.search.SearchResponse;
3737
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
38+
import org.opensearch.search.SearchHit;
3839
import org.opensearch.test.OpenSearchIntegTestCase;
3940

4041
import java.io.IOException;
4142

43+
import static org.hamcrest.Matchers.containsString;
44+
import static org.hamcrest.Matchers.arrayWithSize;
4245
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
4346
import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
4447
import static org.opensearch.index.query.QueryBuilders.matchAllQuery;
@@ -163,4 +166,47 @@ public void testFieldValueFactor() throws IOException {
163166
// locally, instead of just having failures
164167
}
165168
}
169+
170+
public void testFieldValueFactorExplain() throws IOException {
171+
assertAcked(
172+
prepareCreate("test").addMapping(
173+
"type1",
174+
jsonBuilder().startObject()
175+
.startObject("type1")
176+
.startObject("properties")
177+
.startObject("test")
178+
.field("type", randomFrom(new String[] { "short", "float", "long", "integer", "double" }))
179+
.endObject()
180+
.startObject("body")
181+
.field("type", "text")
182+
.endObject()
183+
.endObject()
184+
.endObject()
185+
.endObject()
186+
).get()
187+
);
188+
189+
client().prepareIndex("test", "type1", "1").setSource("test", 5, "body", "foo").get();
190+
client().prepareIndex("test", "type1", "2").setSource("test", 17, "body", "foo").get();
191+
client().prepareIndex("test", "type1", "3").setSource("body", "bar").get();
192+
193+
refresh();
194+
195+
// document 2 scores higher because 17 > 5
196+
final String functionName = "func1";
197+
final String queryName = "query";
198+
SearchResponse response = client().prepareSearch("test")
199+
.setExplain(true)
200+
.setQuery(
201+
functionScoreQuery(simpleQueryStringQuery("foo").queryName(queryName), fieldValueFactorFunction("test", functionName))
202+
)
203+
.get();
204+
assertOrderedSearchHits(response, "2", "1");
205+
SearchHit firstHit = response.getHits().getAt(0);
206+
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
207+
// "description": "sum of: (_name: query)"
208+
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: " + queryName));
209+
// "description": "field value function(_name: func1): none(doc['test'].value * factor=1.0)"
210+
assertThat(firstHit.getExplanation().getDetails()[1].toString(), containsString("_name: " + functionName));
211+
}
166212
}

0 commit comments

Comments
 (0)