Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stats for radial search #1684

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
Expand Down
57 changes: 57 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorQueryType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.Getter;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.stats.KNNCounter;

@Getter
public enum VectorQueryType {
K(KNNConstants.K) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.KNN_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS;
}
},
MIN_SCORE(KNNConstants.MIN_SCORE) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.MIN_SCORE_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS;
}
},
MAX_DISTANCE(KNNConstants.MAX_DISTANCE) {
@Override
public KNNCounter getQueryStatCounter() {
return KNNCounter.MAX_DISTANCE_QUERY_REQUESTS;
}

@Override
public KNNCounter getQueryWithFilterStatCounter() {
return KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS;
}
};

private final String queryTypeName;

VectorQueryType(String queryTypeName) {
this.queryTypeName = queryTypeName;
}

public abstract KNNCounter getQueryStatCounter();

public abstract KNNCounter getQueryWithFilterStatCounter();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
Expand Down Expand Up @@ -242,7 +242,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String currentFieldName = null;
boolean ignoreUnmapped = false;
XContentParser.Token token;
KNNCounter.KNN_QUERY_REQUESTS.increment();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
Expand Down Expand Up @@ -279,7 +278,6 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
filter = parseInnerQueryBuilder(parser);
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
Expand All @@ -298,7 +296,11 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
}
}

validateSingleQueryType(k, maxDistance, minScore);
VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore);
junqiu-lei marked this conversation as resolved.
Show resolved Hide resolved
vectorQueryType.getQueryStatCounter().increment();
if (filter != null) {
vectorQueryType.getQueryWithFilterStatCounter().increment();
}

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand Down Expand Up @@ -549,21 +551,27 @@ public String getWriteableName() {
return NAME;
}

private static void validateSingleQueryType(Integer k, Float distance, Float score) {
private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) {
int countSetFields = 0;
VectorQueryType vectorQueryType = null;

if (k != null && k != 0) {
countSetFields++;
vectorQueryType = VectorQueryType.K;
}
if (distance != null) {
countSetFields++;
vectorQueryType = VectorQueryType.MAX_DISTANCE;
}
if (score != null) {
countSetFields++;
vectorQueryType = VectorQueryType.MIN_SCORE;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
}

return vectorQueryType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ public enum KNNCounter {
SCRIPT_QUERY_ERRORS("script_query_errors"),
TRAINING_REQUESTS("training_requests"),
TRAINING_ERRORS("training_errors"),
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests");
KNN_QUERY_WITH_FILTER_REQUESTS("knn_query_with_filter_requests"),
MIN_SCORE_QUERY_REQUESTS("min_score_query_requests"),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS("min_score_query_with_filter_requests"),
MAX_DISTANCE_QUERY_REQUESTS("max_distance_query_requests"),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS("max_distance_query_with_filter_requests");

private String name;
private AtomicLong count;
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,32 @@ private Map<String, KNNStat<?>> buildStatsMap() {
}

private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
// KNN Query Stats
builder.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS)))
.put(
StatNames.KNN_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS))
);

// Min Score Query Stats
builder.put(
StatNames.MIN_SCORE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_REQUESTS))
)
.put(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS))
);

// Max Distance Query Stats
builder.put(
StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS))
)
.put(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS))
);
}

private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ public enum StatNames {
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
GRAPH_STATS("graph_stats"),
REFRESH("refresh"),
MERGE("merge");
MERGE("merge"),
MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()),
MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()),
MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()),
MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName());

private String name;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class VectorQueryTypeTests extends KNNTestCase {

public void testGetQueryStatCounter() {
assertEquals(KNNCounter.KNN_QUERY_REQUESTS, VectorQueryType.K.getQueryStatCounter());
assertEquals(KNNCounter.MIN_SCORE_QUERY_REQUESTS, VectorQueryType.MIN_SCORE.getQueryStatCounter());
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryStatCounter());
}

