diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 9a4ef6966f500..6cd9af6407a8b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -16,8 +16,6 @@ import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.Queries; @@ -31,7 +29,6 @@ import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteAsyncAction; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; @@ -443,18 +440,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { } if (queryVectorBuilder != null) { SetOnce toSet = new SetOnce<>(); - ctx.registerUniqueAsyncAction(new QueryVectorBuilderAsyncAction(queryVectorBuilder), v -> { - toSet.set(v); - if (v == null) { - throw new IllegalArgumentException( - format( - "[%s] with name [%s] returned null query_vector", - QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), - queryVectorBuilder.getWriteableName() - ) - ); - } - }); + ctx.registerUniqueAsyncAction(new QueryVectorBuilderAsyncAction(queryVectorBuilder), toSet::set); return new KnnVectorQueryBuilder( fieldName, queryVector, @@ -679,27 +665,4 @@ public KnnVectorQueryBuilder setAutoPrefilteringEnabled(boolean isAutoPrefilteri this.isAutoPrefilteringEnabled = isAutoPrefilteringEnabled; return this; } - - private static final class QueryVectorBuilderAsyncAction extends QueryRewriteAsyncAction { - private final QueryVectorBuilder queryVectorBuilder; - - private QueryVectorBuilderAsyncAction(QueryVectorBuilder queryVectorBuilder) { - this.queryVectorBuilder = Objects.requireNonNull(queryVectorBuilder); - } - - @Override - protected void execute(Client client, ActionListener listener) { - queryVectorBuilder.buildVector(client, listener); - } - - @Override - public int doHashCode() { - return Objects.hash(queryVectorBuilder); - } - - @Override - public boolean doEquals(QueryVectorBuilderAsyncAction other) { - return Objects.equals(queryVectorBuilder, other.queryVectorBuilder); - } - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/QueryVectorBuilderAsyncAction.java b/server/src/main/java/org/elasticsearch/search/vectors/QueryVectorBuilderAsyncAction.java new file mode 100644 index 0000000000000..fa3e680076b46 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/QueryVectorBuilderAsyncAction.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.index.query.QueryRewriteAsyncAction; + +import java.util.Objects; + +import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.QUERY_VECTOR_BUILDER_FIELD; + +public final class QueryVectorBuilderAsyncAction extends QueryRewriteAsyncAction { + private final QueryVectorBuilder queryVectorBuilder; + + public QueryVectorBuilderAsyncAction(QueryVectorBuilder queryVectorBuilder) { + this.queryVectorBuilder = Objects.requireNonNull(queryVectorBuilder); + } + + @Override + protected void execute(Client client, ActionListener listener) { + queryVectorBuilder.buildVector(client, listener.delegateFailureAndWrap((l, v) -> { + if (v == null) { + throw new IllegalArgumentException( + format( + "[%s] with name [%s] returned null query_vector", + QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), + queryVectorBuilder.getWriteableName() + ) + ); + } + l.onResponse(v); + })); + } + + @Override + public int doHashCode() { + return Objects.hash(queryVectorBuilder); + } + + @Override + public boolean doEquals(QueryVectorBuilderAsyncAction other) { + return Objects.equals(queryVectorBuilder, other.queryVectorBuilder); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseInferenceRewriteAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseInferenceRewriteAction.java new file mode 100644 index 0000000000000..9942c3b405822 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseInferenceRewriteAction.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.index.query.QueryRewriteAsyncAction; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class SparseInferenceRewriteAction extends QueryRewriteAsyncAction { + private final String inferenceId; + private final String query; + + SparseInferenceRewriteAction(String inferenceId, String query) { + this.inferenceId = inferenceId; + this.query = query; + } + + @Override + protected void execute(Client client, ActionListener responseListener) { + // TODO: Move this class to `server` and update to use InferenceAction.Request + CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( + inferenceId, + List.of(query), + TextExpansionConfigUpdate.EMPTY_UPDATE, + false, + null + ); + + inferRequest.setHighPriority(true); + inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); + + executeAsyncWithOrigin( + client, + ML_ORIGIN, + CoordinatedInferenceAction.INSTANCE, + inferRequest, + responseListener.delegateFailureAndWrap((listener, inferenceResponse) -> { + List inferenceResults = inferenceResponse.getInferenceResults(); + if (inferenceResults.isEmpty()) { + listener.onFailure(new IllegalStateException("inference response contain no results")); + return; + } + if (inferenceResults.size() > 1) { + listener.onFailure(new IllegalStateException("inference response should contain only one result")); + return; + } + + if (inferenceResults.getFirst() instanceof TextExpansionResults textExpansionResults) { + listener.onResponse(textExpansionResults); + } else if (inferenceResults.getFirst() instanceof WarningInferenceResults warning) { + listener.onFailure(new IllegalStateException(warning.getWarning())); + } else { + listener.onFailure( + new IllegalArgumentException( + "expected a result of type [" + + TextExpansionResults.NAME + + "] received [" + + inferenceResults.getFirst().getWriteableName() + + "]. Is [" + + inferenceId + + "] a compatible model?" + ) + ); + } + }) + ); + } + + @Override + public int doHashCode() { + return Objects.hash(inferenceId, query); + } + + @Override + public boolean doEquals(SparseInferenceRewriteAction other) { + return Objects.equals(inferenceId, other.inferenceId) && Objects.equals(query, other.query); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index ebf422072c297..6f34dcfa71711 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -11,8 +11,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,20 +20,14 @@ import org.elasticsearch.index.mapper.vectors.TokenPruningConfig; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteAsyncAction; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import java.io.IOException; import java.util.ArrayList; @@ -45,8 +37,6 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class SparseVectorQueryBuilder extends AbstractQueryBuilder { private static final MatchNoDocsQuery EMPTY_QUERY_VECTORS = new MatchNoDocsQuery("Empty query vectors"); @@ -283,78 +273,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier); } - private static class SparseInferenceRewriteAction extends QueryRewriteAsyncAction { - - private final String inferenceId; - private final String query; - - SparseInferenceRewriteAction(String inferenceId, String query) { - this.inferenceId = inferenceId; - this.query = query; - } - - @Override - protected void execute(Client client, ActionListener responseListener) { - // TODO: Move this class to `server` and update to use InferenceAction.Request - CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( - inferenceId, - List.of(query), - TextExpansionConfigUpdate.EMPTY_UPDATE, - false, - null - ); - - inferRequest.setHighPriority(true); - inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); - - executeAsyncWithOrigin( - client, - ML_ORIGIN, - CoordinatedInferenceAction.INSTANCE, - inferRequest, - responseListener.delegateFailureAndWrap((listener, inferenceResponse) -> { - List inferenceResults = inferenceResponse.getInferenceResults(); - if (inferenceResults.isEmpty()) { - listener.onFailure(new IllegalStateException("inference response contain no results")); - return; - } - if (inferenceResults.size() > 1) { - listener.onFailure(new IllegalStateException("inference response should contain only one result")); - return; - } - - if (inferenceResults.getFirst() instanceof TextExpansionResults textExpansionResults) { - listener.onResponse(textExpansionResults); - } else if (inferenceResults.getFirst() instanceof WarningInferenceResults warning) { - listener.onFailure(new IllegalStateException(warning.getWarning())); - } else { - listener.onFailure( - new IllegalArgumentException( - "expected a result of type [" - + TextExpansionResults.NAME - + "] received [" - + inferenceResults.getFirst().getWriteableName() - + "]. Is [" - + inferenceId - + "] a compatible model?" - ) - ); - } - }) - ); - } - - @Override - public int doHashCode() { - return Objects.hash(inferenceId, query); - } - - @Override - public boolean doEquals(SparseInferenceRewriteAction other) { - return Objects.equals(inferenceId, other.inferenceId) && Objects.equals(query, other.query); - } - } - @Override protected boolean doEquals(SparseVectorQueryBuilder other) { return Objects.equals(fieldName, other.fieldName)