diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index f7ac5d19f3..59a6dfc9eb 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -14,12 +14,19 @@ import java.util.Optional; import java.util.function.Supplier; +import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.util.FeatureFlags; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -55,7 +62,12 @@ import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.neuralsearch.rest.RestNeuralStatsHandler; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsTransportAction; +import org.opensearch.neuralsearch.transport.NeuralStatsAction; +import org.opensearch.neuralsearch.transport.NeuralStatsTransportAction; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; @@ -64,6 +76,8 @@ import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; @@ -85,6 +99,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); public static final String EXPLANATION_RESPONSE_KEY = "explanation_response"; + public static final String NEURAL_BASE_URI = "/_plugins/_neural"; @Override public Collection createComponents( @@ -117,6 +132,28 @@ public List> getQueries() { ); } + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + RestNeuralStatsHandler restNeuralStatsHandler = new RestNeuralStatsHandler(); + return ImmutableList.of(restNeuralStatsHandler); + } + + @Override + public List> getActions() { + return Arrays.asList( + new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class), + new ActionHandler<>(ClearNeuralStatsAction.INSTANCE, ClearNeuralStatsTransportAction.class) + ); + } + @Override public List> getExecutorBuilders(Settings settings) { return List.of(HybridQueryExecutor.getExecutorBuilder(settings)); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java index 3ee212ec7b..c15148d394 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java @@ -13,6 +13,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.neuralsearch.stats.names.StatName; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; @@ -71,6 +73,8 @@ public SearchRequest processRequest(SearchRequest searchRequest) { if (queryBuilder != null) { queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap)); } + + NeuralStats.record(StatName.NEURAL_QUERY_ENRICHER_PROCESSOR_EXECUTIONS).increment(); return searchRequest; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d8..ff6f69bac6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -16,6 +16,8 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.neuralsearch.stats.names.StatName; /** * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, @@ -47,6 +49,8 @@ public void doExecute( List inferenceList, BiConsumer handler ) { + NeuralStats.record(StatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS).increment(); + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); diff --git a/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java new file mode 100644 index 0000000000..df3eefb904 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandler.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.rest; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; +import org.opensearch.client.node.NodeClient; +import org.opensearch.neuralsearch.stats.NeuralStatsInput; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction; +import org.opensearch.neuralsearch.transport.ClearNeuralStatsRequest; +import org.opensearch.neuralsearch.transport.NeuralStatsAction; +import org.opensearch.neuralsearch.transport.NeuralStatsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_BASE_URI; + +@Log4j2 +@AllArgsConstructor +public class RestNeuralStatsHandler extends BaseRestHandler { + private static final String NAME = "neural_stats_action"; + public static final String CLEAR_PARAM = "_clear"; + + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/{stat}"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/"), + new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/{stat}") + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + if (request.param("stat", "").equals(CLEAR_PARAM)) { + // TODO : Hacky, possible collisions. Should be refactored into separate endpoint later + String[] nodeIdsArr = null; + String nodesIdsStr = request.param("nodeId"); + if (StringUtils.isNotEmpty(nodesIdsStr)) { + nodeIdsArr = nodesIdsStr.split(","); + } + + ClearNeuralStatsRequest clearNeuralStatsRequest = new ClearNeuralStatsRequest(nodeIdsArr); + clearNeuralStatsRequest.timeout(request.param("timeout")); + + return channel -> client.execute( + ClearNeuralStatsAction.INSTANCE, + clearNeuralStatsRequest, + new RestActions.NodesResponseRestListener<>(channel) + ); + } else { + // Read inputs and convert to BaseNodesRequest with correct info configured + NeuralStatsRequest neuralStatsRequest = getRequest(request); + + return channel -> client.execute( + NeuralStatsAction.INSTANCE, + neuralStatsRequest, + new RestActions.NodesResponseRestListener<>(channel) + ); + } + + } + + /** + * Creates a NeuralStatsRequest from a RestRequest + * + * @param request Rest request + * @return NeuralStatsRequest + */ + private NeuralStatsRequest getRequest(RestRequest request) { + // parse the nodes the user wants to query + String[] nodeIdsArr = null; + String nodesIdsStr = request.param("nodeId"); + if (StringUtils.isNotEmpty(nodesIdsStr)) { + nodeIdsArr = nodesIdsStr.split(","); + } + + NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr, new NeuralStatsInput()); + neuralStatsRequest.timeout(request.param("timeout")); + + // parse the stats the customer wants to see + Set statsSet = null; + String statsStr = request.param("stat"); + if (StringUtils.isNotEmpty(statsStr)) { + statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); + } + + if (statsSet == null) { + + } else if (statsSet.size() == 1 && statsSet.contains("_all")) { + + } else if (statsSet.contains(NeuralStatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException("Request " + request.path() + " contains _all and individual stats"); + } else { + // Validate NeuralStats input is valid + // Set invalidStats = new TreeSet<>(); + // for (String stat : statsSet) { + // // validate request contains valid stats + // // if (!neuralStatsRequest.addStat(stat)) { + // // invalidStats.add(stat); + // // } + // } + // + // if (!invalidStats.isEmpty()) { + // throw new IllegalArgumentException(unrecognized(request, invalidStats, neuralStatsRequest.getNeuralStatsInput(), "stat")); + // } + + } + log.info("pasta"); + log.info(statsSet); + log.info(neuralStatsRequest.getNeuralStatsInput()); + return neuralStatsRequest; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java b/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java new file mode 100644 index 0000000000..4dcc5857f8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import org.opensearch.neuralsearch.stats.suppliers.DerivedSupplier; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.function.Function; + +public class DerivedStats { + private static final String AGG_KEY_PREFIX = "all_nodes."; + private static DerivedStats INSTANCE; + + public static DerivedStats instance() { + if (INSTANCE == null) { + INSTANCE = new DerivedStats(); + } + return INSTANCE; + } + + private final Map> derivedStatsMap; + private Map aggregatedNodeResponse; + + public DerivedStats() { + this.derivedStatsMap = new ConcurrentSkipListMap<>(); + register("derived.cluster_version", DerivedStats::clusterVersion); + } + + public Map aggregateNodesResponses(List> nodeResponses) { + Map summedMap = new HashMap<>(); + for (Map map : nodeResponses) { + for (Map.Entry entry : map.entrySet()) { + summedMap.merge(AGG_KEY_PREFIX + entry.getKey(), entry.getValue(), Long::sum); + } + } + return summedMap; + } + + public Map addDerivedStats(List> nodeResponses) { + // Reference to provide derived methods access to node Responses + this.aggregatedNodeResponse = aggregateNodesResponses(nodeResponses); + + Map computedDerivedStats = new TreeMap<>(); + for (Map.Entry> neuralStatEntry : derivedStatsMap.entrySet()) { + computedDerivedStats.put(neuralStatEntry.getKey(), neuralStatEntry.getValue().getValue()); + } + + computedDerivedStats.putAll(aggregatedNodeResponse); + // Reset reference to not store + this.aggregatedNodeResponse = null; + return computedDerivedStats; + } + + private void register(String statPath, Function derivedMethod) { + if (derivedStatsMap.containsKey(statPath)) { + // Validation error here + return; + } + NeuralStat neuralStat = new NeuralStat<>(new DerivedSupplier<>(this, derivedMethod)); + derivedStatsMap.put(statPath, neuralStat); + } + + private String clusterVersion() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().toString(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java new file mode 100644 index 0000000000..e672b13bde --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + +import java.util.function.Supplier; + +public class NeuralStat { + private Supplier supplier; + + public NeuralStat(Supplier supplier) { + this.supplier = supplier; + } + + public T getValue() { + return supplier.get(); + } + + /** + * Increments the supplier if it can be incremented + */ + public void increment() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).increment(); + } + } + + /** + * Decrease the supplier if it can be decreased. + */ + public void decrement() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).decrement(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java new file mode 100644 index 0000000000..950fe54614 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatBuilder.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + +import java.util.ArrayList; +import java.util.List; + +public class NeuralStatBuilder { + public static final String DELIMITER = "."; + private NeuralStats neuralStats; + private List path; + + NeuralStatBuilder(NeuralStats neuralStats) { + this.neuralStats = neuralStats; + this.path = new ArrayList<>(); + } + + public String getPathString() { + return String.join(DELIMITER, path); + } + + public NeuralStatBuilder ingestProcessor(String processor) { + // Add a validation check here + // Refactor to constnats + path.add("ingest_processor"); + path.add(processor); + return this; + } + + public NeuralStatBuilder searchProcessor(String processor) { + // Add a validation check here + // Refactor to constnats + path.add("search_processor"); + path.add(processor); + return this; + } + + public NeuralStatBuilder metric(String metric) { + // should do some validation here + path.add(metric); + return this; + } + + public void increment() { + // Should do some extra validation here + // SOme kind of "validate path" method that makes sure the path is valid? + neuralStats.getStats().computeIfAbsent(getPathString(), k -> new NeuralStat<>(new CounterSupplier())).increment(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java new file mode 100644 index 0000000000..065ec2d0c3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStats.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import com.google.common.annotations.VisibleForTesting; +import org.opensearch.neuralsearch.stats.names.StatName; +import org.opensearch.neuralsearch.stats.names.StatType; +import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier; + +import java.util.EnumSet; +import java.util.Map; +import java.util.concurrent.ConcurrentSkipListMap; + +public class NeuralStats { + private Map> counterStatsMap; + + public static NeuralStats INSTANCE; + + public static NeuralStats instance() { + if (INSTANCE == null) { + INSTANCE = new NeuralStats(); + } + return INSTANCE; + } + + public static NeuralStatBuilder recordMetric() { + return new NeuralStatBuilder(instance()); + } + + public static NeuralStat record(StatName statName) { + return instance().getStats().computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } + + public NeuralStats() { + this.counterStatsMap = new ConcurrentSkipListMap<>(); + + // Initialize event counter stats + for (StatName statName : EnumSet.allOf(StatName.class)) { + if (statName.getStatType() == StatType.COUNTER_EVENT) { + counterStatsMap.computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } + } + } + + /** + * Get the stats + * + * @return all the stats + */ + public Map> getStats() { + return counterStatsMap; + } + + @VisibleForTesting + public void resetStats() { + // Risk of memory leak? + this.counterStatsMap = new ConcurrentSkipListMap<>(); + for (StatName statName : EnumSet.allOf(StatName.class)) { + if (statName.getStatType() == StatType.COUNTER_EVENT) { + counterStatsMap.computeIfAbsent(statName.getName(), k -> new NeuralStat<>(new CounterSupplier())); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java new file mode 100644 index 0000000000..c624e90d9c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +public class NeuralStatsInput implements ToXContentObject, Writeable { + public static final String NODE_IDS = "node_ids"; + + /** + * Which node's stats will be retrieved. + */ + private Set nodeIds; + + @Builder + public NeuralStatsInput(Set nodeIds) { + this.nodeIds = nodeIds; + } + + public NeuralStatsInput() { + this.nodeIds = new HashSet<>(); + } + + public NeuralStatsInput(StreamInput input) throws IOException { + nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalStringCollection(nodeIds); + } + + private void writeEnumSet(StreamOutput out, EnumSet set) throws IOException { + if (set != null && set.size() > 0) { + out.writeBoolean(true); + out.writeEnumSet(set); + } else { + out.writeBoolean(false); + } + } + + public static NeuralStatsInput parse(XContentParser parser) throws IOException { + Set nodeIds = new HashSet<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NODE_IDS: + parseArrayField(parser, nodeIds); + break; + default: + parser.skipChildren(); + break; + } + } + return NeuralStatsInput.builder().nodeIds(nodeIds).build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (nodeIds != null) { + builder.field(NODE_IDS, nodeIds); + } + builder.endObject(); + return builder; + } + + public boolean retrieveStatsOnAllNodes() { + return nodeIds == null || nodeIds.size() == 0; + } + + public static void parseArrayField(XContentParser parser, Set set) throws IOException { + parseField(parser, set, null, String.class); + } + + public static void parseField(XContentParser parser, Set set, Function function, Class clazz) throws IOException { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + String value = parser.text(); + if (function != null) { + set.add(function.apply(value)); + } else { + if (clazz.isInstance(value)) { + set.add(clazz.cast(value)); + } + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java new file mode 100644 index 0000000000..ff264afcde --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/StatName.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.names; + +import lombok.Getter; + +import java.util.HashSet; +import java.util.Set; + +public enum StatName { + // TODO : stat type not currently used + EVENT_STAT("example.counter", StatType.COUNTER_EVENT), + INFO_DERIVED_STAT("example.counter", StatType.INFO_DERIVED), + STAT_DERIVED_STAT("example.counter", StatType.STAT_DERIVED), + + TEXT_EMBEDDING_PROCESSOR_EXECUTIONS("ingest_processor.text_embedding.executions", StatType.COUNTER_EVENT), + NEURAL_QUERY_ENRICHER_PROCESSOR_EXECUTIONS("search_processor.neural_query_enricher.executions", StatType.COUNTER_EVENT); + + @Getter + private final String name; + @Getter + private final StatType statType; + + StatName(String name, StatType statType) { + this.name = name; + this.statType = statType; + } + + /** + * Get all stat names + * + * @return set of all stat names + */ + public static Set getNames() { + Set names = new HashSet<>(); + + for (StatName statName : StatName.values()) { + names.add(statName.getName()); + } + return names; + } + + @Override + public String toString() { + return getName(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java b/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java new file mode 100644 index 0000000000..e63d4c3b88 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/names/StatType.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.names; + +public enum StatType { + INFO_DERIVED, + STAT_DERIVED, + COUNTER_EVENT; +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java new file mode 100644 index 0000000000..0d3d81f2db --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/CounterSupplier.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.suppliers; + +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Supplier; + +/** + * CounterSupplier provides a stateful count as the value + */ +public class CounterSupplier implements Supplier { + private LongAdder counter; + + /** + * Constructor + */ + public CounterSupplier() { + this.counter = new LongAdder(); + } + + @Override + public Long get() { + return counter.longValue(); + } + + /** + * Increments the value of the counter by 1 + */ + public void increment() { + counter.increment(); + } + + /** + * Decrease the value of the counter by 1 + */ + public void decrement() { + counter.decrement(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java new file mode 100644 index 0000000000..d89c863398 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/stats/suppliers/DerivedSupplier.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.stats.suppliers; + +import org.opensearch.neuralsearch.stats.DerivedStats; + +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * DerivedSupplier a derived stat value + */ +public class DerivedSupplier implements Supplier { + private DerivedStats derivedStats; + private Function getter; + + /** + * Constructor + */ + public DerivedSupplier(DerivedStats derivedStats, Function getter) { + this.derivedStats = derivedStats; + this.getter = getter; + } + + @Override + public T get() { + return getter.apply(derivedStats); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java new file mode 100644 index 0000000000..187859b968 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionType; + +public class ClearNeuralStatsAction extends ActionType { + public static final ClearNeuralStatsAction INSTANCE = new ClearNeuralStatsAction(); + public static final String NAME = "cluster:admin/clear_neural_stats_action"; // TODO : figure this out + + /** + * Constructor + */ + private ClearNeuralStatsAction() { + super(NAME, ClearNeuralStatsResponse::new); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java new file mode 100644 index 0000000000..e118f32830 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeRequest.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * ClearNeuralStatsNodeRequest represents the request to an individual node + */ +public class ClearNeuralStatsNodeRequest extends TransportRequest { + @Getter + private ClearNeuralStatsRequest request; + + /** + * Constructor + */ + public ClearNeuralStatsNodeRequest() { + super(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public ClearNeuralStatsNodeRequest(StreamInput in) throws IOException { + super(in); + request = new ClearNeuralStatsRequest(in); + } + + /** + * Constructor + * + * @param request ClearNeuralStatsRequest + */ + public ClearNeuralStatsNodeRequest(ClearNeuralStatsRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java new file mode 100644 index 0000000000..d9f84d7fe3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsNodeResponse.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * ClearNeuralStatsNodeResponse represents the responses generated by an individual node + */ +public class ClearNeuralStatsNodeResponse extends BaseNodeResponse { + /** + * Constructor + * + * @param in stream + * @throws IOException in case of I/O errors + */ + public ClearNeuralStatsNodeResponse(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param node node + */ + public ClearNeuralStatsNodeResponse(DiscoveryNode node) { + super(node); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java new file mode 100644 index 0000000000..d8ce5eb9d4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +public class ClearNeuralStatsRequest extends BaseNodesRequest { + public ClearNeuralStatsRequest(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param nodeIds NodeIDs from which to retrieve stats + */ + public ClearNeuralStatsRequest(String[] nodeIds) { + super(nodeIds); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java new file mode 100644 index 0000000000..0f0b0fd703 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public class ClearNeuralStatsResponse extends BaseNodesResponse implements ToXContentObject { + private static final String NODES_KEY = "nodes"; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public ClearNeuralStatsResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ClearNeuralStatsNodeResponse::new), in.readList(FailedNodeException::new)); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of NeuralStatsNodeResponses + * @param failures List of failures from nodes + */ + public ClearNeuralStatsResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ClearNeuralStatsNodeResponse::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // TODO : response should go here + return builder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java new file mode 100644 index 0000000000..abc2b3dbad --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/ClearNeuralStatsTransportAction.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * ClearNeuralStatsTransportAction contains the logic to clear all nodes stats + */ +public class ClearNeuralStatsTransportAction extends TransportNodesAction< + ClearNeuralStatsRequest, + ClearNeuralStatsResponse, + ClearNeuralStatsNodeRequest, + ClearNeuralStatsNodeResponse> { + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + */ + @Inject + public ClearNeuralStatsTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters + ) { + super( + ClearNeuralStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ClearNeuralStatsRequest::new, + ClearNeuralStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ClearNeuralStatsNodeResponse.class + ); + } + + @Override + protected ClearNeuralStatsResponse newResponse( + ClearNeuralStatsRequest request, + List responses, + List failures + ) { + return new ClearNeuralStatsResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ClearNeuralStatsNodeRequest newNodeRequest(ClearNeuralStatsRequest request) { + return new ClearNeuralStatsNodeRequest(request); + } + + @Override + protected ClearNeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ClearNeuralStatsNodeResponse(in); + } + + @Override + protected ClearNeuralStatsNodeResponse nodeOperation(ClearNeuralStatsNodeRequest request) { + NeuralStats.instance().resetStats(); + return new ClearNeuralStatsNodeResponse(clusterService.localNode()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java new file mode 100644 index 0000000000..85683b871c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * NeuralStatsAction class + */ +public class NeuralStatsAction extends ActionType { + + public static final NeuralStatsAction INSTANCE = new NeuralStatsAction(); + public static final String NAME = "cluster:admin/neural_stats_action"; // TODO : figure this out + + /** + * Constructor + */ + private NeuralStatsAction() { + super(NAME, NeuralStatsResponse::new); + } + + @Override + public Writeable.Reader getResponseReader() { + return NeuralStatsResponse::new; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java new file mode 100644 index 0000000000..0ef5205627 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeRequest.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; + +/** + * NeuralStatsNodeRequest represents the request to an individual node + */ +public class NeuralStatsNodeRequest extends TransportRequest { + @Getter + private NeuralStatsRequest request; + + /** + * Constructor + */ + public NeuralStatsNodeRequest() { + super(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsNodeRequest(StreamInput in) throws IOException { + super(in); + request = new NeuralStatsRequest(in); + } + + /** + * Constructor + * + * @param request NeuralStatsRequest + */ + public NeuralStatsNodeRequest(NeuralStatsRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java new file mode 100644 index 0000000000..b9afc87271 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsNodeResponse.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.TreeMap; + +/** + * NeuralStatsNodeResponse represents the responses generated by an individual node + */ +public class NeuralStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { + + @Getter + private Map statsMap; + + /** + * Constructor + * + * @param in stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsNodeResponse(StreamInput in) throws IOException { + super(in); + this.statsMap = new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readLong)); + } + + /** + * Constructor + * + * @param node node + * @param statsToValues mapping of stat name to value + */ + public NeuralStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + super(node); + this.statsMap = new TreeMap<>(statsToValues); + } + + /** + * Creates a new NeuralStatsNodeResponse object and reads in the stats from an input stream + * + * @param in StreamInput to read from + * @return NeuralStatsNodeResponse object corresponding to the input stream + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public static NeuralStatsNodeResponse readStats(StreamInput in) throws IOException { + NeuralStatsNodeResponse neuralStats = new NeuralStatsNodeResponse(in); + return neuralStats; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeLong); + } + + /** + * Converts statsMap to xContent + * + * @param builder XContentBuilder + * @param params Params + * @return XContentBuilder + * @throws IOException thrown by builder for invalid field + */ + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + for (String stat : statsMap.keySet()) { + builder.field(stat, statsMap.get(stat)); + } + return builder; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java new file mode 100644 index 0000000000..ecd40ee510 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsRequest.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import lombok.Getter; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.neuralsearch.stats.NeuralStatsInput; + +import java.io.IOException; + +/** + * NeuralStatsRequest gets node (cluster) level Stats for Neural + * By default, all parameters will be true + */ +public class NeuralStatsRequest extends BaseNodesRequest { + + /** + * Key indicating all stats should be retrieved + */ + public static final String ALL_STATS_KEY = "_all"; + @Getter + private final NeuralStatsInput neuralStatsInput; + + /** + * Empty constructor needed for NeuralStatsTransportAction + */ + public NeuralStatsRequest() { + super((String[]) null); + this.neuralStatsInput = new NeuralStatsInput(); + } + + /** + * Constructor + * + * @param in input stream + * @throws IOException in case of I/O errors + */ + public NeuralStatsRequest(StreamInput in) throws IOException { + super(in); + this.neuralStatsInput = new NeuralStatsInput(in); + } + + /** + * Constructor + * + * @param nodeIds NodeIDs from which to retrieve stats + */ + public NeuralStatsRequest(String[] nodeIds, NeuralStatsInput neuralStatsInput) { + super(nodeIds); + this.neuralStatsInput = neuralStatsInput; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + neuralStatsInput.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java new file mode 100644 index 0000000000..40193cd0f9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * NeuralStatsResponse consists of the aggregated responses from the nodes + */ +public class NeuralStatsResponse extends BaseNodesResponse implements ToXContentObject { + + private static final String NODES_KEY = "nodes"; + private Map clusterStats; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public NeuralStatsResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(NeuralStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + clusterStats = new TreeMap<>(in.readMap()); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of NeuralStatsNodeResponses + * @param failures List of failures from nodes + * @param clusterStats Cluster level stats only obtained from a single node + */ + public NeuralStatsResponse( + ClusterName clusterName, + List nodes, + List failures, + Map clusterStats + ) { + super(clusterName, nodes, failures); + this.clusterStats = new TreeMap<>(clusterStats); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(clusterStats); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(NeuralStatsNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Return cluster level stats + Map nestedClusterStats = convertFlatToNestedMap(new TreeMap<>(clusterStats)); + buildNestedMapXContent(builder, nestedClusterStats); + + // Return node level stats + String nodeId; + DiscoveryNode node; + builder.startObject(NODES_KEY); + for (NeuralStatsNodeResponse neuralStatsResponse : getNodes()) { + node = neuralStatsResponse.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + Map nestedMap = convertFlatToNestedMap(new TreeMap<>(neuralStatsResponse.getStatsMap())); + buildNestedMapXContent(builder, nestedMap); + builder.endObject(); + } + builder.endObject(); + return builder; + } + + private void buildNestedMapXContent(XContentBuilder builder, Map map) throws IOException { + for (Map.Entry entry : map.entrySet()) { + if (entry.getValue() instanceof Map) { + builder.startObject(entry.getKey()); + buildNestedMapXContent(builder, (Map) entry.getValue()); + builder.endObject(); + } else { + builder.field(entry.getKey(), entry.getValue()); + } + } + } + + private Map convertFlatToNestedMap(Map map) { + Map nestedMap = new TreeMap<>(); + for (Map.Entry entry : map.entrySet()) { + putNested(nestedMap, entry.getKey(), entry.getValue()); + } + return nestedMap; + } + + private void putNested(Map map, String path, Object value) { + String[] parts = path.split("\\."); + Map current = map; + for (int i = 0; i < parts.length - 1; i++) { + current = (Map) current.computeIfAbsent(parts[i], k -> new HashMap()); + } + current.put(parts[parts.length - 1], value); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java new file mode 100644 index 0000000000..7051dfc5d6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.transport; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.neuralsearch.stats.DerivedStats; +import org.opensearch.neuralsearch.stats.NeuralStats; +import org.opensearch.transport.TransportService; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * NeuralStatsTransportAction contains the logic to extract the stats from the nodes + */ +public class NeuralStatsTransportAction extends TransportNodesAction< + NeuralStatsRequest, + NeuralStatsResponse, + NeuralStatsNodeRequest, + NeuralStatsNodeResponse> { + + private NeuralStats neuralStats; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param neuralStats NeuralStats object + */ + @Inject + public NeuralStatsTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NeuralStats neuralStats + ) { + super( + NeuralStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + NeuralStatsRequest::new, + NeuralStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + NeuralStatsNodeResponse.class + ); + // TODO : inject rather than singleton here + // this.neuralStats = neuralStats; + this.neuralStats = NeuralStats.instance(); + } + + @Override + protected NeuralStatsResponse newResponse( + NeuralStatsRequest request, + List responses, + List failures + ) { + + Map clusterStats = new HashMap<>(); + + clusterStats.put("cluster_level_stat_1", "Yay!"); + // for (String statName : neuralStats.getStats().keySet()) { + // clusterStats.put(statName, neuralStats.getStats().get(statName).getValue()); + // }' + DerivedStats derivedStats = DerivedStats.instance(); + clusterStats.putAll(derivedStats.addDerivedStats(responses.stream().map(NeuralStatsNodeResponse::getStatsMap).toList())); + + System.out.println(clusterStats); + + return new NeuralStatsResponse(clusterService.getClusterName(), responses, failures, clusterStats); + } + + @Override + protected NeuralStatsNodeRequest newNodeRequest(NeuralStatsRequest request) { + return new NeuralStatsNodeRequest(request); + } + + @Override + protected NeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new NeuralStatsNodeResponse(in); + } + + @Override + protected NeuralStatsNodeResponse nodeOperation(NeuralStatsNodeRequest request) { + // Reads from NeuralStats to node level stats on an individual node + Map statValues = new HashMap<>(); + + for (String statName : neuralStats.getStats().keySet()) { + statValues.put(statName, neuralStats.getStats().get(statName).getValue()); + } + System.out.println("Transport_Action node operation for (ta_node_pasta)" + clusterService.localNode()); + System.out.println(statValues); + return new NeuralStatsNodeResponse(clusterService.localNode(), statValues); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java new file mode 100644 index 0000000000..9d8a9370da --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsHandlerIT.java @@ -0,0 +1,248 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.rest; + +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.client.Request; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.plugin.NeuralSearch; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.stats.NeuralStats; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Log4j2 +public class RestNeuralStatsHandlerIT extends BaseNeuralSearchIT { + private static final String INDEX_NAME = "text_embedding_index"; + + private static final String INGEST_PIPELINE_NAME = "ingest-pipeline-1"; + private static final String SEARCH_PIPELINE_NAME = "search-pipeline-1"; + protected static final String QUERY_TEXT = "hello"; + protected static final String LEVEL_1_FIELD = "nested_passages"; + protected static final String LEVEL_2_FIELD = "level_2"; + protected static final String LEVEL_3_FIELD_TEXT = "level_3_text"; + protected static final String LEVEL_3_FIELD_CONTAINER = "level_3_container"; + protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding"; + protected static final String TEXT_FIELD_VALUE_1 = "hello"; + protected static final String TEXT_FIELD_VALUE_2 = "clown"; + protected static final String TEXT_FIELD_VALUE_3 = "abc"; + private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); + private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); + private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); + private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); + + private final String TITLE_KNN_FIELD = "title_knn"; + + public RestNeuralStatsHandlerIT() throws IOException, URISyntaxException {} + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + + executeClearNeuralStatRequest(Collections.emptyList()); + } + + public void test_happyCase_textEmbedding() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + ingestDocument(INDEX_NAME, INGEST_DOC3); + assertEquals(3, getDocCount(INDEX_NAME)); + + Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(3, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, null); + } + } + + public void test_happyCase_clearNeuralStats() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + assertEquals(2, getDocCount(INDEX_NAME)); + + Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(response.getEntity()); + List> nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(2, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + executeClearNeuralStatRequest(Collections.emptyList()); + + response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + responseBody = EntityUtils.toString(response.getEntity()); + nodesStats = parseNodeStatsResponse(responseBody); + + log.info(nodesStats); + assertEquals(0, getNestedValue(nodesStats.getFirst(), "ingest_processor.text_embedding.executions")); + + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, null); + } + } + + public void test_happyCase_neuralQueryEnricher() throws Exception { + NeuralStats.instance().resetStats(); + + String modelId = null; + try { + modelId = prepareModel(); + createSearchRequestProcessor(modelId, SEARCH_PIPELINE_NAME); + createPipelineProcessor(modelId, INGEST_PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", INGEST_PIPELINE_NAME); + + ingestDocument(INDEX_NAME, INGEST_DOC1); + ingestDocument(INDEX_NAME, INGEST_DOC2); + + updateIndexSettings(INDEX_NAME, Settings.builder().put("index.search.default_pipeline", SEARCH_PIPELINE_NAME)); + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(TITLE_KNN_FIELD) + .queryText("Second") + .k(10) + .build(); + + Map response = search(INDEX_NAME, neuralQueryBuilder, 2); + log.info(response); + assertFalse(response.isEmpty()); + + // Stats request + Response statResponse = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>()); + String responseBody = EntityUtils.toString(statResponse.getEntity()); + Map nodeStats = parseNodeStatsResponse(responseBody).getFirst(); + + log.info(nodeStats); + assertEquals(2, getNestedValue(nodeStats, "ingest_processor.text_embedding.executions")); + assertEquals(1, getNestedValue(nodeStats, "search_processor.neural_query_enricher.executions")); + } finally { + wipeOfTestResources(INDEX_NAME, INGEST_PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); + } + } + + protected String uploadTextEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); + } + + protected Response executeNeuralStatRequest(List nodeIds, List stats) throws IOException { + String nodePrefix = ""; + if (!nodeIds.isEmpty()) { + nodePrefix = "/" + String.join(",", nodeIds); + } + + String statsSuffix = ""; + if (!stats.isEmpty()) { + statsSuffix = "/" + String.join(",", stats); + } + + Request request = new Request("GET", NeuralSearch.NEURAL_BASE_URI + nodePrefix + "/stats" + statsSuffix); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + + protected Response executeClearNeuralStatRequest(List nodeIds) throws IOException { + String nodePrefix = ""; + if (!nodeIds.isEmpty()) { + nodePrefix = "/" + String.join(",", nodeIds); + } + + Request request = new Request("GET", NeuralSearch.NEURAL_BASE_URI + nodePrefix + "/stats/" + RestNeuralStatsHandler.CLEAR_PARAM); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + + protected Map parseStatsResponse(String responseBody) throws IOException { + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + responseMap.remove("cluster_name"); + responseMap.remove("_nodes"); + return responseMap; + } + + protected List> parseNodeStatsResponse(String responseBody) throws IOException { + @SuppressWarnings("unchecked") + Map responseMap = (Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("nodes"); + + @SuppressWarnings("unchecked") + List> nodeResponses = responseMap.keySet() + .stream() + .map(key -> (Map) responseMap.get(key)) + .collect(Collectors.toList()); + + return nodeResponses; + } + + public Object getNestedValue(Map map, String path) { + String[] keys = path.split("\\."); + return getNestedValueHelper(map, keys, 0); + } + + private Object getNestedValueHelper(Map map, String[] keys, int depth) { + if (map == null) { + return null; + } + + Object value = map.get(keys[depth]); + + if (depth == keys.length - 1) { + return value; + } + + if (value instanceof Map) { + Map nestedMap = (Map) value; + return getNestedValueHelper(nestedMap, keys, depth + 1); + } + + return null; + } +}