|
52 | 52 | import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder;
|
53 | 53 | import org.opensearch.index.query.functionscore.ScoreFunctionBuilders;
|
54 | 54 | import org.opensearch.search.MultiValueMode;
|
| 55 | +import org.opensearch.search.SearchHit; |
55 | 56 | import org.opensearch.search.SearchHits;
|
56 | 57 | import org.opensearch.test.OpenSearchIntegTestCase;
|
57 | 58 | import org.opensearch.test.VersionUtils;
|
|
78 | 79 | import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits;
|
79 | 80 | import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits;
|
80 | 81 | import static org.hamcrest.Matchers.anyOf;
|
| 82 | +import static org.hamcrest.Matchers.arrayWithSize; |
81 | 83 | import static org.hamcrest.Matchers.closeTo;
|
| 84 | +import static org.hamcrest.Matchers.containsString; |
82 | 85 | import static org.hamcrest.Matchers.equalTo;
|
83 | 86 | import static org.hamcrest.Matchers.is;
|
84 | 87 | import static org.hamcrest.Matchers.lessThan;
|
@@ -627,6 +630,76 @@ public void testCombineModes() throws Exception {
|
627 | 630 |
|
628 | 631 | }
|
629 | 632 |
|
| 633 | + public void testCombineModesExplain() throws Exception { |
| 634 | + assertAcked( |
| 635 | + prepareCreate("test").addMapping( |
| 636 | + "type1", |
| 637 | + jsonBuilder().startObject() |
| 638 | + .startObject("type1") |
| 639 | + .startObject("properties") |
| 640 | + .startObject("test") |
| 641 | + .field("type", "text") |
| 642 | + .endObject() |
| 643 | + .startObject("num") |
| 644 | + .field("type", "double") |
| 645 | + .endObject() |
| 646 | + .endObject() |
| 647 | + .endObject() |
| 648 | + .endObject() |
| 649 | + ) |
| 650 | + ); |
| 651 | + |
| 652 | + client().prepareIndex() |
| 653 | + .setId("1") |
| 654 | + .setIndex("test") |
| 655 | + .setRefreshPolicy(IMMEDIATE) |
| 656 | + .setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject()) |
| 657 | + .get(); |
| 658 | + |
| 659 | + FunctionScoreQueryBuilder baseQuery = functionScoreQuery( |
| 660 | + constantScoreQuery(termQuery("test", "value")).queryName("query1"), |
| 661 | + ScoreFunctionBuilders.weightFactorFunction(2, "weight1") |
| 662 | + ); |
| 663 | + // decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score |
| 664 | + ActionFuture<SearchResponse> response = client().search( |
| 665 | + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) |
| 666 | + .source( |
| 667 | + searchSource().explain(true) |
| 668 | + .query( |
| 669 | + functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode( |
| 670 | + CombineFunction.MULTIPLY |
| 671 | + ) |
| 672 | + ) |
| 673 | + ) |
| 674 | + ); |
| 675 | + SearchResponse sr = response.actionGet(); |
| 676 | + SearchHits sh = sr.getHits(); |
| 677 | + assertThat(sh.getTotalHits().value, equalTo((long) (1))); |
| 678 | + assertThat(sh.getAt(0).getId(), equalTo("1")); |
| 679 | + assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2)); |
| 680 | + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2)); |
| 681 | + // "description": "ConstantScore(test:value) (_name: query1)" |
| 682 | + assertThat( |
| 683 | + sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(), |
| 684 | + equalTo("ConstantScore(test:value) (_name: query1)") |
| 685 | + ); |
| 686 | + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2)); |
| 687 | + assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2)); |
| 688 | + // "description": "constant score 1.0(_name: func1) - no function provided" |
| 689 | + assertThat( |
| 690 | + sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(), |
| 691 | + equalTo("constant score 1.0(_name: weight1) - no function provided") |
| 692 | + ); |
| 693 | + // "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817, |
| 694 | + // _name: func2)" |
| 695 | + assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); |
| 696 | + assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1)); |
| 697 | + assertThat( |
| 698 | + sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(), |
| 699 | + containsString("_name: func2") |
| 700 | + ); |
| 701 | + } |
| 702 | + |
630 | 703 | public void testExceptionThrownIfScaleLE0() throws Exception {
|
631 | 704 | assertAcked(
|
632 | 705 | prepareCreate("test").addMapping(
|
@@ -1232,4 +1305,132 @@ public void testMultiFieldOptions() throws Exception {
|
1232 | 1305 | sh = sr.getHits();
|
1233 | 1306 | assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d));
|
1234 | 1307 | }
|
| 1308 | + |
| 1309 | + public void testDistanceScoreGeoLinGaussExplain() throws Exception { |
| 1310 | + assertAcked( |
| 1311 | + prepareCreate("test").addMapping( |
| 1312 | + "type1", |
| 1313 | + jsonBuilder().startObject() |
| 1314 | + .startObject("type1") |
| 1315 | + .startObject("properties") |
| 1316 | + .startObject("test") |
| 1317 | + .field("type", "text") |
| 1318 | + .endObject() |
| 1319 | + .startObject("loc") |
| 1320 | + .field("type", "geo_point") |
| 1321 | + .endObject() |
| 1322 | + .endObject() |
| 1323 | + .endObject() |
| 1324 | + .endObject() |
| 1325 | + ) |
| 1326 | + ); |
| 1327 | + |
| 1328 | + List<IndexRequestBuilder> indexBuilders = new ArrayList<>(); |
| 1329 | + indexBuilders.add( |
| 1330 | + client().prepareIndex() |
| 1331 | + .setId("1") |
| 1332 | + .setIndex("test") |
| 1333 | + .setSource( |
| 1334 | + jsonBuilder().startObject() |
| 1335 | + .field("test", "value") |
| 1336 | + .startObject("loc") |
| 1337 | + .field("lat", 10) |
| 1338 | + .field("lon", 20) |
| 1339 | + .endObject() |
| 1340 | + .endObject() |
| 1341 | + ) |
| 1342 | + ); |
| 1343 | + indexBuilders.add( |
| 1344 | + client().prepareIndex() |
| 1345 | + .setId("2") |
| 1346 | + .setIndex("test") |
| 1347 | + .setSource( |
| 1348 | + jsonBuilder().startObject() |
| 1349 | + .field("test", "value") |
| 1350 | + .startObject("loc") |
| 1351 | + .field("lat", 11) |
| 1352 | + .field("lon", 22) |
| 1353 | + .endObject() |
| 1354 | + .endObject() |
| 1355 | + ) |
| 1356 | + ); |
| 1357 | + |
| 1358 | + indexRandom(true, indexBuilders); |
| 1359 | + |
| 1360 | + // Test Gauss |
| 1361 | + List<Float> lonlat = new ArrayList<>(); |
| 1362 | + lonlat.add(20f); |
| 1363 | + lonlat.add(11f); |
| 1364 | + |
| 1365 | + final String queryName = "query1"; |
| 1366 | + final String functionName = "func1"; |
| 1367 | + ActionFuture<SearchResponse> response = client().search( |
| 1368 | + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) |
| 1369 | + .source( |
| 1370 | + searchSource().explain(true) |
| 1371 | + .query( |
| 1372 | + functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName)) |
| 1373 | + ) |
| 1374 | + ) |
| 1375 | + ); |
| 1376 | + SearchResponse sr = response.actionGet(); |
| 1377 | + SearchHits sh = sr.getHits(); |
| 1378 | + assertThat(sh.getTotalHits().value, equalTo(2L)); |
| 1379 | + assertThat(sh.getAt(0).getId(), equalTo("1")); |
| 1380 | + assertThat(sh.getAt(1).getId(), equalTo("2")); |
| 1381 | + assertExplain(queryName, functionName, sr); |
| 1382 | + |
| 1383 | + response = client().search( |
| 1384 | + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) |
| 1385 | + .source( |
| 1386 | + searchSource().explain(true) |
| 1387 | + .query( |
| 1388 | + functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName)) |
| 1389 | + ) |
| 1390 | + ) |
| 1391 | + ); |
| 1392 | + |
| 1393 | + sr = response.actionGet(); |
| 1394 | + sh = sr.getHits(); |
| 1395 | + assertThat(sh.getTotalHits().value, equalTo(2L)); |
| 1396 | + assertThat(sh.getAt(0).getId(), equalTo("1")); |
| 1397 | + assertThat(sh.getAt(1).getId(), equalTo("2")); |
| 1398 | + assertExplain(queryName, functionName, sr); |
| 1399 | + |
| 1400 | + response = client().search( |
| 1401 | + searchRequest().searchType(SearchType.QUERY_THEN_FETCH) |
| 1402 | + .source( |
| 1403 | + searchSource().explain(true) |
| 1404 | + .query( |
| 1405 | + functionScoreQuery( |
| 1406 | + baseQuery.queryName(queryName), |
| 1407 | + exponentialDecayFunction("loc", lonlat, "1000km", functionName) |
| 1408 | + ) |
| 1409 | + ) |
| 1410 | + ) |
| 1411 | + ); |
| 1412 | + |
| 1413 | + sr = response.actionGet(); |
| 1414 | + sh = sr.getHits(); |
| 1415 | + assertThat(sh.getTotalHits().value, equalTo(2L)); |
| 1416 | + assertThat(sh.getAt(0).getId(), equalTo("1")); |
| 1417 | + assertThat(sh.getAt(1).getId(), equalTo("2")); |
| 1418 | + assertExplain(queryName, functionName, sr); |
| 1419 | + } |
| 1420 | + |
| 1421 | + private void assertExplain(final String queryName, final String functionName, SearchResponse sr) { |
| 1422 | + SearchHit firstHit = sr.getHits().getAt(0); |
| 1423 | + assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2)); |
| 1424 | + // "description": "*:* (_name: query1)" |
| 1425 | + assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName)); |
| 1426 | + assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2)); |
| 1427 | + // "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)" |
| 1428 | + assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1)); |
| 1429 | + // "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin)) |
| 1430 | + // - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)" |
| 1431 | + assertThat( |
| 1432 | + firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(), |
| 1433 | + containsString("_name: " + functionName) |
| 1434 | + ); |
| 1435 | + } |
1235 | 1436 | }
|
0 commit comments