Skip to content

Commit

Permalink
Add stats for radial search (#1684)
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei authored May 14, 2024
1 parent c315862 commit 9a52b2b
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
Expand Down
57 changes: 57 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorQueryType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.stats.KNNCounter;

@Getter
public enum VectorQueryType {
K(KNNConstants.K) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.KNN_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS;
}
},
MIN_SCORE(KNNConstants.MIN_SCORE) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.MIN_SCORE_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS;
}
},
MAX_DISTANCE(KNNConstants.MAX_DISTANCE) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.MAX_DISTANCE_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS;
}
};

private final String queryTypeName;

VectorQueryType(String queryTypeName) {
this.queryTypeName = queryTypeName;
}

public abstract KNNCounter getQueryStatCounter();

public abstract KNNCounter getQueryWithFilterStatCounter();
}
18 changes: 13 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
Expand Down Expand Up @@ -242,7 +242,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
Expand Down Expand Up @@ -279,7 +278,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
Expand All @@ -298,7 +296,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

validateSingleQueryType(k, maxDistance, minScore);
VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore);
vectorQueryType.getQueryStatCounter().increment();
if (filter != null) {
vectorQueryType.getQueryWithFilterStatCounter().increment();
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand Down Expand Up @@ -549,21 +551,27 @@ public String getWriteableName() {
return NAME;
}

private static void validateSingleQueryType(Integer k, Float distance, Float score) {
private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) {
int countSetFields = 0;
VectorQueryType vectorQueryType = null;

if (k != null && k != 0) {
countSetFields++;
vectorQueryType = VectorQueryType.K;
}
if (distance != null) {
countSetFields++;
vectorQueryType = VectorQueryType.MAX_DISTANCE;
}
if (score != null) {
countSetFields++;
vectorQueryType = VectorQueryType.MIN_SCORE;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
}

return vectorQueryType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ public enum KNNCounter {
SCRIPT_QUERY_ERRORS("script_query_errors"),
TRAINING_REQUESTS("training_requests"),
TRAINING_ERRORS("training_errors"),
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests");
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"),
MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"),
MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests");

private String name;
private AtomicLong count;
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,32 @@ private Map<String, KNNStat<?>> buildStatsMap() {
}

private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
// KNN Query Stats
builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS)))
.put(
StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS))
);

// Min Score Query Stats
builder.put(
StatNames.MIN_SCORE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS))
)
.put(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS))
);

// Max Distance Query Stats
builder.put(
StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS))
)
.put(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS))
);
}

private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/org/opensearch/knn/plugin/stats/StatNames.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ public enum StatNames {
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
GRAPH_STATS("graph_stats"),
REFRESH("refresh"),
MERGE("merge");
MERGE("merge"),
MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()),
MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName());

private String name;

Expand Down
30 changes: 30 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorQueryTypeTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class VectorQueryTypeTests extends KNNTestCase {

public void testGetQueryStatCounter() {
assertEquals(KNNCounter.KNN_QUERY_REQUESTS, VectorQueryType.K.getQueryStatCounter());
assertEquals(KNNCounter.MIN_SCORE_QUERY_REQUESTS, VectorQueryType.MIN_SCORE.getQueryStatCounter());
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryStatCounter());
}

public void testGetQueryWithFilterStatCounter() {
assertEquals(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.K.getQueryWithFilterStatCounter());
assertEquals(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MIN_SCORE.getQueryWithFilterStatCounter());
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryWithFilterStatCounter());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.action;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -30,27 +31,12 @@
import org.opensearch.core.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;

import static org.opensearch.knn.TestUtils.KNN_VECTOR;
import static org.opensearch.knn.TestUtils.PROPERTIES;
import static org.opensearch.knn.TestUtils.VECTOR_TYPE;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.*;

/**
* Integration tests to check the correctness of RestKNNStatsHandler
Expand Down Expand Up @@ -432,6 +418,95 @@ public void testFieldsByEngineModelTraining() throws Exception {
assertTrue(faissField);
}

public void testRadialSearchStats_thenSucceed() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME));
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);

// First search: radial search by min score
XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScore.startObject("knn");
queryBuilderMinScore.startObject(FIELD_NAME);
queryBuilderMinScore.field("vector", vector);
queryBuilderMinScore.field(MIN_SCORE, 0.95f);
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject().endObject();

Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1);
Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch);

// Second search: radial search by min score with filter
XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScoreWithFilter.startObject("knn");
queryBuilderMinScoreWithFilter.startObject(FIELD_NAME);
queryBuilderMinScoreWithFilter.field("vector", vector);
queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f);
queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject().endObject();

Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1);
Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch);
assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch);

// Third search: radial search by max distance
XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistance.startObject("knn");
queryBuilderMaxDistance.startObject(FIELD_NAME);
queryBuilderMaxDistance.field("vector", vector);
queryBuilderMaxDistance.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject().endObject();

Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0);
Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch);

// Fourth search: radial search by max distance with filter
XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistanceWithFilter.startObject("knn");
queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME);
queryBuilderMaxDistanceWithFilter.field("vector", vector);
queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject().endObject();

Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0);
Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(
1,
maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch
);
assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch);
}

public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description)
throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -487,4 +562,11 @@ protected Settings restClientSettings() {
return super.restClientSettings();
}
}

@SneakyThrows
private Integer getStatCount(String statName) {
Response response = getKnnStats(Collections.emptyList(), Collections.emptyList());
String responseBody = EntityUtils.toString(response.getEntity());
return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName);
}
}

0 comments on commit 9a52b2b

Please sign in to comment.