Skip to content

Commit

Permalink
Add '_name' field support to score functions and provide it back in e…
Browse files Browse the repository at this point in the history
…xplanation response (#2244)

* Add '_name' field support to score functions and provide it back in explanation response

Signed-off-by: Andriy Redko <[email protected]>

* Address code review comments

Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
reta authored Mar 4, 2022
1 parent ae14259 commit 5f90227
Show file tree
Hide file tree
Showing 31 changed files with 1,127 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder;
import org.opensearch.index.query.functionscore.ScoreFunctionBuilders;
import org.opensearch.search.MultiValueMode;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.VersionUtils;
Expand All @@ -77,7 +78,9 @@
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
Expand Down Expand Up @@ -616,6 +619,76 @@ public void testCombineModes() throws Exception {

}

public void testCombineModesExplain() throws Exception {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", "text")
.endObject()
.startObject("num")
.field("type", "double")
.endObject()
.endObject()
.endObject()
.endObject()
)
);

client().prepareIndex()
.setId("1")
.setIndex("test")
.setRefreshPolicy(IMMEDIATE)
.setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject())
.get();

FunctionScoreQueryBuilder baseQuery = functionScoreQuery(
constantScoreQuery(termQuery("test", "value")).queryName("query1"),
ScoreFunctionBuilders.weightFactorFunction(2, "weight1")
);
// decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode(
CombineFunction.MULTIPLY
)
)
)
);
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2));
// "description": "ConstantScore(test:value) (_name: query1)"
assertThat(
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(),
equalTo("ConstantScore(test:value) (_name: query1)")
);
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2));
// "description": "constant score 1.0(_name: func1) - no function provided"
assertThat(
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
equalTo("constant score 1.0(_name: weight1) - no function provided")
);
// "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817,
// _name: func2)"
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
assertThat(
sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
containsString("_name: func2")
);
}

public void testExceptionThrownIfScaleLE0() throws Exception {
assertAcked(
prepareCreate("test").addMapping(
Expand Down Expand Up @@ -1195,4 +1268,132 @@ public void testMultiFieldOptions() throws Exception {
sh = sr.getHits();
assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d));
}

public void testDistanceScoreGeoLinGaussExplain() throws Exception {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", "text")
.endObject()
.startObject("loc")
.field("type", "geo_point")
.endObject()
.endObject()
.endObject()
.endObject()
)
);

List<IndexRequestBuilder> indexBuilders = new ArrayList<>();
indexBuilders.add(
client().prepareIndex()
.setId("1")
.setIndex("test")
.setSource(
jsonBuilder().startObject()
.field("test", "value")
.startObject("loc")
.field("lat", 10)
.field("lon", 20)
.endObject()
.endObject()
)
);
indexBuilders.add(
client().prepareIndex()
.setId("2")
.setIndex("test")
.setSource(
jsonBuilder().startObject()
.field("test", "value")
.startObject("loc")
.field("lat", 11)
.field("lon", 22)
.endObject()
.endObject()
)
);

indexRandom(true, indexBuilders);

// Test Gauss
List<Float> lonlat = new ArrayList<>();
lonlat.add(20f);
lonlat.add(11f);

final String queryName = "query1";
final String functionName = "func1";
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName))
)
)
);
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);

response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName))
)
)
);

sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);

response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(
baseQuery.queryName(queryName),
exponentialDecayFunction("loc", lonlat, "1000km", functionName)
)
)
)
);

sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);
}

private void assertExplain(final String queryName, final String functionName, SearchResponse sr) {
SearchHit firstHit = sr.getHits().getAt(0);
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "*:* (_name: query1)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
// "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin))
// - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)"
assertThat(
firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(),
containsString("_name: " + functionName)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType;
import org.opensearch.common.lucene.search.function.CombineFunction;
import org.opensearch.common.lucene.search.function.Functions;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.plugins.Plugin;
Expand Down Expand Up @@ -72,6 +73,7 @@
import static org.opensearch.index.query.QueryBuilders.termQuery;
import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

Expand Down Expand Up @@ -121,8 +123,17 @@ static class MyScript extends ScoreScript implements ExplainableScoreScript {

@Override
public Explanation explain(Explanation subQueryScore) throws IOException {
return explain(subQueryScore, null);
}

@Override
public Explanation explain(Explanation subQueryScore, String functionName) throws IOException {
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
return Explanation.match((float) (execute(null)), "This script returned " + execute(null), scoreExp);
return Explanation.match(
(float) (execute(null)),
"This script" + Functions.nameOrEmptyFunc(functionName) + " returned " + execute(null),
scoreExp
);
}

@Override
Expand Down Expand Up @@ -174,4 +185,36 @@ public void testExplainScript() throws InterruptedException, IOException, Execut
idCounter--;
}
}

public void testExplainScriptWithName() throws InterruptedException, IOException, ExecutionException {
List<IndexRequestBuilder> indexRequests = new ArrayList<>();
indexRequests.add(
client().prepareIndex("test")
.setId(Integer.toString(1))
.setSource(jsonBuilder().startObject().field("number_field", 1).field("text", "text").endObject())
);
indexRandom(true, true, indexRequests);
client().admin().indices().prepareRefresh().get();
ensureYellow();
SearchResponse response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(
termQuery("text", "text"),
scriptFunction(new Script(ScriptType.INLINE, "test", "explainable_script", Collections.emptyMap()), "func1")
).boostMode(CombineFunction.REPLACE)
)
)
).actionGet();

OpenSearchAssertions.assertNoFailures(response);
SearchHits hits = response.getHits();
assertThat(hits.getTotalHits().value, equalTo(1L));
assertThat(hits.getHits()[0].getId(), equalTo("1"));
assertThat(hits.getHits()[0].getExplanation().getDetails(), arrayWithSize(2));
assertThat(hits.getHits()[0].getExplanation().getDetails()[0].getDescription(), containsString("_name: func1"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
import org.opensearch.action.search.SearchPhaseExecutionException;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
import org.opensearch.search.SearchHit;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
import static org.opensearch.index.query.QueryBuilders.matchAllQuery;
Expand Down Expand Up @@ -163,4 +166,47 @@ public void testFieldValueFactor() throws IOException {
// locally, instead of just having failures
}
}

public void testFieldValueFactorExplain() throws IOException {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", randomFrom(new String[] { "short", "float", "long", "integer", "double" }))
.endObject()
.startObject("body")
.field("type", "text")
.endObject()
.endObject()
.endObject()
.endObject()
).get()
);

client().prepareIndex("test").setId("1").setSource("test", 5, "body", "foo").get();
client().prepareIndex("test").setId("2").setSource("test", 17, "body", "foo").get();
client().prepareIndex("test").setId("3").setSource("body", "bar").get();

refresh();

// document 2 scores higher because 17 > 5
final String functionName = "func1";
final String queryName = "query";
SearchResponse response = client().prepareSearch("test")
.setExplain(true)
.setQuery(
functionScoreQuery(simpleQueryStringQuery("foo").queryName(queryName), fieldValueFactorFunction("test", functionName))
)
.get();
assertOrderedSearchHits(response, "2", "1");
SearchHit firstHit = response.getHits().getAt(0);
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "sum of: (_name: query)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: " + queryName));
// "description": "field value function(_name: func1): none(doc['test'].value * factor=1.0)"
assertThat(firstHit.getExplanation().getDetails()[1].toString(), containsString("_name: " + functionName));
}
}
Loading

0 comments on commit 5f90227

Please sign in to comment.