diff --git a/CHANGELOG.md b/CHANGELOG.md index 708ce86ffa..6795eae21e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x) ### Features ### Enhancements +* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index c7bd7eee25..4028d78ed3 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -242,7 +242,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String currentFieldName = null; boolean ignoreUnmapped = false; XContentParser.Token token; - KNNCounter.KNN_QUERY_REQUESTS.increment(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); @@ -279,7 +278,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String tokenName = parser.currentName(); if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); - KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); filter = parseInnerQueryBuilder(parser); } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); @@ -300,6 +298,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep validateSingleQueryType(k, maxDistance, minScore); + updateQueryStats(k, minScore, maxDistance, filter); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) .boost(boost) @@ -566,4 +566,23 @@ private static void validateSingleQueryType(Integer k, Float distance, Float sco throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); } } + + private static void updateQueryStats(Integer k, Float minScore, Float maxDistance, QueryBuilder filter) { + if (k != null) { + KNNCounter.KNN_QUERY_REQUESTS.increment(); + if (filter != null) { + KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment(); + } + } else if (minScore != null) { + KNNCounter.MIN_SCORE_QUERY_REQUESTS.increment(); + if (filter != null) { + KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.increment(); + } + } else if (maxDistance != null) { + KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.increment(); + if (filter != null) { + KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.increment(); + } + } + } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java index ce04c90780..3bcc3399c8 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java @@ -22,7 +22,11 @@ public enum KNNCounter { SCRIPT_QUERY_ERRORS("script_query_errors"), TRAINING_REQUESTS("training_requests"), TRAINING_ERRORS("training_errors"), - KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"); + KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"), + MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"), + MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"), + MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"), + MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests"); private String name; private AtomicLong count; diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 07d1296527..2e458ca62b 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -80,16 +80,18 @@ private Map> getClusterOrNodeStats(Boolean getClusterStats) { private Map> buildStatsMap() { ImmutableMap.Builder> builder = ImmutableMap.>builder(); - addQueryStats(builder); + addKNNQueryStats(builder); addNativeMemoryStats(builder); addEngineStats(builder); addScriptStats(builder); addModelStats(builder); addGraphStats(builder); + addMinScoreQueryStats(builder); + addMaxDistanceQueryStats(builder); return builder.build(); } - private void addQueryStats(ImmutableMap.Builder> builder) { + private void addKNNQueryStats(ImmutableMap.Builder> builder) { builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) .put( StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(), @@ -98,6 +100,28 @@ private void addQueryStats(ImmutableMap.Builder> builder) { } + private void addMinScoreQueryStats(ImmutableMap.Builder> builder) { + builder.put( + StatNames.MIN_SCORE_QUERY_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS)) + ) + .put( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS)) + ); + } + + private void addMaxDistanceQueryStats(ImmutableMap.Builder> builder) { + builder.put( + StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS)) + ) + .put( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS)) + ); + } + private void addNativeMemoryStats(ImmutableMap.Builder> builder) { builder.put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::missCount))) diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index e9ed2b126e..e7f4fd4a2d 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -44,7 +44,11 @@ public enum StatNames { KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()), GRAPH_STATS("graph_stats"), REFRESH("refresh"), - MERGE("merge"); + MERGE("merge"), + MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()), + MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()), + MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()), + MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()); private String name; diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 2e13c70490..d9949aaf24 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.action; +import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -30,27 +31,12 @@ import org.opensearch.core.rest.RestStatus; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import static org.opensearch.knn.TestUtils.KNN_VECTOR; import static org.opensearch.knn.TestUtils.PROPERTIES; import static org.opensearch.knn.TestUtils.VECTOR_TYPE; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.NAME; -import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.*; /** * Integration tests to check the correctness of RestKNNStatsHandler @@ -432,6 +418,95 @@ public void testFieldsByEngineModelTraining() throws Exception { assertTrue(faissField); } + public void testRadialSearchStats_thenSucceed() throws Exception { + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME)); + Float[] vector = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); + + // First search: radial search by min score + XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMinScore.startObject("knn"); + queryBuilderMinScore.startObject(FIELD_NAME); + queryBuilderMinScore.field("vector", vector); + queryBuilderMinScore.field(MIN_SCORE, 0.95f); + queryBuilderMinScore.endObject(); + queryBuilderMinScore.endObject(); + queryBuilderMinScore.endObject().endObject(); + + Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1); + Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + + assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch); + + // Second search: radial search by min score with filter + XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMinScoreWithFilter.startObject("knn"); + queryBuilderMinScoreWithFilter.startObject(FIELD_NAME); + queryBuilderMinScoreWithFilter.field("vector", vector); + queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f); + queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1")); + queryBuilderMinScoreWithFilter.endObject(); + queryBuilderMinScoreWithFilter.endObject(); + queryBuilderMinScoreWithFilter.endObject().endObject(); + + Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1); + Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount( + StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName()); + + assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch); + assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch); + + // Third search: radial search by max distance + XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMaxDistance.startObject("knn"); + queryBuilderMaxDistance.startObject(FIELD_NAME); + queryBuilderMaxDistance.field("vector", vector); + queryBuilderMaxDistance.field(MAX_DISTANCE, 100f); + queryBuilderMaxDistance.endObject(); + queryBuilderMaxDistance.endObject(); + queryBuilderMaxDistance.endObject().endObject(); + + Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0); + Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + + assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch); + + // Fourth search: radial search by max distance with filter + XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilderMaxDistanceWithFilter.startObject("knn"); + queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME); + queryBuilderMaxDistanceWithFilter.field("vector", vector); + queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f); + queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1")); + queryBuilderMaxDistanceWithFilter.endObject(); + queryBuilderMaxDistanceWithFilter.endObject(); + queryBuilderMaxDistanceWithFilter.endObject().endObject(); + + Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0); + Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount( + StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName() + ); + Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName()); + + assertEquals( + 1, + maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch + ); + assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch); + } + public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -487,4 +562,11 @@ protected Settings restClientSettings() { return super.restClientSettings(); } } + + @SneakyThrows + private Integer getStatCount(String statName) { + Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName); + } }