Skip to content

Commit

Permalink
Add stats for radial search
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed May 8, 2024
1 parent 73d5425 commit c812547
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x)
### Features
### Enhancements
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
23 changes: 21 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
Expand Down Expand Up @@ -279,7 +278,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
Expand All @@ -300,6 +298,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep

validateSingleQueryType(k, maxDistance, minScore);

updateQueryStats(k, minScore, maxDistance, filter);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
.boost(boost)
Expand Down Expand Up @@ -566,4 +566,23 @@ private static void validateSingleQueryType(Integer k, Float distance, Float sco
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
}
}

private static void updateQueryStats(Integer k, Float minScore, Float maxDistance, QueryBuilder filter) {
if (k != null) {
KNNCounter.KNN_QUERY_REQUESTS.increment();
if (filter != null) {
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
}
} else if (minScore != null) {
KNNCounter.MIN_SCORE_QUERY_REQUESTS.increment();
if (filter != null) {
KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.increment();
}
} else if (maxDistance != null) {
KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.increment();
if (filter != null) {
KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.increment();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ public enum KNNCounter {
SCRIPT_QUERY_ERRORS("script_query_errors"),
TRAINING_REQUESTS("training_requests"),
TRAINING_ERRORS("training_errors"),
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests");
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"),
MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"),
MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests");

private String name;
private AtomicLong count;
Expand Down
28 changes: 26 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,18 @@ private Map<String, KNNStat<?>> getClusterOrNodeStats(Boolean getClusterStats) {

private Map<String, KNNStat<?>> buildStatsMap() {
ImmutableMap.Builder<String, KNNStat<?>> builder = ImmutableMap.<String, KNNStat<?>>builder();
addQueryStats(builder);
addKNNQueryStats(builder);
addNativeMemoryStats(builder);
addEngineStats(builder);
addScriptStats(builder);
addModelStats(builder);
addGraphStats(builder);
addMinScoreQueryStats(builder);
addMaxDistanceQueryStats(builder);
return builder.build();
}

private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
private void addKNNQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS)))
.put(
StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(),
Expand All @@ -98,6 +100,28 @@ private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {

}

private void addMinScoreQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(
StatNames.MIN_SCORE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS))
)
.put(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS))
);
}

private void addMaxDistanceQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(
StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS))
)
.put(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS))
);
}

private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::hitCount)))
.put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::missCount)))
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/org/opensearch/knn/plugin/stats/StatNames.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ public enum StatNames {
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
GRAPH_STATS("graph_stats"),
REFRESH("refresh"),
MERGE("merge");
MERGE("merge"),
MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()),
MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName());

private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.action;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -30,27 +31,12 @@
import org.opensearch.core.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;

import static org.opensearch.knn.TestUtils.KNN_VECTOR;
import static org.opensearch.knn.TestUtils.PROPERTIES;
import static org.opensearch.knn.TestUtils.VECTOR_TYPE;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.*;

/**
* Integration tests to check the correctness of RestKNNStatsHandler
Expand Down Expand Up @@ -432,6 +418,95 @@ public void testFieldsByEngineModelTraining() throws Exception {
assertTrue(faissField);
}

public void testRadialSearchStats_thenSucceed() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME));
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);

// First search: radial search by min score
XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScore.startObject("knn");
queryBuilderMinScore.startObject(FIELD_NAME);
queryBuilderMinScore.field("vector", vector);
queryBuilderMinScore.field(MIN_SCORE, 0.95f);
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject().endObject();

Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1);
Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch);

// Second search: radial search by min score with filter
XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScoreWithFilter.startObject("knn");
queryBuilderMinScoreWithFilter.startObject(FIELD_NAME);
queryBuilderMinScoreWithFilter.field("vector", vector);
queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f);
queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject().endObject();

Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1);
Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch);
assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch);

// Third search: radial search by max distance
XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistance.startObject("knn");
queryBuilderMaxDistance.startObject(FIELD_NAME);
queryBuilderMaxDistance.field("vector", vector);
queryBuilderMaxDistance.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject().endObject();

Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0);
Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch);

// Fourth search: radial search by max distance with filter
XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistanceWithFilter.startObject("knn");
queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME);
queryBuilderMaxDistanceWithFilter.field("vector", vector);
queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject().endObject();

Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0);
Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(
1,
maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch
);
assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch);
}

public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description)
throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -487,4 +562,11 @@ protected Settings restClientSettings() {
return super.restClientSettings();
}
}

@SneakyThrows
private Integer getStatCount(String statName) {
Response response = getKnnStats(Collections.emptyList(), Collections.emptyList());
String responseBody = EntityUtils.toString(response.getEntity());
return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName);
}
}

0 comments on commit c812547

Please sign in to comment.