From 72d41606cae2d5cc386c6f5aff73a68dd42de491 Mon Sep 17 00:00:00 2001 From: Andy Qin Date: Fri, 24 Jan 2025 17:19:02 -0800 Subject: [PATCH 1/4] implements knn style --- .../neuralsearch/plugin/NeuralSearch.java | 32 +++ .../rest/RestNeuralStatsHandler.java | 105 +++++++++ .../neuralsearch/stats/NeuralStat.java | 19 ++ .../neuralsearch/stats/NeuralStats.java | 57 +++++ .../neuralsearch/stats/NeuralStatsInput.java | 214 ++++++++++++++++++ .../neuralsearch/stats/StatNames.java | 42 ++++ .../stats/names/NeuralClusterLevelStat.java | 19 ++ .../stats/names/NeuralNodeLevelStat.java | 20 ++ .../stats/names/NeuralStatLevel.java | 19 ++ .../transport/NeuralStatsAction.java | 29 +++ .../transport/NeuralStatsNodeRequest.java | 60 +++++ .../transport/NeuralStatsNodeResponse.java | 88 +++++++ .../transport/NeuralStatsRequest.java | 103 +++++++++ .../transport/NeuralStatsResponse.java | 96 ++++++++ .../transport/NeuralStatsTransportAction.java | 110 +++++++++ .../rest/RestNeuralStatsHandlerIT.java | 85 +++++++ 16 files changed, 1098 insertions(+) create mode 100644 src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/StatNames.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java create mode 100644 src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index f7ac5d19f3..68aff564c0 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -14,12 +14,19 @@ import java.util.Optional; import java.util.function.Supplier; +import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.util.FeatureFlags; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -55,7 +62,10 @@ import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.neuralsearch.rest.RestNeuralStatsHandler; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.neuralsearch.transport.NeuralStatsAction; +import org.opensearch.neuralsearch.transport.NeuralStatsTransportAction; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; @@ -64,6 +74,8 @@ import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; @@ -85,6 +97,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); public static final String EXPLANATION_RESPONSE_KEY = "explanation_response"; + public static final String NEURAL_BASE_URI = "/_plugins/_neural"; @Override public Collection createComponents( @@ -117,6 +130,25 @@ public List> getQueries() { ); } + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + RestNeuralStatsHandler restNeuralStatsHandler = new RestNeuralStatsHandler(); + return ImmutableList.of(restNeuralStatsHandler); + } + + @Override + public List> getActions() { + return Arrays.asList(new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class)); + } + @Override public List> getExecutorBuilders(Settings settings) { return List.of(HybridQueryExecutor.getExecutorBuilder(settings)); diff --git a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java new file mode 100644 index 0000000000..90e2662bf3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.rest; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; +import org.opensearch.client.node.NodeClient; +import org.opensearch.neuralsearch.transport.NeuralStatsAction; +import org.opensearch.neuralsearch.transport.NeuralStatsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_BASE_URI; + +@Log4j2 +@AllArgsConstructor +public class RestNeuralStatsHandler extends BaseRestHandler { + private static final String NAME = "neural_stats_action"; + + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/{stat}"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/{stat}") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + NeuralStatsRequest neuralStatsRequest = getRequest(request); + + return channel -> client.execute( + NeuralStatsAction.INSTANCE, + neuralStatsRequest, + new RestActions.NodesResponseRestListener<>(channel) + ); + } + + /** + * Creates a NeuralStatsRequest from a RestRequest + * + * @param request Rest request + * @return NeuralStatsRequest + */ + private NeuralStatsRequest getRequest(RestRequest request) { + // parse the nodes the user wants to query + String[] nodeIdsArr = null; + String nodesIdsStr = request.param("nodeId"); + if (StringUtils.isNotEmpty(nodesIdsStr)) { + nodeIdsArr = nodesIdsStr.split(","); + } + + NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr); + neuralStatsRequest.timeout(request.param("timeout")); + + // parse the stats the customer wants to see + Set statsSet = null; + String statsStr = request.param("stat"); + if (StringUtils.isNotEmpty(statsStr)) { + statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); + } + + if (statsSet == null) { + neuralStatsRequest.all(); + } else if (statsSet.size() == 1 && statsSet.contains("_all")) { + neuralStatsRequest.all(); + } else if (statsSet.contains(NeuralStatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException("Request " + request.path() + " contains _all and individual stats"); + } else { + Set invalidStats = new TreeSet<>(); + for (String stat : statsSet) { + if (!neuralStatsRequest.addStat(stat)) { + invalidStats.add(stat); + } + } + + if (!invalidStats.isEmpty()) { + throw new IllegalArgumentException(unrecognized(request, invalidStats, neuralStatsRequest.getStatsToBeRetrieved(), "stat")); + } + + } + log.info("pasta"); + log.info(statsSet); + log.info(neuralStatsRequest.getStatsToBeRetrieved()); + return neuralStatsRequest; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java new file mode 100644 index 0000000000..4acbee527e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import java.util.function.Supplier; + +public class NeuralStat { + private Supplier supplier; + + public NeuralStat(Supplier supplier) { + this.supplier = supplier; + } + + public T getValue() { + return supplier.get(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java new file mode 100644 index 0000000000..7e6d9593a9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import java.util.HashMap; +import java.util.Map; + +public class NeuralStats { + private final Map> neuralStats; + + public NeuralStats() { + this.neuralStats = new HashMap<>(); + neuralStats.put(StatNames.NEURAL_QUERY_COUNT.getName(), new NeuralStat<>(() -> "nqc")); + neuralStats.put(StatNames.HYBRID_QUERY_COUNT.getName(), new NeuralStat<>(() -> "hqc")); + } + + /** + * Get the stats + * + * @return all the stats + */ + public Map> getStats() { + return neuralStats; + } + + /** + * Get a map of the stats that are kept at the node level + * + * @return Map of stats kept at the node level + */ + public Map> getNodeStats() { + return getClusterOrNodeStats(false); + } + + /** + * Get a map of the stats that are kept at the cluster level + * + * @return Map of stats kept at the cluster level + */ + public Map> getClusterStats() { + return getClusterOrNodeStats(true); + } + + private Map> getClusterOrNodeStats(Boolean getClusterStats) { + return neuralStats; + // Map> statsMap = new HashMap<>(); + // + // for (Map.Entry> entry : NeuralStat.entrySet()) { + // if (entry.getValue().isClusterLevel() == getClusterStats) { + // statsMap.put(entry.getKey(), entry.getValue()); + // } + // } + // return statsMap; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java new file mode 100644 index 0000000000..cabe177e38 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.stats; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.neuralsearch.stats.names.NeuralClusterLevelStat; +import org.opensearch.neuralsearch.stats.names.NeuralNodeLevelStat; +import org.opensearch.neuralsearch.stats.names.NeuralStatLevel; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Locale; +import java.util.Set; +import java.util.function.Function; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +public class NeuralStatsInput implements ToXContentObject, Writeable { + public static final String TARGET_STAT_LEVEL = "target_stat_levels"; + public static final String CLUSTER_LEVEL_STATS = "cluster_level_stats"; + public static final String NODE_LEVEL_STATS = "node_level_stats"; + public static final String NODE_IDS = "node_ids"; + + /** + * Determines levels of stats to retrieve. If empty, will not retrieve any + */ + private EnumSet targetStatLevels; + /** + * Which cluster level stats will be retrieved. + */ + private EnumSet clusterLevelStats; + + /** + * Which node level stats will be retrieved. + */ + private EnumSet nodeLevelStats; + + /** + * Which node's stats will be retrieved. + */ + private Set nodeIds; + + @Builder + public NeuralStatsInput( + EnumSet targetStatLevels, + EnumSet clusterLevelStats, + EnumSet nodeLevelStats, + Set nodeIds + ) { + this.targetStatLevels = targetStatLevels; + this.clusterLevelStats = clusterLevelStats; + this.nodeLevelStats = nodeLevelStats; + this.nodeIds = nodeIds; + } + + public NeuralStatsInput() { + this.targetStatLevels = EnumSet.noneOf(NeuralStatLevel.class); + this.clusterLevelStats = EnumSet.noneOf(NeuralClusterLevelStat.class); + this.nodeLevelStats = EnumSet.noneOf(NeuralNodeLevelStat.class); + this.nodeIds = new HashSet<>(); + } + + public NeuralStatsInput(StreamInput input) throws IOException { + targetStatLevels = input.readBoolean() ? input.readEnumSet(NeuralStatLevel.class) : EnumSet.noneOf(NeuralStatLevel.class); + clusterLevelStats = input.readBoolean() ? input.readEnumSet(NeuralClusterLevelStat.class) : EnumSet.noneOf(NeuralClusterLevelStat.class); + nodeLevelStats = input.readBoolean() ? input.readEnumSet(NeuralNodeLevelStat.class) : EnumSet.noneOf(NeuralNodeLevelStat.class); + nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + writeEnumSet(out, targetStatLevels); + writeEnumSet(out, clusterLevelStats); + writeEnumSet(out, nodeLevelStats); + out.writeOptionalStringCollection(nodeIds); + } + + private void writeEnumSet(StreamOutput out, EnumSet set) throws IOException { + if (set != null && set.size() > 0) { + out.writeBoolean(true); + out.writeEnumSet(set); + } else { + out.writeBoolean(false); + } + } + + public static NeuralStatsInput parse(XContentParser parser) throws IOException { + EnumSet targetStatLevels = EnumSet.noneOf(NeuralStatLevel.class); + EnumSet clusterLevelStats = EnumSet.noneOf(NeuralClusterLevelStat.class); + EnumSet nodeLevelStats = EnumSet.noneOf(NeuralNodeLevelStat.class); + Set nodeIds = new HashSet<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TARGET_STAT_LEVEL: + parseField(parser, targetStatLevels, input -> NeuralStatLevel.from(input.toUpperCase(Locale.ROOT)), NeuralStatLevel.class); + break; + case CLUSTER_LEVEL_STATS: + parseField( + parser, + clusterLevelStats, + input -> NeuralClusterLevelStat.from(input.toUpperCase(Locale.ROOT)), + NeuralClusterLevelStat.class + ); + break; + case NODE_LEVEL_STATS: + parseField( + parser, + nodeLevelStats, + input -> NeuralNodeLevelStat.from(input.toUpperCase(Locale.ROOT)), + NeuralNodeLevelStat.class + ); + break; + case NODE_IDS: + parseArrayField(parser, nodeIds); + break; + default: + parser.skipChildren(); + break; + } + } + return NeuralStatsInput + .builder() + .targetStatLevels(targetStatLevels) + .clusterLevelStats(clusterLevelStats) + .nodeLevelStats(nodeLevelStats) + .nodeIds(nodeIds) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (targetStatLevels != null) { + builder.field(TARGET_STAT_LEVEL, targetStatLevels); + } + if (clusterLevelStats != null) { + builder.field(CLUSTER_LEVEL_STATS, clusterLevelStats); + } + if (nodeLevelStats != null) { + builder.field(NODE_LEVEL_STATS, nodeLevelStats); + } + if (nodeIds != null) { + builder.field(NODE_IDS, nodeIds); + } + builder.endObject(); + return builder; + } + + public boolean retrieveAllClusterLevelStats() { + return clusterLevelStats == null || clusterLevelStats.size() == 0; + } + + public boolean retrieveAllNodeLevelStats() { + return nodeLevelStats == null || nodeLevelStats.size() == 0; + } + + + public boolean retrieveStatsOnAllNodes() { + return nodeIds == null || nodeIds.size() == 0; + } + + public boolean retrieveStat(Enum key) { + if (key instanceof NeuralClusterLevelStat) { + return retrieveAllClusterLevelStats() || clusterLevelStats.contains(key); + } + if (key instanceof NeuralNodeLevelStat) { + return retrieveAllNodeLevelStats() || nodeLevelStats.contains(key); + } + return false; + } + + public boolean onlyRetrieveClusterLevelStats() { + if (targetStatLevels == null || targetStatLevels.size() == 0) { + return false; + } + return !targetStatLevels.contains(NeuralStatLevel.NODE); + } + + public static void parseArrayField(XContentParser parser, Set set) throws IOException { + parseField(parser, set, null, String.class); + } + + public static void parseField(XContentParser parser, Set set, Function function, Class clazz) throws IOException { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + String value = parser.text(); + if (function != null) { + set.add(function.apply(value)); + } else { + if (clazz.isInstance(value)) { + set.add(clazz.cast(value)); + } + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java b/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java new file mode 100644 index 0000000000..d1cc9aa3be --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import java.util.HashSet; +import java.util.Set; + +public enum StatNames { + NEURAL_QUERY_COUNT("neural_query_count"), + HYBRID_QUERY_COUNT("hybrid_query_count"); + + private String name; + + StatNames(String name) { + this.name = name; + } + + /** + * Get stat name + * + * @return name + */ + public String getName() { + return name; + } + + /** + * Get all stat names + * + * @return set of all stat names + */ + public static Set getNames() { + Set names = new HashSet<>(); + + for (StatNames statName : StatNames.values()) { + names.add(statName.getName()); + } + return names; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java new file mode 100644 index 0000000000..6b9c56800f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.stats.names; + +public enum NeuralClusterLevelStat +{ + FORCE_INFERENCE; + + public static NeuralClusterLevelStat from(String value) { + try { + return NeuralClusterLevelStat.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("No such neural cluster level stat"); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java new file mode 100644 index 0000000000..de42db3f8b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.stats.names; + +public enum NeuralNodeLevelStat +{ + NEURAL_QUERY_COUNT, + HYBRID_QUERY_COUNT; + + public static NeuralNodeLevelStat from(String value) { + try { + return NeuralNodeLevelStat.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("No such neural node level stat"); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java new file mode 100644 index 0000000000..e2cfa56a68 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.stats.names; + +public enum NeuralStatLevel { + CLUSTER, + NODE; + + public static NeuralStatLevel from(String value) { + try { + return NeuralStatLevel.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("No such neural stat level"); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java new file mode 100644 index 0000000000..85683b871c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * NeuralStatsAction class + */ +public class NeuralStatsAction extends ActionType { + + public static final NeuralStatsAction INSTANCE = new NeuralStatsAction(); + public static final String NAME = "cluster:admin/neural_stats_action"; // TODO : figure this out + + /** + * Constructor + */ + private NeuralStatsAction() { + super(NAME, NeuralStatsResponse::new); + } + + @Override + public Writeable.Reader getResponseReader() { + return NeuralStatsResponse::new; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java new file mode 100644 index 0000000000..bc52c71586 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * NeuralStatsNodeRequest represents the request to an individual node + */ +public class NeuralStatsNodeRequest extends TransportRequest { + private NeuralStatsRequest request; + + /** + * Constructor + */ + public NeuralStatsNodeRequest() { + super(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsNodeRequest(StreamInput in) throws IOException { + super(in); + request = new NeuralStatsRequest(in); + } + + /** + * Constructor + * + * @param request NeuralStatsRequest + */ + public NeuralStatsNodeRequest(NeuralStatsRequest request) { + this.request = request; + } + + /** + * Get NeuralStatsRequest + * + * @return NeuralStatsRequest for this node + */ + public NeuralStatsRequest getNeuralStatsRequest() { + return request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java new file mode 100644 index 0000000000..b9e65a242a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +/** + * NeuralStatsNodeResponse represents the responses generated by an individual node + */ +public class NeuralStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { + + private Map statsMap; + + /** + * Constructor + * + * @param in stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsNodeResponse(StreamInput in) throws IOException { + super(in); + this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); + } + + /** + * Constructor + * + * @param node node + * @param statsToValues mapping of stat name to value + */ + public NeuralStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + super(node); + this.statsMap = statsToValues; + } + + /** + * Creates a new NeuralStatsNodeResponse object and reads in the stats from an input stream + * + * @param in StreamInput to read from + * @return NeuralStatsNodeResponse object corresponding to the input stream + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public static NeuralStatsNodeResponse readStats(StreamInput in) throws IOException { + NeuralStatsNodeResponse neuralStats = new NeuralStatsNodeResponse(in); + return neuralStats; + } + + /** + * Get the map of stats + * + * @return map of stats + */ + public Map getStatsMap() { + return statsMap; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeGenericValue); + } + + /** + * Converts statsMap to xContent + * + * @param builder XContentBuilder + * @param params Params + * @return XContentBuilder + * @throws IOException thrown by builder for invalid field + */ + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + for (String stat : statsMap.keySet()) { + builder.field(stat, statsMap.get(stat)); + } + + return builder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java new file mode 100644 index 0000000000..8f700d23ef --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.neuralsearch.stats.StatNames; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +/** + * NeuralStatsRequest gets node (cluster) level Stats for KNN + * By default, all parameters will be true + */ +public class NeuralStatsRequest extends BaseNodesRequest { + + /** + * Key indicating all stats should be retrieved + */ + public static final String ALL_STATS_KEY = "_all"; + private final Set validStats; + private final Set statsToBeRetrieved; + + /** + * Empty constructor needed for NeuralStatsTransportAction + */ + public NeuralStatsRequest() { + super((String[]) null); + validStats = StatNames.getNames(); + statsToBeRetrieved = new HashSet<>(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsRequest(StreamInput in) throws IOException { + super(in); + validStats = in.readSet(StreamInput::readString); + statsToBeRetrieved = in.readSet(StreamInput::readString); + } + + /** + * Constructor + * + * @param nodeIds NodeIDs from which to retrieve stats + */ + public NeuralStatsRequest(String... nodeIds) { + super(nodeIds); + validStats = StatNames.getNames(); + statsToBeRetrieved = new HashSet<>(); + } + + /** + * Add all stats to be retrieved + */ + public void all() { + statsToBeRetrieved.addAll(validStats); + } + + /** + * Remove all stats from retrieval set + */ + public void clear() { + statsToBeRetrieved.clear(); + } + + /** + * Sets a stats retrieval status to true if it is a valid stat + * @param stat stat name + * @return true if the stats's retrieval status is successfully update; false otherwise + */ + public boolean addStat(String stat) { + if (validStats.contains(stat)) { + statsToBeRetrieved.add(stat); + return true; + } + return false; + } + + /** + * Get the set that tracks which stats should be retrieved + * + * @return the set that contains the stat names marked for retrieval + */ + public Set getStatsToBeRetrieved() { + return statsToBeRetrieved; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringCollection(validStats); + out.writeStringCollection(statsToBeRetrieved); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java new file mode 100644 index 0000000000..3c44f9ad56 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * NeuralStatsResponse consists of the aggregated responses from the nodes + */ +public class NeuralStatsResponse extends BaseNodesResponse implements ToXContentObject { + + private static final String NODES_KEY = "nodes"; + private Map clusterStats; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public NeuralStatsResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(NeuralStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + clusterStats = in.readMap(); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of NeuralStatsNodeResponses + * @param failures List of failures from nodes + * @param clusterStats Cluster level stats only obtained from a single node + */ + public NeuralStatsResponse( + ClusterName clusterName, + List nodes, + List failures, + Map clusterStats + ) { + super(clusterName, nodes, failures); + this.clusterStats = clusterStats; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(clusterStats); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(NeuralStatsNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Return cluster level stats + for (Map.Entry clusterStat : clusterStats.entrySet()) { + builder.field(clusterStat.getKey(), clusterStat.getValue()); + } + + // Return node level stats + String nodeId; + DiscoveryNode node; + builder.startObject(NODES_KEY); + for (NeuralStatsNodeResponse neuralStats : getNodes()) { + node = neuralStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + neuralStats.toXContent(builder, params); + builder.endObject(); + System.out.println("Timothy"); + System.out.println(neuralStats.getStatsMap()); + } + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java new file mode 100644 index 0000000000..e792dba52e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.transport.TransportService; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * NeuralStatsTransportAction contains the logic to extract the stats from the nodes + */ +public class NeuralStatsTransportAction extends TransportNodesAction< + NeuralStatsRequest, + NeuralStatsResponse, + NeuralStatsNodeRequest, + NeuralStatsNodeResponse> { + + private NeuralStats neuralStats; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param neuralStats NeuralStats object + */ + @Inject + public NeuralStatsTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NeuralStats neuralStats + ) { + super( + NeuralStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + NeuralStatsRequest::new, + NeuralStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + NeuralStatsNodeResponse.class + ); + this.neuralStats = neuralStats; + } + + @Override + protected NeuralStatsResponse newResponse( + NeuralStatsRequest request, + List responses, + List failures + ) { + + Map clusterStats = new HashMap<>(); + Set statsToBeRetrieved = request.getStatsToBeRetrieved(); + + for (String statName : neuralStats.getClusterStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); + } + } + + return new NeuralStatsResponse(clusterService.getClusterName(), responses, failures, clusterStats); + } + + @Override + protected NeuralStatsNodeRequest newNodeRequest(NeuralStatsRequest request) { + return new NeuralStatsNodeRequest(request); + } + + @Override + protected NeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new NeuralStatsNodeResponse(in); + } + + @Override + protected NeuralStatsNodeResponse nodeOperation(NeuralStatsNodeRequest request) { + return createNeuralStatsNodeResponse(request.getNeuralStatsRequest()); + } + + private NeuralStatsNodeResponse createNeuralStatsNodeResponse(NeuralStatsRequest neuralStatsRequest) { + Map statValues = new HashMap<>(); + Set statsToBeRetrieved = neuralStatsRequest.getStatsToBeRetrieved(); + + for (String statName : neuralStats.getNodeStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + statValues.put(statName, neuralStats.getStats().get(statName).getValue()); + } + } + return new NeuralStatsNodeResponse(clusterService.localNode(), statValues); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java new file mode 100644 index 0000000000..bb1bf42181 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.rest; + +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.client.Request; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.plugin.NeuralSearch; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Log4j2 +public class RestNeuralStatsHandlerIT extends BaseNeuralSearchIT { + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + public void test_happyCase() throws Exception { + Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(response.getEntity()); + Map clusterStats = parseStatsResponse(responseBody); + log.info(clusterStats.toString()); + assertEquals("Bratwurst", (String) clusterStats.get("Sushi")); + } + + protected Response executeNeuralStatRequest(List nodeIds, List stats) throws IOException { + String nodePrefix = ""; + if (!nodeIds.isEmpty()) { + nodePrefix = "/" + String.join(",", nodeIds); + } + + String statsSuffix = ""; + if (!stats.isEmpty()) { + statsSuffix = "/" + String.join(",", stats); + } + + Request request = new Request("GET", NeuralSearch.NEURAL_BASE_URI + nodePrefix + "/stats" + statsSuffix); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + + protected Map parseStatsResponse(String responseBody) throws IOException { + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + return responseMap; + } + + protected Map parseClusterStatsResponse(String responseBody) throws IOException { + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + responseMap.remove("cluster_name"); + responseMap.remove("_nodes"); + responseMap.remove("nodes"); + return responseMap; + } + + protected List> parseNodeStatsResponse(String responseBody) throws IOException { + @SuppressWarnings("unchecked") + Map responseMap = (Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("nodes"); + + @SuppressWarnings("unchecked") + List> nodeResponses = responseMap.keySet() + .stream() + .map(key -> (Map) responseMap.get(key)) + .collect(Collectors.toList()); + + return nodeResponses; + } +} From 2573ab1113363f8a36c90c2e741ca18c4f149fe0 Mon Sep 17 00:00:00 2001 From: Andy Qin Date: Mon, 27 Jan 2025 16:58:26 -0800 Subject: [PATCH 2/4] Pass and recieve all stats in TreeMap Signed-off-by: Andy Qin --- .../rest/RestNeuralStatsHandler.java | 12 ++--- .../neuralsearch/stats/NeuralStats.java | 41 +++-------------- .../neuralsearch/stats/NeuralStatsInput.java | 16 ++++--- .../neuralsearch/stats/StatNames.java | 42 ------------------ .../stats/names/NeuralClusterLevelStat.java | 4 +- .../stats/names/NeuralNodeLevelStat.java | 4 +- .../names/NeuralSearchProcessorLevelStat.java | 21 +++++++++ .../stats/names/NeuralStatLevel.java | 7 +-- .../transport/NeuralStatsNodeRequest.java | 11 +---- .../transport/NeuralStatsNodeResponse.java | 5 ++- .../transport/NeuralStatsRequest.java | 44 +------------------ .../transport/NeuralStatsResponse.java | 5 ++- .../transport/NeuralStatsTransportAction.java | 20 +++------ .../rest/RestNeuralStatsHandlerIT.java | 7 ++- 14 files changed, 71 insertions(+), 168 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/stats/StatNames.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java diff --git a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java index 90e2662bf3..dfb5d867af 100644 --- a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java +++ b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java @@ -45,6 +45,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + // Read inputs and convert to BaseNodesRequest with correct info configured NeuralStatsRequest neuralStatsRequest = getRequest(request); return channel -> client.execute( @@ -79,17 +80,18 @@ private NeuralStatsRequest getRequest(RestRequest request) { } if (statsSet == null) { - neuralStatsRequest.all(); + } else if (statsSet.size() == 1 && statsSet.contains("_all")) { - neuralStatsRequest.all(); + } else if (statsSet.contains(NeuralStatsRequest.ALL_STATS_KEY)) { throw new IllegalArgumentException("Request " + request.path() + " contains _all and individual stats"); } else { Set invalidStats = new TreeSet<>(); for (String stat : statsSet) { - if (!neuralStatsRequest.addStat(stat)) { - invalidStats.add(stat); - } + // validate request contains valid stats + // if (!neuralStatsRequest.addStat(stat)) { + // invalidStats.add(stat); + // } } if (!invalidStats.isEmpty()) { diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java index 7e6d9593a9..b10e6f0791 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java @@ -4,16 +4,19 @@ */ package org.opensearch.neuralsearch.stats; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentSkipListMap; public class NeuralStats { private final Map> neuralStats; public NeuralStats() { - this.neuralStats = new HashMap<>(); - neuralStats.put(StatNames.NEURAL_QUERY_COUNT.getName(), new NeuralStat<>(() -> "nqc")); - neuralStats.put(StatNames.HYBRID_QUERY_COUNT.getName(), new NeuralStat<>(() -> "hqc")); + this.neuralStats = new ConcurrentSkipListMap<>(); + neuralStats.put("ingest_processor.text_chunking.algorithm.delimiter.execution_count", new NeuralStat<>(() -> "10")); + neuralStats.put("Bratwurst", new NeuralStat<>(() -> "Sushi")); + neuralStats.put("ingest_processor.text_embedding.execution_count", new NeuralStat<>(() -> "3123")); + neuralStats.put("ingest_processor.text_chunking.execution_count", new NeuralStat<>(() -> "777")); + neuralStats.put("ingest_processor.text_chunking.algorithm.fixed_length.execution_count", new NeuralStat<>(() -> "32")); } /** @@ -24,34 +27,4 @@ public NeuralStats() { public Map> getStats() { return neuralStats; } - - /** - * Get a map of the stats that are kept at the node level - * - * @return Map of stats kept at the node level - */ - public Map> getNodeStats() { - return getClusterOrNodeStats(false); - } - - /** - * Get a map of the stats that are kept at the cluster level - * - * @return Map of stats kept at the cluster level - */ - public Map> getClusterStats() { - return getClusterOrNodeStats(true); - } - - private Map> getClusterOrNodeStats(Boolean getClusterStats) { - return neuralStats; - // Map> statsMap = new HashMap<>(); - // - // for (Map.Entry> entry : NeuralStat.entrySet()) { - // if (entry.getValue().isClusterLevel() == getClusterStats) { - // statsMap.put(entry.getKey(), entry.getValue()); - // } - // } - // return statsMap; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java index cabe177e38..8517821064 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java @@ -2,7 +2,6 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.neuralsearch.stats; import lombok.Builder; @@ -75,7 +74,9 @@ public NeuralStatsInput() { public NeuralStatsInput(StreamInput input) throws IOException { targetStatLevels = input.readBoolean() ? input.readEnumSet(NeuralStatLevel.class) : EnumSet.noneOf(NeuralStatLevel.class); - clusterLevelStats = input.readBoolean() ? input.readEnumSet(NeuralClusterLevelStat.class) : EnumSet.noneOf(NeuralClusterLevelStat.class); + clusterLevelStats = input.readBoolean() + ? input.readEnumSet(NeuralClusterLevelStat.class) + : EnumSet.noneOf(NeuralClusterLevelStat.class); nodeLevelStats = input.readBoolean() ? input.readEnumSet(NeuralNodeLevelStat.class) : EnumSet.noneOf(NeuralNodeLevelStat.class); nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); } @@ -110,7 +111,12 @@ public static NeuralStatsInput parse(XContentParser parser) throws IOException { switch (fieldName) { case TARGET_STAT_LEVEL: - parseField(parser, targetStatLevels, input -> NeuralStatLevel.from(input.toUpperCase(Locale.ROOT)), NeuralStatLevel.class); + parseField( + parser, + targetStatLevels, + input -> NeuralStatLevel.from(input.toUpperCase(Locale.ROOT)), + NeuralStatLevel.class + ); break; case CLUSTER_LEVEL_STATS: parseField( @@ -136,8 +142,7 @@ public static NeuralStatsInput parse(XContentParser parser) throws IOException { break; } } - return NeuralStatsInput - .builder() + return NeuralStatsInput.builder() .targetStatLevels(targetStatLevels) .clusterLevelStats(clusterLevelStats) .nodeLevelStats(nodeLevelStats) @@ -172,7 +177,6 @@ public boolean retrieveAllNodeLevelStats() { return nodeLevelStats == null || nodeLevelStats.size() == 0; } - public boolean retrieveStatsOnAllNodes() { return nodeIds == null || nodeIds.size() == 0; } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java b/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java deleted file mode 100644 index d1cc9aa3be..0000000000 --- a/src/main/java/org/opensearch/neuralsearch/stats/StatNames.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.stats; - -import java.util.HashSet; -import java.util.Set; - -public enum StatNames { - NEURAL_QUERY_COUNT("neural_query_count"), - HYBRID_QUERY_COUNT("hybrid_query_count"); - - private String name; - - StatNames(String name) { - this.name = name; - } - - /** - * Get stat name - * - * @return name - */ - public String getName() { - return name; - } - - /** - * Get all stat names - * - * @return set of all stat names - */ - public static Set getNames() { - Set names = new HashSet<>(); - - for (StatNames statName : StatNames.values()) { - names.add(statName.getName()); - } - return names; - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java index 6b9c56800f..ef2495f950 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java @@ -2,11 +2,9 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.neuralsearch.stats.names; -public enum NeuralClusterLevelStat -{ +public enum NeuralClusterLevelStat { FORCE_INFERENCE; public static NeuralClusterLevelStat from(String value) { diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java index de42db3f8b..3d4f79e6f3 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java @@ -2,11 +2,9 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.neuralsearch.stats.names; -public enum NeuralNodeLevelStat -{ +public enum NeuralNodeLevelStat { NEURAL_QUERY_COUNT, HYBRID_QUERY_COUNT; diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java new file mode 100644 index 0000000000..e5ab80a983 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.names; + +public enum NeuralSearchProcessorLevelStat { + TEXT_EMBEDDING_EXECUTION_COUNT, + TEXT_IMAGE_EMBEDDING_EXECUTION_COUNT, + TEXT_IMAGE_EMBEDDING_IMAGE_INFERENCE_COUNT, + TEXT_IMAGE_EMBEDDING_TEXT_INFERENCE_COUNT, + TEXT_CHUNKING_EXECUTION_COUNT; + + public static NeuralSearchProcessorLevelStat from(String value) { + try { + return NeuralSearchProcessorLevelStat.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("No such neural cluster level stat"); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java index e2cfa56a68..da693cf342 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java @@ -2,12 +2,13 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.neuralsearch.stats.names; public enum NeuralStatLevel { CLUSTER, - NODE; + NODE, + SEARCH_PROCESSOR, + INGEST_PROCESSOR; public static NeuralStatLevel from(String value) { try { @@ -16,4 +17,4 @@ public static NeuralStatLevel from(String value) { throw new IllegalArgumentException("No such neural stat level"); } } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java index bc52c71586..0ef5205627 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.transport; +import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; @@ -14,6 +15,7 @@ * NeuralStatsNodeRequest represents the request to an individual node */ public class NeuralStatsNodeRequest extends TransportRequest { + @Getter private NeuralStatsRequest request; /** @@ -43,15 +45,6 @@ public NeuralStatsNodeRequest(NeuralStatsRequest request) { this.request = request; } - /** - * Get NeuralStatsRequest - * - * @return NeuralStatsRequest for this node - */ - public NeuralStatsRequest getNeuralStatsRequest() { - return request; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java index b9e65a242a..ecee3658a4 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Map; +import java.util.TreeMap; /** * NeuralStatsNodeResponse represents the responses generated by an individual node @@ -29,7 +30,7 @@ public class NeuralStatsNodeResponse extends BaseNodeResponse implements ToXCont */ public NeuralStatsNodeResponse(StreamInput in) throws IOException { super(in); - this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); + this.statsMap = new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readGenericValue)); } /** @@ -40,7 +41,7 @@ public NeuralStatsNodeResponse(StreamInput in) throws IOException { */ public NeuralStatsNodeResponse(DiscoveryNode node, Map statsToValues) { super(node); - this.statsMap = statsToValues; + this.statsMap = new TreeMap<>(statsToValues); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java index 8f700d23ef..2cb5e5b33c 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.transport; +import lombok.Getter; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.neuralsearch.stats.StatNames; import java.io.IOException; import java.util.HashSet; @@ -23,7 +23,7 @@ public class NeuralStatsRequest extends BaseNodesRequest { * Key indicating all stats should be retrieved */ public static final String ALL_STATS_KEY = "_all"; - private final Set validStats; + @Getter private final Set statsToBeRetrieved; /** @@ -31,7 +31,6 @@ public class NeuralStatsRequest extends BaseNodesRequest { */ public NeuralStatsRequest() { super((String[]) null); - validStats = StatNames.getNames(); statsToBeRetrieved = new HashSet<>(); } @@ -43,7 +42,6 @@ public NeuralStatsRequest() { */ public NeuralStatsRequest(StreamInput in) throws IOException { super(in); - validStats = in.readSet(StreamInput::readString); statsToBeRetrieved = in.readSet(StreamInput::readString); } @@ -54,50 +52,12 @@ public NeuralStatsRequest(StreamInput in) throws IOException { */ public NeuralStatsRequest(String... nodeIds) { super(nodeIds); - validStats = StatNames.getNames(); statsToBeRetrieved = new HashSet<>(); } - /** - * Add all stats to be retrieved - */ - public void all() { - statsToBeRetrieved.addAll(validStats); - } - - /** - * Remove all stats from retrieval set - */ - public void clear() { - statsToBeRetrieved.clear(); - } - - /** - * Sets a stats retrieval status to true if it is a valid stat - * @param stat stat name - * @return true if the stats's retrieval status is successfully update; false otherwise - */ - public boolean addStat(String stat) { - if (validStats.contains(stat)) { - statsToBeRetrieved.add(stat); - return true; - } - return false; - } - - /** - * Get the set that tracks which stats should be retrieved - * - * @return the set that contains the stat names marked for retrieval - */ - public Set getStatsToBeRetrieved() { - return statsToBeRetrieved; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeStringCollection(validStats); out.writeStringCollection(statsToBeRetrieved); } } diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java index 3c44f9ad56..cee4032ad1 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.TreeMap; /** * NeuralStatsResponse consists of the aggregated responses from the nodes @@ -33,7 +34,7 @@ public class NeuralStatsResponse extends BaseNodesResponse(in.readMap()); } /** @@ -51,7 +52,7 @@ public NeuralStatsResponse( Map clusterStats ) { super(clusterName, nodes, failures); - this.clusterStats = clusterStats; + this.clusterStats = new TreeMap<>(clusterStats); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java index e792dba52e..5430149830 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java @@ -18,7 +18,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; /** * NeuralStatsTransportAction contains the logic to extract the stats from the nodes @@ -70,13 +69,11 @@ protected NeuralStatsResponse newResponse( ) { Map clusterStats = new HashMap<>(); - Set statsToBeRetrieved = request.getStatsToBeRetrieved(); - for (String statName : neuralStats.getClusterStats().keySet()) { - if (statsToBeRetrieved.contains(statName)) { - clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); - } + for (String statName : neuralStats.getStats().keySet()) { + clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); } + System.out.println(clusterStats); return new NeuralStatsResponse(clusterService.getClusterName(), responses, failures, clusterStats); } @@ -93,17 +90,10 @@ protected NeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOExcep @Override protected NeuralStatsNodeResponse nodeOperation(NeuralStatsNodeRequest request) { - return createNeuralStatsNodeResponse(request.getNeuralStatsRequest()); - } - - private NeuralStatsNodeResponse createNeuralStatsNodeResponse(NeuralStatsRequest neuralStatsRequest) { Map statValues = new HashMap<>(); - Set statsToBeRetrieved = neuralStatsRequest.getStatsToBeRetrieved(); - for (String statName : neuralStats.getNodeStats().keySet()) { - if (statsToBeRetrieved.contains(statName)) { - statValues.put(statName, neuralStats.getStats().get(statName).getValue()); - } + for (String statName : neuralStats.getStats().keySet()) { + statValues.put(statName, neuralStats.getStats().get(statName).getValue()); } return new NeuralStatsNodeResponse(clusterService.localNode(), statValues); } diff --git a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java index bb1bf42181..64040a4322 100644 --- a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java +++ b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java @@ -32,8 +32,11 @@ public void test_happyCase() throws Exception { Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); String responseBody = EntityUtils.toString(response.getEntity()); Map clusterStats = parseStatsResponse(responseBody); - log.info(clusterStats.toString()); - assertEquals("Bratwurst", (String) clusterStats.get("Sushi")); + log.info("rest_it_pasta"); + for (Map.Entry entry : clusterStats.entrySet()) { + log.info(entry.toString(), entry.getValue()); + } + assertEquals("Sushi", (String) clusterStats.get("Bratwurst")); } protected Response executeNeuralStatRequest(List nodeIds, List stats) throws IOException { From 60a78df73b3ef552c2c8288b43a21fac33cddae2 Mon Sep 17 00:00:00 2001 From: Andy Qin Date: Thu, 6 Feb 2025 12:53:57 -0800 Subject: [PATCH 3/4] Refactor stats, add text embedding stats, add integ tests Signed-off-by: Andy Qin --- .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/TextEmbeddingProcessor.java | 4 + .../rest/RestNeuralStatsHandler.java | 65 +++++--- .../neuralsearch/stats/DerivedStats.java | 73 +++++++++ .../neuralsearch/stats/NeuralStat.java | 20 +++ .../neuralsearch/stats/NeuralStatBuilder.java | 53 +++++++ .../neuralsearch/stats/NeuralStats.java | 54 +++++-- .../neuralsearch/stats/NeuralStatsInput.java | 110 +------------ .../stats/names/NeuralClusterLevelStat.java | 17 -- .../stats/names/NeuralNodeLevelStat.java | 18 --- .../names/NeuralSearchProcessorLevelStat.java | 21 --- .../stats/names/NeuralStatLevel.java | 20 --- .../neuralsearch/stats/names/StatName.java | 48 ++++++ .../neuralsearch/stats/names/StatType.java | 11 ++ .../stats/suppliers/CounterSupplier.java | 41 +++++ .../stats/suppliers/DerivedSupplier.java | 31 ++++ .../transport/ClearNeuralStatsAction.java | 19 +++ .../ClearNeuralStatsNodeRequest.java | 53 +++++++ .../ClearNeuralStatsNodeResponse.java | 41 +++++ .../transport/ClearNeuralStatsRequest.java | 31 ++++ .../transport/ClearNeuralStatsResponse.java | 62 ++++++++ .../ClearNeuralStatsTransportAction.java | 80 ++++++++++ .../transport/NeuralStatsNodeResponse.java | 20 +-- .../transport/NeuralStatsRequest.java | 17 +- .../transport/NeuralStatsResponse.java | 44 +++++- .../transport/NeuralStatsTransportAction.java | 20 ++- .../rest/RestNeuralStatsHandlerIT.java | 147 ++++++++++++++++-- 27 files changed, 863 insertions(+), 264 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java create mode 100644 src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java create mode 100644 src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 68aff564c0..59a6dfc9eb 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -64,6 +64,8 @@ import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.neuralsearch.rest.RestNeuralStatsHandler; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsTransportAction; import org.opensearch.neuralsearch.transport.NeuralStatsAction; import org.opensearch.neuralsearch.transport.NeuralStatsTransportAction; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; @@ -146,7 +148,10 @@ public List getRestHandlers( @Override public List> getActions() { - return Arrays.asList(new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class)); + return Arrays.asList( + new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class), + new ActionHandler<>(ClearNeuralStatsAction.INSTANCE, ClearNeuralStatsTransportAction.class) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d8..ff6f69bac6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -16,6 +16,8 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.neuralsearch.stats.names.StatName; /** * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, @@ -47,6 +49,8 @@ public void doExecute( List inferenceList, BiConsumer handler ) { + NeuralStats.record(StatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS).increment(); + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); diff --git a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java index dfb5d867af..df3eefb904 100644 --- a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java +++ b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java @@ -9,6 +9,9 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; import org.opensearch.client.node.NodeClient; +import org.opensearch.neuralsearch.stats.NeuralStatsInput; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsRequest; import org.opensearch.neuralsearch.transport.NeuralStatsAction; import org.opensearch.neuralsearch.transport.NeuralStatsRequest; import org.opensearch.rest.BaseRestHandler; @@ -19,7 +22,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.TreeSet; import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_BASE_URI; @@ -27,6 +29,7 @@ @AllArgsConstructor public class RestNeuralStatsHandler extends BaseRestHandler { private static final String NAME = "neural_stats_action"; + public static final String CLEAR_PARAM = "_clear"; @Override public String getName() { @@ -45,14 +48,33 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { - // Read inputs and convert to BaseNodesRequest with correct info configured - NeuralStatsRequest neuralStatsRequest = getRequest(request); + if (request.param("stat", "").equals(CLEAR_PARAM)) { + // TODO : Hacky, possible collisions. Should be refactored into separate endpoint later + String[] nodeIdsArr = null; + String nodesIdsStr = request.param("nodeId"); + if (StringUtils.isNotEmpty(nodesIdsStr)) { + nodeIdsArr = nodesIdsStr.split(","); + } + + ClearNeuralStatsRequest clearNeuralStatsRequest = new ClearNeuralStatsRequest(nodeIdsArr); + clearNeuralStatsRequest.timeout(request.param("timeout")); + + return channel -> client.execute( + ClearNeuralStatsAction.INSTANCE, + clearNeuralStatsRequest, + new RestActions.NodesResponseRestListener<>(channel) + ); + } else { + // Read inputs and convert to BaseNodesRequest with correct info configured + NeuralStatsRequest neuralStatsRequest = getRequest(request); + + return channel -> client.execute( + NeuralStatsAction.INSTANCE, + neuralStatsRequest, + new RestActions.NodesResponseRestListener<>(channel) + ); + } - return channel -> client.execute( - NeuralStatsAction.INSTANCE, - neuralStatsRequest, - new RestActions.NodesResponseRestListener<>(channel) - ); } /** @@ -69,7 +91,7 @@ private NeuralStatsRequest getRequest(RestRequest request) { nodeIdsArr = nodesIdsStr.split(","); } - NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr); + NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr, new NeuralStatsInput()); neuralStatsRequest.timeout(request.param("timeout")); // parse the stats the customer wants to see @@ -86,22 +108,23 @@ private NeuralStatsRequest getRequest(RestRequest request) { } else if (statsSet.contains(NeuralStatsRequest.ALL_STATS_KEY)) { throw new IllegalArgumentException("Request " + request.path() + " contains _all and individual stats"); } else { - Set invalidStats = new TreeSet<>(); - for (String stat : statsSet) { - // validate request contains valid stats - // if (!neuralStatsRequest.addStat(stat)) { - // invalidStats.add(stat); - // } - } - - if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException(unrecognized(request, invalidStats, neuralStatsRequest.getStatsToBeRetrieved(), "stat")); - } + // Validate NeuralStats input is valid + // Set invalidStats = new TreeSet<>(); + // for (String stat : statsSet) { + // // validate request contains valid stats + // // if (!neuralStatsRequest.addStat(stat)) { + // // invalidStats.add(stat); + // // } + // } + // + // if (!invalidStats.isEmpty()) { + // throw new IllegalArgumentException(unrecognized(request, invalidStats, neuralStatsRequest.getNeuralStatsInput(), "stat")); + // } } log.info("pasta"); log.info(statsSet); - log.info(neuralStatsRequest.getStatsToBeRetrieved()); + log.info(neuralStatsRequest.getNeuralStatsInput()); return neuralStatsRequest; } } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java b/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java new file mode 100644 index 0000000000..4dcc5857f8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import org.opensearch.neuralsearch.stats.suppliers.DerivedSupplier; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.function.Function; + +public class DerivedStats { + private static final String AGG_KEY_PREFIX = "all_nodes."; + private static DerivedStats INSTANCE; + + public static DerivedStats instance() { + if (INSTANCE == null) { + INSTANCE = new DerivedStats(); + } + return INSTANCE; + } + + private final Map> derivedStatsMap; + private Map aggregatedNodeResponse; + + public DerivedStats() { + this.derivedStatsMap = new ConcurrentSkipListMap<>(); + register("derived.cluster_version", DerivedStats::clusterVersion); + } + + public Map aggregateNodesResponses(List> nodeResponses) { + Map summedMap = new HashMap<>(); + for (Map map : nodeResponses) { + for (Map.Entry entry : map.entrySet()) { + summedMap.merge(AGG_KEY_PREFIX + entry.getKey(), entry.getValue(), Long::sum); + } + } + return summedMap; + } + + public Map addDerivedStats(List> nodeResponses) { + // Reference to provide derived methods access to node Responses + this.aggregatedNodeResponse = aggregateNodesResponses(nodeResponses); + + Map computedDerivedStats = new TreeMap<>(); + for (Map.Entry> neuralStatEntry : derivedStatsMap.entrySet()) { + computedDerivedStats.put(neuralStatEntry.getKey(), neuralStatEntry.getValue().getValue()); + } + + computedDerivedStats.putAll(aggregatedNodeResponse); + // Reset reference to not store + this.aggregatedNodeResponse = null; + return computedDerivedStats; + } + + private void register(String statPath, Function derivedMethod) { + if (derivedStatsMap.containsKey(statPath)) { + // Validation error here + return; + } + NeuralStat neuralStat = new NeuralStat<>(new DerivedSupplier<>(this, derivedMethod)); + derivedStatsMap.put(statPath, neuralStat); + } + + private String clusterVersion() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().toString(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java index 4acbee527e..e672b13bde 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.stats; +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + import java.util.function.Supplier; public class NeuralStat { @@ -16,4 +18,22 @@ public NeuralStat(Supplier supplier) { public T getValue() { return supplier.get(); } + + /** + * Increments the supplier if it can be incremented + */ + public void increment() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).increment(); + } + } + + /** + * Decrease the supplier if it can be decreased. + */ + public void decrement() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).decrement(); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java new file mode 100644 index 0000000000..950fe54614 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + +import java.util.ArrayList; +import java.util.List; + +public class NeuralStatBuilder { + public static final String DELIMITER = "."; + private NeuralStats neuralStats; + private List path; + + NeuralStatBuilder(NeuralStats neuralStats) { + this.neuralStats = neuralStats; + this.path = new ArrayList<>(); + } + + public String getPathString() { + return String.join(DELIMITER, path); + } + + public NeuralStatBuilder ingestProcessor(String processor) { + // Add a validation check here + // Refactor to constnats + path.add("ingest_processor"); + path.add(processor); + return this; + } + + public NeuralStatBuilder searchProcessor(String processor) { + // Add a validation check here + // Refactor to constnats + path.add("search_processor"); + path.add(processor); + return this; + } + + public NeuralStatBuilder metric(String metric) { + // should do some validation here + path.add(metric); + return this; + } + + public void increment() { + // Should do some extra validation here + // SOme kind of "validate path" method that makes sure the path is valid? + neuralStats.getStats().computeIfAbsent(getPathString(), k -> new NeuralStat<>(new CounterSupplier())).increment(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java index b10e6f0791..065ec2d0c3 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java @@ -4,19 +4,44 @@ */ package org.opensearch.neuralsearch.stats; +import com.google.common.annotations.VisibleForTesting; +import org.opensearch.neuralsearch.stats.names.StatName; +import org.opensearch.neuralsearch.stats.names.StatType; +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + +import java.util.EnumSet; import java.util.Map; import java.util.concurrent.ConcurrentSkipListMap; public class NeuralStats { - private final Map> neuralStats; + private Map> counterStatsMap; + + public static NeuralStats INSTANCE; + + public static NeuralStats instance() { + if (INSTANCE == null) { + INSTANCE = new NeuralStats(); + } + return INSTANCE; + } + + public static NeuralStatBuilder recordMetric() { + return new NeuralStatBuilder(instance()); + } + + public static NeuralStat record(StatName statName) { + return instance().getStats().computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } public NeuralStats() { - this.neuralStats = new ConcurrentSkipListMap<>(); - neuralStats.put("ingest_processor.text_chunking.algorithm.delimiter.execution_count", new NeuralStat<>(() -> "10")); - neuralStats.put("Bratwurst", new NeuralStat<>(() -> "Sushi")); - neuralStats.put("ingest_processor.text_embedding.execution_count", new NeuralStat<>(() -> "3123")); - neuralStats.put("ingest_processor.text_chunking.execution_count", new NeuralStat<>(() -> "777")); - neuralStats.put("ingest_processor.text_chunking.algorithm.fixed_length.execution_count", new NeuralStat<>(() -> "32")); + this.counterStatsMap = new ConcurrentSkipListMap<>(); + + // Initialize event counter stats + for (StatName statName : EnumSet.allOf(StatName.class)) { + if (statName.getStatType() == StatType.COUNTER_EVENT) { + counterStatsMap.computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } + } } /** @@ -24,7 +49,18 @@ public NeuralStats() { * * @return all the stats */ - public Map> getStats() { - return neuralStats; + public Map> getStats() { + return counterStatsMap; + } + + @VisibleForTesting + public void resetStats() { + // Risk of memory leak? + this.counterStatsMap = new ConcurrentSkipListMap<>(); + for (StatName statName : EnumSet.allOf(StatName.class)) { + if (statName.getStatType() == StatType.COUNTER_EVENT) { + counterStatsMap.computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } + } } } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java index 8517821064..c624e90d9c 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java @@ -13,14 +13,10 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.neuralsearch.stats.names.NeuralClusterLevelStat; -import org.opensearch.neuralsearch.stats.names.NeuralNodeLevelStat; -import org.opensearch.neuralsearch.stats.names.NeuralStatLevel; import java.io.IOException; import java.util.EnumSet; import java.util.HashSet; -import java.util.Locale; import java.util.Set; import java.util.function.Function; @@ -28,64 +24,28 @@ @Getter public class NeuralStatsInput implements ToXContentObject, Writeable { - public static final String TARGET_STAT_LEVEL = "target_stat_levels"; - public static final String CLUSTER_LEVEL_STATS = "cluster_level_stats"; - public static final String NODE_LEVEL_STATS = "node_level_stats"; public static final String NODE_IDS = "node_ids"; - /** - * Determines levels of stats to retrieve. If empty, will not retrieve any - */ - private EnumSet targetStatLevels; - /** - * Which cluster level stats will be retrieved. - */ - private EnumSet clusterLevelStats; - - /** - * Which node level stats will be retrieved. - */ - private EnumSet nodeLevelStats; - /** * Which node's stats will be retrieved. */ private Set nodeIds; @Builder - public NeuralStatsInput( - EnumSet targetStatLevels, - EnumSet clusterLevelStats, - EnumSet nodeLevelStats, - Set nodeIds - ) { - this.targetStatLevels = targetStatLevels; - this.clusterLevelStats = clusterLevelStats; - this.nodeLevelStats = nodeLevelStats; + public NeuralStatsInput(Set nodeIds) { this.nodeIds = nodeIds; } public NeuralStatsInput() { - this.targetStatLevels = EnumSet.noneOf(NeuralStatLevel.class); - this.clusterLevelStats = EnumSet.noneOf(NeuralClusterLevelStat.class); - this.nodeLevelStats = EnumSet.noneOf(NeuralNodeLevelStat.class); this.nodeIds = new HashSet<>(); } public NeuralStatsInput(StreamInput input) throws IOException { - targetStatLevels = input.readBoolean() ? input.readEnumSet(NeuralStatLevel.class) : EnumSet.noneOf(NeuralStatLevel.class); - clusterLevelStats = input.readBoolean() - ? input.readEnumSet(NeuralClusterLevelStat.class) - : EnumSet.noneOf(NeuralClusterLevelStat.class); - nodeLevelStats = input.readBoolean() ? input.readEnumSet(NeuralNodeLevelStat.class) : EnumSet.noneOf(NeuralNodeLevelStat.class); nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); } @Override public void writeTo(StreamOutput out) throws IOException { - writeEnumSet(out, targetStatLevels); - writeEnumSet(out, clusterLevelStats); - writeEnumSet(out, nodeLevelStats); out.writeOptionalStringCollection(nodeIds); } @@ -99,9 +59,6 @@ private void writeEnumSet(StreamOutput out, EnumSet set) throws IOException { } public static NeuralStatsInput parse(XContentParser parser) throws IOException { - EnumSet targetStatLevels = EnumSet.noneOf(NeuralStatLevel.class); - EnumSet clusterLevelStats = EnumSet.noneOf(NeuralClusterLevelStat.class); - EnumSet nodeLevelStats = EnumSet.noneOf(NeuralNodeLevelStat.class); Set nodeIds = new HashSet<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -110,30 +67,6 @@ public static NeuralStatsInput parse(XContentParser parser) throws IOException { parser.nextToken(); switch (fieldName) { - case TARGET_STAT_LEVEL: - parseField( - parser, - targetStatLevels, - input -> NeuralStatLevel.from(input.toUpperCase(Locale.ROOT)), - NeuralStatLevel.class - ); - break; - case CLUSTER_LEVEL_STATS: - parseField( - parser, - clusterLevelStats, - input -> NeuralClusterLevelStat.from(input.toUpperCase(Locale.ROOT)), - NeuralClusterLevelStat.class - ); - break; - case NODE_LEVEL_STATS: - parseField( - parser, - nodeLevelStats, - input -> NeuralNodeLevelStat.from(input.toUpperCase(Locale.ROOT)), - NeuralNodeLevelStat.class - ); - break; case NODE_IDS: parseArrayField(parser, nodeIds); break; @@ -142,26 +75,12 @@ public static NeuralStatsInput parse(XContentParser parser) throws IOException { break; } } - return NeuralStatsInput.builder() - .targetStatLevels(targetStatLevels) - .clusterLevelStats(clusterLevelStats) - .nodeLevelStats(nodeLevelStats) - .nodeIds(nodeIds) - .build(); + return NeuralStatsInput.builder().nodeIds(nodeIds).build(); } @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); - if (targetStatLevels != null) { - builder.field(TARGET_STAT_LEVEL, targetStatLevels); - } - if (clusterLevelStats != null) { - builder.field(CLUSTER_LEVEL_STATS, clusterLevelStats); - } - if (nodeLevelStats != null) { - builder.field(NODE_LEVEL_STATS, nodeLevelStats); - } if (nodeIds != null) { builder.field(NODE_IDS, nodeIds); } @@ -169,35 +88,10 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } - public boolean retrieveAllClusterLevelStats() { - return clusterLevelStats == null || clusterLevelStats.size() == 0; - } - - public boolean retrieveAllNodeLevelStats() { - return nodeLevelStats == null || nodeLevelStats.size() == 0; - } - public boolean retrieveStatsOnAllNodes() { return nodeIds == null || nodeIds.size() == 0; } - public boolean retrieveStat(Enum key) { - if (key instanceof NeuralClusterLevelStat) { - return retrieveAllClusterLevelStats() || clusterLevelStats.contains(key); - } - if (key instanceof NeuralNodeLevelStat) { - return retrieveAllNodeLevelStats() || nodeLevelStats.contains(key); - } - return false; - } - - public boolean onlyRetrieveClusterLevelStats() { - if (targetStatLevels == null || targetStatLevels.size() == 0) { - return false; - } - return !targetStatLevels.contains(NeuralStatLevel.NODE); - } - public static void parseArrayField(XContentParser parser, Set set) throws IOException { parseField(parser, set, null, String.class); } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java deleted file mode 100644 index ef2495f950..0000000000 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralClusterLevelStat.java +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.stats.names; - -public enum NeuralClusterLevelStat { - FORCE_INFERENCE; - - public static NeuralClusterLevelStat from(String value) { - try { - return NeuralClusterLevelStat.valueOf(value); - } catch (Exception e) { - throw new IllegalArgumentException("No such neural cluster level stat"); - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java deleted file mode 100644 index 3d4f79e6f3..0000000000 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralNodeLevelStat.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.stats.names; - -public enum NeuralNodeLevelStat { - NEURAL_QUERY_COUNT, - HYBRID_QUERY_COUNT; - - public static NeuralNodeLevelStat from(String value) { - try { - return NeuralNodeLevelStat.valueOf(value); - } catch (Exception e) { - throw new IllegalArgumentException("No such neural node level stat"); - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java deleted file mode 100644 index e5ab80a983..0000000000 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralSearchProcessorLevelStat.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.stats.names; - -public enum NeuralSearchProcessorLevelStat { - TEXT_EMBEDDING_EXECUTION_COUNT, - TEXT_IMAGE_EMBEDDING_EXECUTION_COUNT, - TEXT_IMAGE_EMBEDDING_IMAGE_INFERENCE_COUNT, - TEXT_IMAGE_EMBEDDING_TEXT_INFERENCE_COUNT, - TEXT_CHUNKING_EXECUTION_COUNT; - - public static NeuralSearchProcessorLevelStat from(String value) { - try { - return NeuralSearchProcessorLevelStat.valueOf(value); - } catch (Exception e) { - throw new IllegalArgumentException("No such neural cluster level stat"); - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java b/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java deleted file mode 100644 index da693cf342..0000000000 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/NeuralStatLevel.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.stats.names; - -public enum NeuralStatLevel { - CLUSTER, - NODE, - SEARCH_PROCESSOR, - INGEST_PROCESSOR; - - public static NeuralStatLevel from(String value) { - try { - return NeuralStatLevel.valueOf(value); - } catch (Exception e) { - throw new IllegalArgumentException("No such neural stat level"); - } - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java new file mode 100644 index 0000000000..1d37480550 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.names; + +import lombok.Getter; + +import java.util.HashSet; +import java.util.Set; + +public enum StatName { + // TODO : stat type not currently used + EVENT_STAT("example.counter", StatType.COUNTER_EVENT), + INFO_DERIVED_STAT("example.counter", StatType.INFO_DERIVED), + STAT_DERIVED_STAT("example.counter", StatType.STAT_DERIVED), + + TEXT_EMBEDDING_PROCESSOR_EXECUTIONS("ingest_processor.text_embedding.executions", StatType.COUNTER_EVENT); + + @Getter + private final String name; + @Getter + private final StatType statType; + + StatName(String name, StatType statType) { + this.name = name; + this.statType = statType; + } + + /** + * Get all stat names + * + * @return set of all stat names + */ + public static Set getNames() { + Set names = new HashSet<>(); + + for (StatName statName : StatName.values()) { + names.add(statName.getName()); + } + return names; + } + + @Override + public String toString() { + return getName(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java b/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java new file mode 100644 index 0000000000..e63d4c3b88 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.names; + +public enum StatType { + INFO_DERIVED, + STAT_DERIVED, + COUNTER_EVENT; +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java new file mode 100644 index 0000000000..0d3d81f2db --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.suppliers; + +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Supplier; + +/** + * CounterSupplier provides a stateful count as the value + */ +public class CounterSupplier implements Supplier { + private LongAdder counter; + + /** + * Constructor + */ + public CounterSupplier() { + this.counter = new LongAdder(); + } + + @Override + public Long get() { + return counter.longValue(); + } + + /** + * Increments the value of the counter by 1 + */ + public void increment() { + counter.increment(); + } + + /** + * Decrease the value of the counter by 1 + */ + public void decrement() { + counter.decrement(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java new file mode 100644 index 0000000000..d89c863398 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.suppliers; + +import org.opensearch.neuralsearch.stats.DerivedStats; + +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * DerivedSupplier a derived stat value + */ +public class DerivedSupplier implements Supplier { + private DerivedStats derivedStats; + private Function getter; + + /** + * Constructor + */ + public DerivedSupplier(DerivedStats derivedStats, Function getter) { + this.derivedStats = derivedStats; + this.getter = getter; + } + + @Override + public T get() { + return getter.apply(derivedStats); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java new file mode 100644 index 0000000000..187859b968 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionType; + +public class ClearNeuralStatsAction extends ActionType { + public static final ClearNeuralStatsAction INSTANCE = new ClearNeuralStatsAction(); + public static final String NAME = "cluster:admin/clear_neural_stats_action"; // TODO : figure this out + + /** + * Constructor + */ + private ClearNeuralStatsAction() { + super(NAME, ClearNeuralStatsResponse::new); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java new file mode 100644 index 0000000000..e118f32830 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * ClearNeuralStatsNodeRequest represents the request to an individual node + */ +public class ClearNeuralStatsNodeRequest extends TransportRequest { + @Getter + private ClearNeuralStatsRequest request; + + /** + * Constructor + */ + public ClearNeuralStatsNodeRequest() { + super(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public ClearNeuralStatsNodeRequest(StreamInput in) throws IOException { + super(in); + request = new ClearNeuralStatsRequest(in); + } + + /** + * Constructor + * + * @param request ClearNeuralStatsRequest + */ + public ClearNeuralStatsNodeRequest(ClearNeuralStatsRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java new file mode 100644 index 0000000000..d9f84d7fe3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * ClearNeuralStatsNodeResponse represents the responses generated by an individual node + */ +public class ClearNeuralStatsNodeResponse extends BaseNodeResponse { + /** + * Constructor + * + * @param in stream + * @throws IOException in case of I/O errors + */ + public ClearNeuralStatsNodeResponse(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param node node + */ + public ClearNeuralStatsNodeResponse(DiscoveryNode node) { + super(node); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java new file mode 100644 index 0000000000..d8ce5eb9d4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +public class ClearNeuralStatsRequest extends BaseNodesRequest { + public ClearNeuralStatsRequest(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param nodeIds NodeIDs from which to retrieve stats + */ + public ClearNeuralStatsRequest(String[] nodeIds) { + super(nodeIds); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java new file mode 100644 index 0000000000..0f0b0fd703 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class ClearNeuralStatsResponse extends BaseNodesResponse implements ToXContentObject { + private static final String NODES_KEY = "nodes"; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public ClearNeuralStatsResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ClearNeuralStatsNodeResponse::new), in.readList(FailedNodeException::new)); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of NeuralStatsNodeResponses + * @param failures List of failures from nodes + */ + public ClearNeuralStatsResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ClearNeuralStatsNodeResponse::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // TODO : response should go here + return builder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java new file mode 100644 index 0000000000..abc2b3dbad --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * ClearNeuralStatsTransportAction contains the logic to clear all nodes stats + */ +public class ClearNeuralStatsTransportAction extends TransportNodesAction< + ClearNeuralStatsRequest, + ClearNeuralStatsResponse, + ClearNeuralStatsNodeRequest, + ClearNeuralStatsNodeResponse> { + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + */ + @Inject + public ClearNeuralStatsTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters + ) { + super( + ClearNeuralStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ClearNeuralStatsRequest::new, + ClearNeuralStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ClearNeuralStatsNodeResponse.class + ); + } + + @Override + protected ClearNeuralStatsResponse newResponse( + ClearNeuralStatsRequest request, + List responses, + List failures + ) { + return new ClearNeuralStatsResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ClearNeuralStatsNodeRequest newNodeRequest(ClearNeuralStatsRequest request) { + return new ClearNeuralStatsNodeRequest(request); + } + + @Override + protected ClearNeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ClearNeuralStatsNodeResponse(in); + } + + @Override + protected ClearNeuralStatsNodeResponse nodeOperation(ClearNeuralStatsNodeRequest request) { + NeuralStats.instance().resetStats(); + return new ClearNeuralStatsNodeResponse(clusterService.localNode()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java index ecee3658a4..b9afc87271 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.transport; +import lombok.Getter; import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -20,7 +21,8 @@ */ public class NeuralStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { - private Map statsMap; + @Getter + private Map statsMap; /** * Constructor @@ -30,7 +32,7 @@ public class NeuralStatsNodeResponse extends BaseNodeResponse implements ToXCont */ public NeuralStatsNodeResponse(StreamInput in) throws IOException { super(in); - this.statsMap = new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readGenericValue)); + this.statsMap = new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readLong)); } /** @@ -39,7 +41,7 @@ public NeuralStatsNodeResponse(StreamInput in) throws IOException { * @param node node * @param statsToValues mapping of stat name to value */ - public NeuralStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + public NeuralStatsNodeResponse(DiscoveryNode node, Map statsToValues) { super(node); this.statsMap = new TreeMap<>(statsToValues); } @@ -56,19 +58,10 @@ public static NeuralStatsNodeResponse readStats(StreamInput in) throws IOExcepti return neuralStats; } - /** - * Get the map of stats - * - * @return map of stats - */ - public Map getStatsMap() { - return statsMap; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeGenericValue); + out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeLong); } /** @@ -83,7 +76,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (String stat : statsMap.keySet()) { builder.field(stat, statsMap.get(stat)); } - return builder; } } diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java index 2cb5e5b33c..ecd40ee510 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java @@ -8,13 +8,12 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.neuralsearch.stats.NeuralStatsInput; import java.io.IOException; -import java.util.HashSet; -import java.util.Set; /** - * NeuralStatsRequest gets node (cluster) level Stats for KNN + * NeuralStatsRequest gets node (cluster) level Stats for Neural * By default, all parameters will be true */ public class NeuralStatsRequest extends BaseNodesRequest { @@ -24,14 +23,14 @@ public class NeuralStatsRequest extends BaseNodesRequest { */ public static final String ALL_STATS_KEY = "_all"; @Getter - private final Set statsToBeRetrieved; + private final NeuralStatsInput neuralStatsInput; /** * Empty constructor needed for NeuralStatsTransportAction */ public NeuralStatsRequest() { super((String[]) null); - statsToBeRetrieved = new HashSet<>(); + this.neuralStatsInput = new NeuralStatsInput(); } /** @@ -42,7 +41,7 @@ public NeuralStatsRequest() { */ public NeuralStatsRequest(StreamInput in) throws IOException { super(in); - statsToBeRetrieved = in.readSet(StreamInput::readString); + this.neuralStatsInput = new NeuralStatsInput(in); } /** @@ -50,14 +49,14 @@ public NeuralStatsRequest(StreamInput in) throws IOException { * * @param nodeIds NodeIDs from which to retrieve stats */ - public NeuralStatsRequest(String... nodeIds) { + public NeuralStatsRequest(String[] nodeIds, NeuralStatsInput neuralStatsInput) { super(nodeIds); - statsToBeRetrieved = new HashSet<>(); + this.neuralStatsInput = neuralStatsInput; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeStringCollection(statsToBeRetrieved); + neuralStatsInput.writeTo(out); } } diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java index cee4032ad1..40193cd0f9 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.TreeMap; @@ -74,24 +75,51 @@ public List readNodesFrom(StreamInput in) throws IOExce @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { // Return cluster level stats - for (Map.Entry clusterStat : clusterStats.entrySet()) { - builder.field(clusterStat.getKey(), clusterStat.getValue()); - } + Map nestedClusterStats = convertFlatToNestedMap(new TreeMap<>(clusterStats)); + buildNestedMapXContent(builder, nestedClusterStats); // Return node level stats String nodeId; DiscoveryNode node; builder.startObject(NODES_KEY); - for (NeuralStatsNodeResponse neuralStats : getNodes()) { - node = neuralStats.getNode(); + for (NeuralStatsNodeResponse neuralStatsResponse : getNodes()) { + node = neuralStatsResponse.getNode(); nodeId = node.getId(); builder.startObject(nodeId); - neuralStats.toXContent(builder, params); + Map nestedMap = convertFlatToNestedMap(new TreeMap<>(neuralStatsResponse.getStatsMap())); + buildNestedMapXContent(builder, nestedMap); builder.endObject(); - System.out.println("Timothy"); - System.out.println(neuralStats.getStatsMap()); } builder.endObject(); return builder; } + + private void buildNestedMapXContent(XContentBuilder builder, Map map) throws IOException { + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() instanceof Map) { + builder.startObject(entry.getKey()); + buildNestedMapXContent(builder, (Map) entry.getValue()); + builder.endObject(); + } else { + builder.field(entry.getKey(), entry.getValue()); + } + } + } + + private Map convertFlatToNestedMap(Map map) { + Map nestedMap = new TreeMap<>(); + for (Map.Entry entry : map.entrySet()) { + putNested(nestedMap, entry.getKey(), entry.getValue()); + } + return nestedMap; + } + + private void putNested(Map map, String path, Object value) { + String[] parts = path.split("\\."); + Map current = map; + for (int i = 0; i < parts.length - 1; i++) { + current = (Map) current.computeIfAbsent(parts[i], k -> new HashMap()); + } + current.put(parts[parts.length - 1], value); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java index 5430149830..7051dfc5d6 100644 --- a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java @@ -10,6 +10,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.neuralsearch.stats.DerivedStats; import org.opensearch.neuralsearch.stats.NeuralStats; import org.opensearch.transport.TransportService; import org.opensearch.threadpool.ThreadPool; @@ -58,7 +59,9 @@ public NeuralStatsTransportAction( ThreadPool.Names.MANAGEMENT, NeuralStatsNodeResponse.class ); - this.neuralStats = neuralStats; + // TODO : inject rather than singleton here + // this.neuralStats = neuralStats; + this.neuralStats = NeuralStats.instance(); } @Override @@ -70,9 +73,13 @@ protected NeuralStatsResponse newResponse( Map clusterStats = new HashMap<>(); - for (String statName : neuralStats.getStats().keySet()) { - clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); - } + clusterStats.put("cluster_level_stat_1", "Yay!"); + // for (String statName : neuralStats.getStats().keySet()) { + // clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); + // }' + DerivedStats derivedStats = DerivedStats.instance(); + clusterStats.putAll(derivedStats.addDerivedStats(responses.stream().map(NeuralStatsNodeResponse::getStatsMap).toList())); + System.out.println(clusterStats); return new NeuralStatsResponse(clusterService.getClusterName(), responses, failures, clusterStats); @@ -90,11 +97,14 @@ protected NeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOExcep @Override protected NeuralStatsNodeResponse nodeOperation(NeuralStatsNodeRequest request) { - Map statValues = new HashMap<>(); + // Reads from NeuralStats to node level stats on an individual node + Map statValues = new HashMap<>(); for (String statName : neuralStats.getStats().keySet()) { statValues.put(statName, neuralStats.getStats().get(statName).getValue()); } + System.out.println("Transport_Action node operation for (ta_node_pasta)" + clusterService.localNode()); + System.out.println(statValues); return new NeuralStatsNodeResponse(clusterService.localNode(), statValues); } } diff --git a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java index 64040a4322..1293b39f1f 100644 --- a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java +++ b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java @@ -6,6 +6,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.client.Request; @@ -13,30 +14,119 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.plugin.NeuralSearch; +import org.opensearch.neuralsearch.stats.NeuralStats; import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @Log4j2 public class RestNeuralStatsHandlerIT extends BaseNeuralSearchIT { + private static final String INDEX_NAME = "text_embedding_index"; + + private static final String INGEST_PIPELINE_NAME = "ingest-pipeline-1"; + private static final String SEARCH_PIPELINE_NAME = "search-pipeline-1"; + protected static final String QUERY_TEXT = "hello"; + protected static final String LEVEL_1_FIELD = "nested_passages"; + protected static final String LEVEL_2_FIELD = "level_2"; + protected static final String LEVEL_3_FIELD_TEXT = "level_3_text"; + protected static final String LEVEL_3_FIELD_CONTAINER = "level_3_container"; + protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding"; + protected static final String TEXT_FIELD_VALUE_1 = "hello"; + protected static final String TEXT_FIELD_VALUE_2 = "clown"; + protected static final String TEXT_FIELD_VALUE_3 = "abc"; + private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); + private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); + private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); + private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); + + private final String TITLE_KNN_FIELD = "title_knn"; + + public RestNeuralStatsHandlerIT() throws IOException, URISyntaxException {} + @Before public void setUp() throws Exception { super.setUp(); updateClusterSettings(); } - public void test_happyCase() throws Exception { - Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); - String responseBody = EntityUtils.toString(response.getEntity()); - Map clusterStats = parseStatsResponse(responseBody); - log.info("rest_it_pasta"); - for (Map.Entry entry : clusterStats.entrySet()) { - log.info(entry.toString(), entry.getValue()); + @After + public void tearDown() throws Exception { + super.tearDown(); + + executeClearNeuralStatRequest(Collections.emptyList()); + } + + public void test_happyCase_textEmbedding() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + ingestDocument(INDEX_NAME, INGEST_DOC3); + assertEquals(3, getDocCount(INDEX_NAME)); + + Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(3, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, null); + } + } + + public void test_happyCase_clearNeuralStats() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + assertEquals(2, getDocCount(INDEX_NAME)); + + Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(2, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + executeClearNeuralStatRequest(Collections.emptyList()); + + response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + responseBody = EntityUtils.toString(response.getEntity()); + nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(0, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, null); } - assertEquals("Sushi", (String) clusterStats.get("Bratwurst")); + } + + protected String uploadTextEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); } protected Response executeNeuralStatRequest(List nodeIds, List stats) throws IOException { @@ -57,16 +147,23 @@ protected Response executeNeuralStatRequest(List nodeIds, List s return response; } - protected Map parseStatsResponse(String responseBody) throws IOException { - Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); - return responseMap; + protected Response executeClearNeuralStatRequest(List nodeIds) throws IOException { + String nodePrefix = ""; + if (!nodeIds.isEmpty()) { + nodePrefix = "/" + String.join(",", nodeIds); + } + + Request request = new Request("GET", NeuralSearch.NEURAL_BASE_URI + nodePrefix + "/stats/" + RestNeuralStatsHandler.CLEAR_PARAM); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; } - protected Map parseClusterStatsResponse(String responseBody) throws IOException { + protected Map parseStatsResponse(String responseBody) throws IOException { Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); responseMap.remove("cluster_name"); responseMap.remove("_nodes"); - responseMap.remove("nodes"); return responseMap; } @@ -85,4 +182,28 @@ protected List> parseNodeStatsResponse(String responseBody) return nodeResponses; } + + public Object getNestedValue(Map map, String path) { + String[] keys = path.split("\\."); + return getNestedValueHelper(map, keys, 0); + } + + private Object getNestedValueHelper(Map map, String[] keys, int depth) { + if (map == null) { + return null; + } + + Object value = map.get(keys[depth]); + + if (depth == keys.length - 1) { + return value; + } + + if (value instanceof Map) { + Map nestedMap = (Map) value; + return getNestedValueHelper(nestedMap, keys, depth + 1); + } + + return null; + } } From 50017a7fe4cff7decc4d6b5fb2810c9e17316c84 Mon Sep 17 00:00:00 2001 From: Andy Qin Date: Thu, 6 Feb 2025 13:04:39 -0800 Subject: [PATCH 4/4] Add NeuralQueryEnricher stats Signed-off-by: Andy Qin --- .../NeuralQueryEnricherProcessor.java | 4 ++ .../neuralsearch/stats/names/StatName.java | 3 +- .../rest/RestNeuralStatsHandlerIT.java | 39 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java index 3ee212ec7b..c15148d394 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java @@ -13,6 +13,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.neuralsearch.stats.names.StatName; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; @@ -71,6 +73,8 @@ public SearchRequest processRequest(SearchRequest searchRequest) { if (queryBuilder != null) { queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap)); } + + NeuralStats.record(StatName.NEURAL_QUERY_ENRICHER_PROCESSOR_EXECUTIONS).increment(); return searchRequest; } diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java index 1d37480550..ff264afcde 100644 --- a/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java @@ -15,7 +15,8 @@ public enum StatName { INFO_DERIVED_STAT("example.counter", StatType.INFO_DERIVED), STAT_DERIVED_STAT("example.counter", StatType.STAT_DERIVED), - TEXT_EMBEDDING_PROCESSOR_EXECUTIONS("ingest_processor.text_embedding.executions", StatType.COUNTER_EVENT); + TEXT_EMBEDDING_PROCESSOR_EXECUTIONS("ingest_processor.text_embedding.executions", StatType.COUNTER_EVENT), + NEURAL_QUERY_ENRICHER_PROCESSOR_EXECUTIONS("search_processor.neural_query_enricher.executions", StatType.COUNTER_EVENT); @Getter private final String name; diff --git a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java index 1293b39f1f..9d8a9370da 100644 --- a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java +++ b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java @@ -10,10 +10,12 @@ import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.client.Request; +import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.plugin.NeuralSearch; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.stats.NeuralStats; import java.io.IOException; @@ -124,6 +126,43 @@ public void test_happyCase_clearNeuralStats() throws Exception { } } + public void test_happyCase_neuralQueryEnricher() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = prepareModel(); + createSearchRequestProcessor(modelId, SEARCH_PIPELINE_NAME); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + + updateIndexSettings(INDEX_NAME, Settings.builder().put("index.search.default_pipeline", SEARCH_PIPELINE_NAME)); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(TITLE_KNN_FIELD) + .queryText("Second") + .k(10) + .build(); + + Map response = search(INDEX_NAME, neuralQueryBuilder, 2); + log.info(response); + assertFalse(response.isEmpty()); + + // Stats request + Response statResponse = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(statResponse.getEntity()); + Map nodeStats = parseNodeStatsResponse(responseBody).getFirst(); + + log.info(nodeStats); + assertEquals(2, getNestedValue(nodeStats, "ingest_processor.text_embedding.executions")); + assertEquals(1, getNestedValue(nodeStats, "search_processor.neural_query_enricher.executions")); + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); + } + } + protected String uploadTextEmbeddingModel() throws Exception { String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); return registerModelGroupAndUploadModel(requestBody);