public void testGetQueryWithFilterStatCounter() {
assertEquals(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.K.getQueryWithFilterStatCounter());
assertEquals(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MIN_SCORE.getQueryWithFilterStatCounter());
assertEquals(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS, VectorQueryType.MAX_DISTANCE.getQueryWithFilterStatCounter());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.plugin.action;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -30,27 +31,12 @@
import org.opensearch.core.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;

import static org.opensearch.knn.TestUtils.KNN_VECTOR;
import static org.opensearch.knn.TestUtils.PROPERTIES;
import static org.opensearch.knn.TestUtils.VECTOR_TYPE;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.*;

/**
* Integration tests to check the correctness of RestKNNStatsHandler
Expand Down Expand Up @@ -432,6 +418,95 @@ public void testFieldsByEngineModelTraining() throws Exception {
assertTrue(faissField);
}

public void testRadialSearchStats_thenSucceed() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, LUCENE_NAME));
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);

// First search: radial search by min score
XContentBuilder queryBuilderMinScore = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScore.startObject("knn");
queryBuilderMinScore.startObject(FIELD_NAME);
queryBuilderMinScore.field("vector", vector);
queryBuilderMinScore.field(MIN_SCORE, 0.95f);
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject();
queryBuilderMinScore.endObject().endObject();

Integer minScoreStatBeforeMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScore, 1);
Integer minScoreStatAfterMinScoreSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreStatAfterMinScoreSearch - minScoreStatBeforeMinScoreSearch);

// Second search: radial search by min score with filter
XContentBuilder queryBuilderMinScoreWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMinScoreWithFilter.startObject("knn");
queryBuilderMinScoreWithFilter.startObject(FIELD_NAME);
queryBuilderMinScoreWithFilter.field("vector", vector);
queryBuilderMinScoreWithFilter.field(MIN_SCORE, 0.95f);
queryBuilderMinScoreWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject();
queryBuilderMinScoreWithFilter.endObject().endObject();

Integer minScoreWithFilterStatBeforeMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatBeforeMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMinScoreWithFilter, 1);
Integer minScoreWithFilterStatAfterMinScoreWithFilterSearch = getStatCount(
StatNames.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer minScoreStatAfterMinScoreWithFilterSearch = getStatCount(StatNames.MIN_SCORE_QUERY_REQUESTS.getName());

assertEquals(1, minScoreWithFilterStatAfterMinScoreWithFilterSearch - minScoreWithFilterStatBeforeMinScoreWithFilterSearch);
assertEquals(1, minScoreStatAfterMinScoreWithFilterSearch - minScoreStatBeforeMinScoreWithFilterSearch);

// Third search: radial search by max distance
XContentBuilder queryBuilderMaxDistance = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistance.startObject("knn");
queryBuilderMaxDistance.startObject(FIELD_NAME);
queryBuilderMaxDistance.field("vector", vector);
queryBuilderMaxDistance.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject();
queryBuilderMaxDistance.endObject().endObject();

Integer maxDistanceStatBeforeMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistance, 0);
Integer maxDistanceStatAfterMaxDistanceSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(1, maxDistanceStatAfterMaxDistanceSearch - maxDistanceStatBeforeMaxDistanceSearch);

// Fourth search: radial search by max distance with filter
XContentBuilder queryBuilderMaxDistanceWithFilter = XContentFactory.jsonBuilder().startObject().startObject("query");
queryBuilderMaxDistanceWithFilter.startObject("knn");
queryBuilderMaxDistanceWithFilter.startObject(FIELD_NAME);
queryBuilderMaxDistanceWithFilter.field("vector", vector);
queryBuilderMaxDistanceWithFilter.field(MAX_DISTANCE, 100f);
queryBuilderMaxDistanceWithFilter.field("filter", QueryBuilders.termQuery("_id", "1"));
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject();
queryBuilderMaxDistanceWithFilter.endObject().endObject();

Integer maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatBeforeMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());
searchKNNIndex(INDEX_NAME, queryBuilderMaxDistanceWithFilter, 0);
Integer maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch = getStatCount(
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName()
);
Integer maxDistanceStatAfterMaxDistanceWithFilterSearch = getStatCount(StatNames.MAX_DISTANCE_QUERY_REQUESTS.getName());

assertEquals(
1,
maxDistanceWithFilterStatAfterMaxDistanceWithFilterSearch - maxDistanceWithFilterStatBeforeMaxDistanceWithFilterSearch
);
assertEquals(1, maxDistanceStatAfterMaxDistanceWithFilterSearch - maxDistanceStatBeforeMaxDistanceWithFilterSearch);
}

public void trainKnnModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, String description)
throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -487,4 +562,11 @@ protected Settings restClientSettings() {
return super.restClientSettings();
}
}

@SneakyThrows
private Integer getStatCount(String statName) {
Response response = getKnnStats(Collections.emptyList(), Collections.emptyList());
String responseBody = EntityUtils.toString(response.getEntity());
return (Integer) parseNodeStatsResponse(responseBody).get(0).get(statName);
}
}
Loading