Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
import java.util.Optional;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableList;
import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -55,7 +62,12 @@
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.rest.RestNeuralStatsHandler;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction;
import org.opensearch.neuralsearch.transport.ClearNeuralStatsTransportAction;
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
import org.opensearch.neuralsearch.transport.NeuralStatsTransportAction;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand All @@ -64,6 +76,8 @@
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
Expand All @@ -85,6 +99,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";
public static final String NEURAL_BASE_URI = "/_plugins/_neural";

@Override
public Collection<Object> createComponents(
Expand Down Expand Up @@ -117,6 +132,28 @@ public List<QuerySpec<?>> getQueries() {
);
}

@Override
public List<RestHandler> getRestHandlers(
Settings settings,
RestController restController,
ClusterSettings clusterSettings,
IndexScopedSettings indexScopedSettings,
SettingsFilter settingsFilter,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {
RestNeuralStatsHandler restNeuralStatsHandler = new RestNeuralStatsHandler();
return ImmutableList.of(restNeuralStatsHandler);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList(
new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class),
new ActionHandler<>(ClearNeuralStatsAction.INSTANCE, ClearNeuralStatsTransportAction.class)
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(HybridQueryExecutor.getExecutorBuilder(settings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
import org.opensearch.neuralsearch.stats.NeuralStats;
import org.opensearch.neuralsearch.stats.names.StatName;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
Expand Down Expand Up @@ -71,6 +73,8 @@ public SearchRequest processRequest(SearchRequest searchRequest) {
if (queryBuilder != null) {
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
}

NeuralStats.record(StatName.NEURAL_QUERY_ENRICHER_PROCESSOR_EXECUTIONS).increment();
return searchRequest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.stats.NeuralStats;
import org.opensearch.neuralsearch.stats.names.StatName;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
Expand Down Expand Up @@ -47,6 +49,8 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
NeuralStats.record(StatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS).increment();

mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.rest;

import com.google.common.collect.ImmutableList;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.opensearch.client.node.NodeClient;
import org.opensearch.neuralsearch.stats.NeuralStatsInput;
import org.opensearch.neuralsearch.transport.ClearNeuralStatsAction;
import org.opensearch.neuralsearch.transport.ClearNeuralStatsRequest;
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
import org.opensearch.neuralsearch.transport.NeuralStatsRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestActions;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_BASE_URI;

@Log4j2
@AllArgsConstructor
public class RestNeuralStatsHandler extends BaseRestHandler {
private static final String NAME = "neural_stats_action";
public static final String CLEAR_PARAM = "_clear";

@Override
public String getName() {
return NAME;
}

@Override
public List<Route> routes() {
return ImmutableList.of(
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/{stat}"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/{stat}")
);
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
if (request.param("stat", "").equals(CLEAR_PARAM)) {
// TODO : Hacky, possible collisions. Should be refactored into separate endpoint later
String[] nodeIdsArr = null;
String nodesIdsStr = request.param("nodeId");
if (StringUtils.isNotEmpty(nodesIdsStr)) {
nodeIdsArr = nodesIdsStr.split(",");
}

ClearNeuralStatsRequest clearNeuralStatsRequest = new ClearNeuralStatsRequest(nodeIdsArr);
clearNeuralStatsRequest.timeout(request.param("timeout"));

return channel -> client.execute(
ClearNeuralStatsAction.INSTANCE,
clearNeuralStatsRequest,
new RestActions.NodesResponseRestListener<>(channel)
);
} else {
// Read inputs and convert to BaseNodesRequest with correct info configured
NeuralStatsRequest neuralStatsRequest = getRequest(request);

return channel -> client.execute(
NeuralStatsAction.INSTANCE,
neuralStatsRequest,
new RestActions.NodesResponseRestListener<>(channel)
);
}

}

/**
* Creates a NeuralStatsRequest from a RestRequest
*
* @param request Rest request
* @return NeuralStatsRequest
*/
private NeuralStatsRequest getRequest(RestRequest request) {
// parse the nodes the user wants to query
String[] nodeIdsArr = null;
String nodesIdsStr = request.param("nodeId");
if (StringUtils.isNotEmpty(nodesIdsStr)) {
nodeIdsArr = nodesIdsStr.split(",");
}

NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr, new NeuralStatsInput());
neuralStatsRequest.timeout(request.param("timeout"));

// parse the stats the customer wants to see
Set<String> statsSet = null;
String statsStr = request.param("stat");
if (StringUtils.isNotEmpty(statsStr)) {
statsSet = new HashSet<>(Arrays.asList(statsStr.split(",")));
}

if (statsSet == null) {

} else if (statsSet.size() == 1 && statsSet.contains("_all")) {

} else if (statsSet.contains(NeuralStatsRequest.ALL_STATS_KEY)) {
throw new IllegalArgumentException("Request " + request.path() + " contains _all and individual stats");
} else {
// Validate NeuralStats input is valid
// Set<String> invalidStats = new TreeSet<>();
// for (String stat : statsSet) {
// // validate request contains valid stats
// // if (!neuralStatsRequest.addStat(stat)) {
// // invalidStats.add(stat);
// // }
// }
//
// if (!invalidStats.isEmpty()) {
// throw new IllegalArgumentException(unrecognized(request, invalidStats, neuralStatsRequest.getNeuralStatsInput(), "stat"));
// }

}
log.info("pasta");
log.info(statsSet);
log.info(neuralStatsRequest.getNeuralStatsInput());
return neuralStatsRequest;
}
}
73 changes: 73 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/stats/DerivedStats.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.stats;

import org.opensearch.neuralsearch.stats.suppliers.DerivedSupplier;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.function.Function;

public class DerivedStats {
private static final String AGG_KEY_PREFIX = "all_nodes.";
private static DerivedStats INSTANCE;

public static DerivedStats instance() {
if (INSTANCE == null) {
INSTANCE = new DerivedStats();
}
return INSTANCE;
}

private final Map<String, NeuralStat<?>> derivedStatsMap;
private Map<String, Long> aggregatedNodeResponse;

public DerivedStats() {
this.derivedStatsMap = new ConcurrentSkipListMap<>();
register("derived.cluster_version", DerivedStats::clusterVersion);
}

public Map<String, Long> aggregateNodesResponses(List<Map<String, Long>> nodeResponses) {
Map<String, Long> summedMap = new HashMap<>();
for (Map<String, Long> map : nodeResponses) {
for (Map.Entry<String, Long> entry : map.entrySet()) {
summedMap.merge(AGG_KEY_PREFIX + entry.getKey(), entry.getValue(), Long::sum);
}
}
return summedMap;
}

public Map<String, Object> addDerivedStats(List<Map<String, Long>> nodeResponses) {
// Reference to provide derived methods access to node Responses
this.aggregatedNodeResponse = aggregateNodesResponses(nodeResponses);

Map<String, Object> computedDerivedStats = new TreeMap<>();
for (Map.Entry<String, NeuralStat<?>> neuralStatEntry : derivedStatsMap.entrySet()) {
computedDerivedStats.put(neuralStatEntry.getKey(), neuralStatEntry.getValue().getValue());
}

computedDerivedStats.putAll(aggregatedNodeResponse);
// Reset reference to not store
this.aggregatedNodeResponse = null;
return computedDerivedStats;
}

private void register(String statPath, Function<DerivedStats, ?> derivedMethod) {
if (derivedStatsMap.containsKey(statPath)) {
// Validation error here
return;
}
NeuralStat<?> neuralStat = new NeuralStat<>(new DerivedSupplier<>(this, derivedMethod));
derivedStatsMap.put(statPath, neuralStat);
}

private String clusterVersion() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().toString();
}
}
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/stats/NeuralStat.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.stats;

import org.opensearch.neuralsearch.stats.suppliers.CounterSupplier;

import java.util.function.Supplier;

public class NeuralStat<T> {
private Supplier<T> supplier;

public NeuralStat(Supplier<T> supplier) {
this.supplier = supplier;
}

public T getValue() {
return supplier.get();
}

/**
* Increments the supplier if it can be incremented
*/
public void increment() {
if (supplier instanceof CounterSupplier) {
((CounterSupplier) supplier).increment();
}
}

/**
* Decrease the supplier if it can be decreased.
*/
public void decrement() {
if (supplier instanceof CounterSupplier) {
((CounterSupplier) supplier).decrement();
}
}
}
Loading