diff --git a/server/src/main/resources/transport/definitions/referable/get_inference_fields_action_as_indices_action.csv b/server/src/main/resources/transport/definitions/referable/get_inference_fields_action_as_indices_action.csv new file mode 100644 index 0000000000000..a8a5fda7760c2 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/get_inference_fields_action_as_indices_action.csv @@ -0,0 +1 @@ +9260000,9250001 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 94a47e0878c87..ea5804219d2b8 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -initial_9.3.0,9250000 +get_inference_fields_action_as_indices_action,9250001 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv new file mode 100644 index 0000000000000..d3954517fe8c3 --- /dev/null +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -0,0 +1 @@ +get_inference_fields_action_as_indices_action,9260000 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsInternalAction.java similarity index 83% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsInternalAction.java index ddd1cbce9aa6d..ca1dcfd14a95a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceFieldsInternalAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.RemoteClusterActionType; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -24,11 +25,11 @@ import org.elasticsearch.inference.InferenceResults; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -42,20 +43,27 @@ * fields can be gathered more directly using {@link IndexMetadata#getMatchingInferenceFields}. *

*/ -public class GetInferenceFieldsAction extends ActionType { - public static final GetInferenceFieldsAction INSTANCE = new GetInferenceFieldsAction(); +public class GetInferenceFieldsInternalAction extends ActionType { + public static final GetInferenceFieldsInternalAction INSTANCE = new GetInferenceFieldsInternalAction(); public static final RemoteClusterActionType REMOTE_TYPE = new RemoteClusterActionType<>(INSTANCE.name(), Response::new); + // This is a defunct transport version for when GetInferenceFieldsInternalAction was an internal cluster action. This was the case only + // for 9.3.0 development builds, so there should be no real-world deployments using GetInferenceFieldsInternalAction as an internal + // cluster action. Therefore, this transport version can be disregarded. public static final TransportVersion GET_INFERENCE_FIELDS_ACTION_TV = TransportVersion.fromName("get_inference_fields_action"); - public static final String NAME = "cluster:internal/xpack/inference/fields/get"; + public static final TransportVersion GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV = TransportVersion.fromName( + "get_inference_fields_action_as_indices_action" + ); - public GetInferenceFieldsAction() { + public static final String NAME = "indices:admin/inference/fields/get"; + + public GetInferenceFieldsInternalAction() { super(NAME); } - public static class Request extends ActionRequest { - private final Set indices; + public static class Request extends ActionRequest implements IndicesRequest.Replaceable { + private String[] indices; private final Map fields; private final boolean resolveWildcards; private final boolean useDefaultFields; @@ -63,10 +71,10 @@ public static class Request extends ActionRequest { private final IndicesOptions indicesOptions; /** - * An overload of {@link #Request(Set, Map, boolean, boolean, String, IndicesOptions)} that uses {@link IndicesOptions#DEFAULT} + * An overload of {@link #Request(String[], Map, boolean, boolean, String, IndicesOptions)} that uses {@link IndicesOptions#DEFAULT} */ public Request( - Set indices, + String[] indices, Map fields, boolean resolveWildcards, boolean useDefaultFields, @@ -97,7 +105,7 @@ public Request( * @param indicesOptions The {@link IndicesOptions} to use when resolving indices. */ public Request( - Set indices, + String[] indices, Map fields, boolean resolveWildcards, boolean useDefaultFields, @@ -114,7 +122,7 @@ public Request( public Request(StreamInput in) throws IOException { super(in); - this.indices = in.readCollectionAsSet(StreamInput::readString); + this.indices = in.readStringArray(); this.fields = in.readMap(StreamInput::readFloat); this.resolveWildcards = in.readBoolean(); this.useDefaultFields = in.readBoolean(); @@ -125,7 +133,7 @@ public Request(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeStringCollection(indices); + out.writeStringArray(indices); out.writeMap(fields, StreamOutput::writeFloat); out.writeBoolean(resolveWildcards); out.writeBoolean(useDefaultFields); @@ -156,11 +164,18 @@ public ActionRequestValidationException validate() { return validationException; } - public Set getIndices() { - return Collections.unmodifiableSet(indices); + @Override + public Request indices(String... indices) { + this.indices = indices; + return this; + } + + @Override + public String[] indices() { + return indices; } - public Map getFields() { + public Map fields() { return Collections.unmodifiableMap(fields); } @@ -172,20 +187,26 @@ public boolean useDefaultFields() { return useDefaultFields; } - public String getQuery() { + public String query() { return query; } - public IndicesOptions getIndicesOptions() { + @Override + public IndicesOptions indicesOptions() { return indicesOptions; } + @Override + public boolean includeDataStreams() { + return true; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(indices, request.indices) + return Arrays.equals(indices, request.indices) && Objects.equals(fields, request.fields) && resolveWildcards == request.resolveWildcards && useDefaultFields == request.useDefaultFields @@ -195,7 +216,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(indices, fields, resolveWildcards, useDefaultFields, query, indicesOptions); + return Objects.hash(Arrays.hashCode(indices), fields, resolveWildcards, useDefaultFields, query, indicesOptions); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java index e1ae19acc4bb1..6a221d33041b5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ccr.action.PutFollowAction; import org.elasticsearch.xpack.core.ccr.action.UnfollowAction; import org.elasticsearch.xpack.core.ilm.action.ExplainLifecycleAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.rollup.action.GetRollupIndexCapsAction; import org.elasticsearch.xpack.core.security.support.Automatons; import org.elasticsearch.xpack.core.transform.action.GetCheckpointAction; @@ -84,7 +85,8 @@ public final class IndexPrivilege extends Privilege { private static final Automaton READ_AUTOMATON = patterns( "indices:data/read/*", ResolveIndexAction.NAME, - TransportResolveClusterAction.NAME + TransportResolveClusterAction.NAME, + GetInferenceFieldsInternalAction.NAME // cross-cluster inference for semantic search ); private static final Automaton READ_FAILURE_STORE_AUTOMATON = patterns("indices:data/read/*", ResolveIndexAction.NAME); private static final Automaton READ_CROSS_CLUSTER_AUTOMATON = patterns( diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java index fbb274c632c7a..df9ccdf0ae1f0 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java @@ -18,7 +18,7 @@ import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.transport.RemoteClusterService; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -47,8 +47,6 @@ public class GetInferenceFieldsCrossClusterIT extends AbstractMultiClustersTestC private static final String INFERENCE_ID = "test-inference-id"; private static final Map INFERENCE_ENDPOINT_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); - private boolean clustersConfigured = false; - @Override protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); @@ -71,23 +69,20 @@ protected Collection> nodePlugins(String clusterAlias) { @Before public void configureClusters() throws Exception { - if (clustersConfigured == false) { - setupTwoClusters(); - clustersConfigured = true; - } + setupTwoClusters(); } public void testRemoteIndex() { - Consumer assertFailedRequest = r -> { + Consumer assertFailedRequest = r -> { IllegalArgumentException e = assertThrows( IllegalArgumentException.class, - () -> client().execute(GetInferenceFieldsAction.INSTANCE, r).actionGet(TEST_REQUEST_TIMEOUT) + () -> client().execute(GetInferenceFieldsInternalAction.INSTANCE, r).actionGet(TEST_REQUEST_TIMEOUT) ); - assertThat(e.getMessage(), containsString("GetInferenceFieldsAction does not support remote indices")); + assertThat(e.getMessage(), containsString("GetInferenceFieldsInternalAction does not support remote indices")); }; - var concreteIndexRequest = new GetInferenceFieldsAction.Request( - Set.of(REMOTE_CLUSTER + ":test-index"), + var concreteIndexRequest = new GetInferenceFieldsInternalAction.Request( + new String[] { REMOTE_CLUSTER + ":test-index" }, Map.of(), false, false, @@ -95,10 +90,22 @@ public void testRemoteIndex() { ); assertFailedRequest.accept(concreteIndexRequest); - var wildcardIndexRequest = new GetInferenceFieldsAction.Request(Set.of(REMOTE_CLUSTER + ":*"), Map.of(), false, false, "foo"); + var wildcardIndexRequest = new GetInferenceFieldsInternalAction.Request( + new String[] { REMOTE_CLUSTER + ":*" }, + Map.of(), + false, + false, + "foo" + ); assertFailedRequest.accept(wildcardIndexRequest); - var wildcardClusterAndIndexRequest = new GetInferenceFieldsAction.Request(Set.of("*:*"), Map.of(), false, false, "foo"); + var wildcardClusterAndIndexRequest = new GetInferenceFieldsInternalAction.Request( + new String[] { "*:*" }, + Map.of(), + false, + false, + "foo" + ); assertFailedRequest.accept(wildcardClusterAndIndexRequest); } @@ -109,15 +116,15 @@ public void testRemoteClusterAction() { RemoteClusterService.DisconnectedStrategy.RECONNECT_IF_DISCONNECTED ); - var request = new GetInferenceFieldsAction.Request( - Set.of(INDEX_NAME), + var request = new GetInferenceFieldsInternalAction.Request( + new String[] { INDEX_NAME }, generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD)), false, false, "foo" ); - PlainActionFuture future = new PlainActionFuture<>(); - remoteClusterClient.execute(GetInferenceFieldsAction.REMOTE_TYPE, request, future); + PlainActionFuture future = new PlainActionFuture<>(); + remoteClusterClient.execute(GetInferenceFieldsInternalAction.REMOTE_TYPE, request, future); var response = future.actionGet(TEST_REQUEST_TIMEOUT); assertInferenceFieldsMap( diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java index 26dd84509bb7b..8b15fecb33983 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexNotFoundException; @@ -20,14 +21,13 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.junit.Before; import java.io.IOException; import java.util.Collection; @@ -51,7 +51,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) +@ESIntegTestCase.SuiteScopeTestCase public class GetInferenceFieldsIT extends ESIntegTestCase { private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of( @@ -70,7 +70,7 @@ public class GetInferenceFieldsIT extends ESIntegTestCase { private static final String INDEX_1 = "index-1"; private static final String INDEX_2 = "index-2"; - private static final Set ALL_INDICES = Set.of(INDEX_1, INDEX_2); + private static final String[] ALL_INDICES = new String[] { INDEX_1, INDEX_2 }; private static final String INDEX_ALIAS = "index-alias"; private static final String INFERENCE_FIELD_1 = "inference-field-1"; @@ -105,8 +105,6 @@ public class GetInferenceFieldsIT extends ESIntegTestCase { MlDenseEmbeddingResults.class ); - private boolean clusterConfigured = false; - @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); @@ -117,13 +115,10 @@ protected Collection> nodePlugins() { return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class); } - @Before - public void setUpCluster() throws Exception { - if (clusterConfigured == false) { - createInferenceEndpoints(); - createTestIndices(); - clusterConfigured = true; - } + @Override + protected void setupSuiteScopeCluster() throws Exception { + createInferenceEndpoints(); + createTestIndices(); } public void testNullQuery() { @@ -140,7 +135,7 @@ public void testBlankQuery() { public void testFieldWeight() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( + new GetInferenceFieldsInternalAction.Request( ALL_INDICES, Map.of(INFERENCE_FIELD_1, 2.0f, "inference-*", 1.5f, TEXT_FIELD_1, 1.75f), false, @@ -159,7 +154,7 @@ public void testFieldWeight() { public void testNoInferenceFields() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( + new GetInferenceFieldsInternalAction.Request( ALL_INDICES, generateDefaultWeightFieldMap(Set.of(TEXT_FIELD_1, TEXT_FIELD_2)), false, @@ -173,13 +168,13 @@ public void testNoInferenceFields() { public void testResolveFieldWildcards() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), true, false, "foo"), + new GetInferenceFieldsInternalAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), true, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( + new GetInferenceFieldsInternalAction.Request( ALL_INDICES, Map.of("*-field-1", 2.0f, "*-1", 1.75f, "inference-*-3", 2.0f), true, @@ -204,39 +199,53 @@ public void testResolveFieldWildcards() { public void testUseDefaultFields() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of(INDEX_1), Map.of(), true, true, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { INDEX_1 }, Map.of(), true, true, "foo"), Map.of(INDEX_1, Set.of(new InferenceFieldWithTestMetadata(INFERENCE_FIELD_1, SPARSE_EMBEDDING_INFERENCE_ID, 5.0f))), filterExpectedInferenceResults(ALL_EXPECTED_INFERENCE_RESULTS, Set.of(SPARSE_EMBEDDING_INFERENCE_ID)) ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of(INDEX_2), Map.of(), true, true, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { INDEX_2 }, Map.of(), true, true, "foo"), Map.of(INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); } public void testMissingIndexName() { - Set indicesWithIndex1 = Set.of(INDEX_1, "missing-index"); + String[] indicesWithIndex1 = new String[] { INDEX_1, "missing-index" }; assertFailedRequest( - new GetInferenceFieldsAction.Request(indicesWithIndex1, ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(indicesWithIndex1, ALL_FIELDS, false, false, "foo"), IndexNotFoundException.class, e -> assertThat(e.getMessage(), containsString("no such index [missing-index]")) ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(indicesWithIndex1, ALL_FIELDS, false, false, "foo", IndicesOptions.LENIENT_EXPAND_OPEN), + new GetInferenceFieldsInternalAction.Request( + indicesWithIndex1, + ALL_FIELDS, + false, + false, + "foo", + IndicesOptions.LENIENT_EXPAND_OPEN + ), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); - Set indicesWithoutIndex1 = Set.of("missing-index"); + String[] indicesWithoutIndex1 = new String[] { "missing-index" }; assertFailedRequest( - new GetInferenceFieldsAction.Request(indicesWithoutIndex1, ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(indicesWithoutIndex1, ALL_FIELDS, false, false, "foo"), IndexNotFoundException.class, e -> assertThat(e.getMessage(), containsString("no such index [missing-index]")) ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(indicesWithoutIndex1, ALL_FIELDS, false, false, "foo", IndicesOptions.LENIENT_EXPAND_OPEN), + new GetInferenceFieldsInternalAction.Request( + indicesWithoutIndex1, + ALL_FIELDS, + false, + false, + "foo", + IndicesOptions.LENIENT_EXPAND_OPEN + ), Map.of(), Map.of() ); @@ -244,13 +253,25 @@ public void testMissingIndexName() { public void testMissingFieldName() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("missing-field")), false, false, "foo"), + new GetInferenceFieldsInternalAction.Request( + ALL_INDICES, + generateDefaultWeightFieldMap(Set.of("missing-field")), + false, + false, + "foo" + ), Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), Map.of() ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("missing-*")), true, false, "foo"), + new GetInferenceFieldsInternalAction.Request( + ALL_INDICES, + generateDefaultWeightFieldMap(Set.of("missing-*")), + true, + false, + "foo" + ), Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), Map.of() ); @@ -259,14 +280,21 @@ public void testMissingFieldName() { public void testNoIndices() { // By default, an empty index set will be interpreted as _all assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of(), ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(Strings.EMPTY_ARRAY, ALL_FIELDS, false, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); // We can provide an IndicesOptions that changes this behavior to interpret an empty index set as no indices assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of(), ALL_FIELDS, false, false, "foo", IndicesOptions.STRICT_NO_EXPAND_FORBID_CLOSED), + new GetInferenceFieldsInternalAction.Request( + Strings.EMPTY_ARRAY, + ALL_FIELDS, + false, + false, + "foo", + IndicesOptions.STRICT_NO_EXPAND_FORBID_CLOSED + ), Map.of(), Map.of() ); @@ -275,15 +303,15 @@ public void testNoIndices() { public void testAllIndices() { // By default, _all expands to all indices assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of("_all"), ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { "_all" }, ALL_FIELDS, false, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); // We can provide an IndicesOptions that changes this behavior to interpret it as no indices assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( - Set.of("_all"), + new GetInferenceFieldsInternalAction.Request( + new String[] { "_all" }, ALL_FIELDS, false, false, @@ -297,7 +325,7 @@ public void testAllIndices() { public void testIndexAlias() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of(INDEX_ALIAS), ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { INDEX_ALIAS }, ALL_FIELDS, false, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); @@ -305,20 +333,20 @@ public void testIndexAlias() { public void testResolveIndexWildcards() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of("index-*"), ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { "index-*" }, ALL_FIELDS, false, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(Set.of("*-1"), ALL_FIELDS, false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(new String[] { "*-1" }, ALL_FIELDS, false, false, "foo"), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS), ALL_EXPECTED_INFERENCE_RESULTS ); assertFailedRequest( - new GetInferenceFieldsAction.Request( - Set.of("index-*"), + new GetInferenceFieldsInternalAction.Request( + new String[] { "index-*" }, ALL_FIELDS, false, false, @@ -332,7 +360,7 @@ public void testResolveIndexWildcards() { public void testNoFields() { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, Map.of(), false, false, "foo"), + new GetInferenceFieldsInternalAction.Request(ALL_INDICES, Map.of(), false, false, "foo"), Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), Map.of() ); @@ -344,17 +372,17 @@ public void testInvalidRequest() { ); assertFailedRequest( - new GetInferenceFieldsAction.Request(null, Map.of(), false, false, null), + new GetInferenceFieldsInternalAction.Request(null, Map.of(), false, false, null), ActionRequestValidationException.class, e -> validator.accept(e, List.of("indices must not be null")) ); assertFailedRequest( - new GetInferenceFieldsAction.Request(Set.of(), null, false, false, null), + new GetInferenceFieldsInternalAction.Request(Strings.EMPTY_ARRAY, null, false, false, null), ActionRequestValidationException.class, e -> validator.accept(e, List.of("fields must not be null")) ); assertFailedRequest( - new GetInferenceFieldsAction.Request(null, null, false, false, null), + new GetInferenceFieldsInternalAction.Request(null, null, false, false, null), ActionRequestValidationException.class, e -> validator.accept(e, List.of("indices must not be null", "fields must not be null")) ); @@ -362,7 +390,7 @@ public void testInvalidRequest() { Map fields = new HashMap<>(); fields.put(INFERENCE_FIELD_1, null); assertFailedRequest( - new GetInferenceFieldsAction.Request(Set.of(), fields, false, false, null), + new GetInferenceFieldsInternalAction.Request(Strings.EMPTY_ARRAY, fields, false, false, null), ActionRequestValidationException.class, e -> validator.accept(e, List.of("weight for field [" + INFERENCE_FIELD_1 + "] must not be null")) ); @@ -370,7 +398,7 @@ public void testInvalidRequest() { private void explicitIndicesAndFieldsTestCase(String query) { assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, ALL_FIELDS, false, false, query), + new GetInferenceFieldsInternalAction.Request(ALL_INDICES, ALL_FIELDS, false, false, query), Map.of(INDEX_1, INDEX_1_EXPECTED_INFERENCE_FIELDS, INDEX_2, INDEX_2_EXPECTED_INFERENCE_FIELDS), query == null ? Map.of() : ALL_EXPECTED_INFERENCE_RESULTS ); @@ -380,7 +408,7 @@ private void explicitIndicesAndFieldsTestCase(String query) { Set.of(SPARSE_EMBEDDING_INFERENCE_ID) ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( + new GetInferenceFieldsInternalAction.Request( ALL_INDICES, generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD_3)), false, @@ -397,8 +425,8 @@ private void explicitIndicesAndFieldsTestCase(String query) { ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request( - Set.of(INDEX_1), + new GetInferenceFieldsInternalAction.Request( + new String[] { INDEX_1 }, generateDefaultWeightFieldMap(Set.of(INFERENCE_FIELD_3)), false, false, @@ -409,7 +437,7 @@ private void explicitIndicesAndFieldsTestCase(String query) { ); assertSuccessfulRequest( - new GetInferenceFieldsAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), false, false, query), + new GetInferenceFieldsInternalAction.Request(ALL_INDICES, generateDefaultWeightFieldMap(Set.of("*")), false, false, query), Map.of(INDEX_1, Set.of(), INDEX_2, Set.of()), Map.of() ); @@ -471,12 +499,12 @@ private void addTextField(String fieldName, XContentBuilder mapping) throws IOEx mapping.endObject(); } - private static GetInferenceFieldsAction.Response executeRequest(GetInferenceFieldsAction.Request request) { - return client().execute(GetInferenceFieldsAction.INSTANCE, request).actionGet(TEST_REQUEST_TIMEOUT); + private static GetInferenceFieldsInternalAction.Response executeRequest(GetInferenceFieldsInternalAction.Request request) { + return client().execute(GetInferenceFieldsInternalAction.INSTANCE, request).actionGet(TEST_REQUEST_TIMEOUT); } private static void assertSuccessfulRequest( - GetInferenceFieldsAction.Request request, + GetInferenceFieldsInternalAction.Request request, Map> expectedInferenceFields, Map> expectedInferenceResults ) { @@ -486,7 +514,7 @@ private static void assertSuccessfulRequest( } private static void assertFailedRequest( - GetInferenceFieldsAction.Request request, + GetInferenceFieldsInternalAction.Request request, Class expectedException, Consumer exceptionValidator ) { @@ -495,13 +523,13 @@ private static void assertFailedRequest( } static void assertInferenceFieldsMap( - Map> inferenceFieldsMap, + Map> inferenceFieldsMap, Map> expectedInferenceFields ) { assertThat(inferenceFieldsMap.size(), equalTo(expectedInferenceFields.size())); for (var entry : inferenceFieldsMap.entrySet()) { String indexName = entry.getKey(); - List indexInferenceFields = entry.getValue(); + List indexInferenceFields = entry.getValue(); Set expectedIndexInferenceFields = expectedInferenceFields.get(indexName); assertThat(expectedIndexInferenceFields, notNullValue()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index b7075fa6dc7f3..d8f8e84a38030 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -67,7 +67,7 @@ import org.elasticsearch.xpack.core.inference.action.EmbeddingAction; import org.elasticsearch.xpack.core.inference.action.GetCCMConfigurationAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.GetRerankerWindowSizeAction; @@ -84,7 +84,7 @@ import org.elasticsearch.xpack.inference.action.TransportEmbeddingAction; import org.elasticsearch.xpack.inference.action.TransportGetCCMConfigurationAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; -import org.elasticsearch.xpack.inference.action.TransportGetInferenceFieldsAction; +import org.elasticsearch.xpack.inference.action.TransportGetInferenceFieldsInternalAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction; import org.elasticsearch.xpack.inference.action.TransportGetRerankerWindowSizeAction; @@ -289,7 +289,7 @@ public List getActions() { new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class), new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class), new ActionHandler(AuthorizationTaskExecutor.Action.INSTANCE, AuthorizationTaskExecutor.Action.class), - new ActionHandler(GetInferenceFieldsAction.INSTANCE, TransportGetInferenceFieldsAction.class), + new ActionHandler(GetInferenceFieldsInternalAction.INSTANCE, TransportGetInferenceFieldsInternalAction.class), new ActionHandler(EmbeddingAction.INSTANCE, TransportEmbeddingAction.class) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsInternalAction.java similarity index 79% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsInternalAction.java index 57fd31fab754d..b791cf77f73c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceFieldsInternalAction.java @@ -31,7 +31,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -45,9 +45,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -public class TransportGetInferenceFieldsAction extends HandledTransportAction< - GetInferenceFieldsAction.Request, - GetInferenceFieldsAction.Response> { +public class TransportGetInferenceFieldsInternalAction extends HandledTransportAction< + GetInferenceFieldsInternalAction.Request, + GetInferenceFieldsInternalAction.Response> { private final TransportService transportService; private final ClusterService clusterService; @@ -56,7 +56,7 @@ public class TransportGetInferenceFieldsAction extends HandledTransportAction< private final Client client; @Inject - public TransportGetInferenceFieldsAction( + public TransportGetInferenceFieldsInternalAction( TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, @@ -65,10 +65,10 @@ public TransportGetInferenceFieldsAction( Client client ) { super( - GetInferenceFieldsAction.NAME, + GetInferenceFieldsInternalAction.NAME, transportService, actionFilters, - GetInferenceFieldsAction.Request::new, + GetInferenceFieldsInternalAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); this.transportService = transportService; @@ -81,37 +81,33 @@ public TransportGetInferenceFieldsAction( @Override protected void doExecute( Task task, - GetInferenceFieldsAction.Request request, - ActionListener listener + GetInferenceFieldsInternalAction.Request request, + ActionListener listener ) { - final Set indices = request.getIndices(); - final Map fields = request.getFields(); + final String[] indices = request.indices(); + final Map fields = request.fields(); final boolean resolveWildcards = request.resolveWildcards(); final boolean useDefaultFields = request.useDefaultFields(); - final String query = request.getQuery(); - final IndicesOptions indicesOptions = request.getIndicesOptions(); + final String query = request.query(); + final IndicesOptions indicesOptions = request.indicesOptions(); try { Map groupedIndices = transportService.getRemoteClusterService() - .groupIndices(indicesOptions, indices.toArray(new String[0]), true); + .groupIndices(indicesOptions, indices, true); OriginalIndices localIndices = groupedIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); if (groupedIndices.isEmpty() == false) { - throw new IllegalArgumentException("GetInferenceFieldsAction does not support remote indices"); + throw new IllegalArgumentException("GetInferenceFieldsInternalAction does not support remote indices"); } ProjectState projectState = projectResolver.getProjectState(clusterService.state()); String[] concreteLocalIndices = indexNameExpressionResolver.concreteIndexNames(projectState.metadata(), localIndices); - Map> inferenceFieldsMap = new HashMap<>( + Map> inferenceFieldsMap = new HashMap<>( concreteLocalIndices.length ); Arrays.stream(concreteLocalIndices).forEach(index -> { - List inferenceFieldMetadataList = getInferenceFieldMetadata( - index, - fields, - resolveWildcards, - useDefaultFields - ); + List inferenceFieldMetadataList = + getInferenceFieldMetadata(index, fields, resolveWildcards, useDefaultFields); inferenceFieldsMap.put(index, inferenceFieldMetadataList); }); @@ -124,14 +120,14 @@ protected void doExecute( getInferenceResults(query, inferenceIds, inferenceFieldsMap, listener); } else { - listener.onResponse(new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of())); + listener.onResponse(new GetInferenceFieldsInternalAction.Response(inferenceFieldsMap, Map.of())); } } catch (Exception e) { listener.onFailure(e); } } - private List getInferenceFieldMetadata( + private List getInferenceFieldMetadata( String index, Map fields, boolean resolveWildcards, @@ -149,18 +145,18 @@ private List getInferen ); return matchingInferenceFieldMap.entrySet() .stream() - .map(e -> new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata(e.getKey(), e.getValue())) + .map(e -> new GetInferenceFieldsInternalAction.ExtendedInferenceFieldMetadata(e.getKey(), e.getValue())) .toList(); } private void getInferenceResults( String query, Set inferenceIds, - Map> inferenceFieldsMap, - ActionListener listener + Map> inferenceFieldsMap, + ActionListener listener ) { if (inferenceIds.isEmpty()) { - listener.onResponse(new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of())); + listener.onResponse(new GetInferenceFieldsInternalAction.Response(inferenceFieldsMap, Map.of())); return; } @@ -170,7 +166,10 @@ private void getInferenceResults( Map inferenceResultsMap = new HashMap<>(inferenceIds.size()); c.forEach(t -> inferenceResultsMap.put(t.v1(), t.v2())); - GetInferenceFieldsAction.Response response = new GetInferenceFieldsAction.Response(inferenceFieldsMap, inferenceResultsMap); + GetInferenceFieldsInternalAction.Response response = new GetInferenceFieldsInternalAction.Response( + inferenceFieldsMap, + inferenceResultsMap + ); l.onResponse(response); }) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceQueryUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceQueryUtils.java index a6fbc708911b0..61ab6e4423c3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceQueryUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceQueryUtils.java @@ -33,7 +33,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; @@ -56,7 +56,7 @@ import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction.GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV; public final class InferenceQueryUtils { /** @@ -163,7 +163,8 @@ public static void getInferenceInfo( } SetOnce localInferenceInfoSupplier = new SetOnce<>(); - SetOnce>> remoteInferenceInfoSupplier = new SetOnce<>(); + SetOnce>> remoteInferenceInfoSupplier = + new SetOnce<>(); try (var refs = new RefCountingListener(inferenceInfoListener.delegateFailureAndWrap((l, v) -> { l.onResponse(combineLocalAndRemoteInferenceInfo(localInferenceInfoSupplier.get(), remoteInferenceInfoSupplier.get())); @@ -172,8 +173,9 @@ public static void getInferenceInfo( getLocalInferenceInfo(queryRewriteContext, inferenceInfoRequest, localInferenceInfoListener); if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false && queryRewriteContext.isCcsMinimizeRoundTrips() == false) { - ActionListener>> remoteInferenceInfoListener = refs - .acquire(remoteInferenceInfoSupplier::set); + ActionListener< + Map>> remoteInferenceInfoListener = refs + .acquire(remoteInferenceInfoSupplier::set); if (inferenceInfoRequest.alwaysSkipRemotes() == false) { getRemoteInferenceInfo(queryRewriteContext, inferenceInfoRequest, remoteInferenceInfoListener); @@ -210,7 +212,7 @@ public static void ccsMinimizeRoundTripsFalseSupportCheck( String queryName ) { ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); - if (inferenceInfo.minTransportVersion().supports(GET_INFERENCE_FIELDS_ACTION_TV) == false + if (inferenceInfo.minTransportVersion().supports(GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV) == false && inferenceInfo.inferenceFieldCount() > 0 && resolvedIndices.getRemoteClusterIndices().isEmpty() == false && queryRewriteContext.isCcsMinimizeRoundTrips() == false) { @@ -220,7 +222,7 @@ public static void ccsMinimizeRoundTripsFalseSupportCheck( + queryName + " query cross-cluster search when" + " [ccs_minimize_roundtrips] is false. Please update all clusters to at least " - + GET_INFERENCE_FIELDS_ACTION_TV.toReleaseVersion() + + GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV.toReleaseVersion() ); } } @@ -297,10 +299,10 @@ private static void getLocalInferenceInfo( private static void getRemoteInferenceInfo( QueryRewriteContext queryRewriteContext, InferenceInfoRequest inferenceInfoRequest, - ActionListener>> remoteInferenceInfoListener + ActionListener>> remoteInferenceInfoListener ) { var remoteIndices = queryRewriteContext.getResolvedIndices().getRemoteClusterIndices(); - GroupedActionListener>> gal = + GroupedActionListener>> gal = createRemoteInferenceInfoGroupedActionListener(remoteIndices.size(), remoteInferenceInfoListener); // When an inference ID override is set, inference is only performed on the local cluster. Set the query to null in this case @@ -310,8 +312,8 @@ private static void getRemoteInferenceInfo( String clusterAlias = entry.getKey(); OriginalIndices originalIndices = entry.getValue(); - GetInferenceFieldsAction.Request request = new GetInferenceFieldsAction.Request( - Set.of(originalIndices.indices()), + GetInferenceFieldsInternalAction.Request request = new GetInferenceFieldsInternalAction.Request( + originalIndices.indices(), inferenceInfoRequest.fields(), inferenceInfoRequest.resolveWildcards(), inferenceInfoRequest.useDefaultFields(), @@ -325,10 +327,10 @@ private static void getRemoteInferenceInfo( private static void getRemoteTransportVersion( QueryRewriteContext queryRewriteContext, - ActionListener>> remoteInferenceInfoListener + ActionListener>> remoteInferenceInfoListener ) { var remoteIndices = queryRewriteContext.getResolvedIndices().getRemoteClusterIndices(); - GroupedActionListener>> gal = + GroupedActionListener>> gal = createRemoteInferenceInfoGroupedActionListener(remoteIndices.size(), remoteInferenceInfoListener); for (var entry : remoteIndices.entrySet()) { @@ -339,7 +341,7 @@ private static void getRemoteTransportVersion( private static InferenceInfo combineLocalAndRemoteInferenceInfo( InferenceInfo localInferenceInfo, - Map> remoteInferenceInfo + Map> remoteInferenceInfo ) { int totalInferenceFieldCount = localInferenceInfo.inferenceFieldCount; int totalIndexCount = localInferenceInfo.indexCount; @@ -461,14 +463,16 @@ private static GroupedActionListener>> + GroupedActionListener>> createRemoteInferenceInfoGroupedActionListener( int remoteClusterCount, - ActionListener>> remoteInferenceInfoListener + ActionListener>> remoteInferenceInfoListener ) { return new GroupedActionListener<>(remoteClusterCount, remoteInferenceInfoListener.delegateFailureAndWrap((l, c) -> { - Map> remoteInferenceInfoMap = new HashMap<>(c.size()); + Map> remoteInferenceInfoMap = new HashMap<>( + c.size() + ); c.forEach(remoteInferenceInfoMap::putAll); l.onResponse(remoteInferenceInfoMap); })); @@ -658,12 +662,12 @@ private void executeInferenceRequest( } private static final class RemoteInferenceInfoAsyncAction extends QueryRewriteRemoteAsyncAction< - Map>, + Map>, RemoteInferenceInfoAsyncAction> { - private final GetInferenceFieldsAction.Request request; + private final GetInferenceFieldsInternalAction.Request request; - private RemoteInferenceInfoAsyncAction(String clusterAlias, GetInferenceFieldsAction.Request request) { + private RemoteInferenceInfoAsyncAction(String clusterAlias, GetInferenceFieldsInternalAction.Request request) { super(clusterAlias); this.request = request; } @@ -672,27 +676,33 @@ private RemoteInferenceInfoAsyncAction(String clusterAlias, GetInferenceFieldsAc protected void execute( RemoteClusterClient client, ThreadContext threadContext, - ActionListener>> listener + ActionListener>> listener ) { final String clusterAlias = getClusterAlias(); - executeAsyncWithOrigin(threadContext, ML_ORIGIN, request, listener, (req, l1) -> { - client.getConnection(req, l1.delegateFailureAndWrap((l2, c) -> { - TransportVersion transportVersion = c.getTransportVersion(); - if (transportVersion.supports(GET_INFERENCE_FIELDS_ACTION_TV) == false) { - // Assume that no remote inference fields are queried. We must do this because we cannot throw an error here - // without breaking BwC for interception-eligible queries (ex: match/knn/sparse_vector) that don't need to be - // intercepted. We track the transport version in the response so that more thorough error checking can be - // performed with the complete output of getInferenceInfo (i.e. complete local and remote inference info). - l2.onResponse( - Map.of(clusterAlias, Tuple.tuple(new GetInferenceFieldsAction.Response(Map.of(), Map.of()), transportVersion)) - ); - } else { - client.execute(GetInferenceFieldsAction.REMOTE_TYPE, req, l2.delegateFailureAndWrap((l3, resp) -> { - l3.onResponse(Map.of(clusterAlias, Tuple.tuple(resp, transportVersion))); - })); - } - })); - }); + client.getConnection(request, listener.delegateFailureAndWrap((l1, connection) -> { + TransportVersion transportVersion = connection.getTransportVersion(); + if (transportVersion.supports(GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV) == false) { + // Assume that no remote inference fields are queried. We must do this because we cannot throw an error here + // without breaking BwC for interception-eligible queries (ex: match/knn/sparse_vector) that don't need to be + // intercepted. We track the transport version in the response so that more thorough error checking can be + // performed with the complete output of getInferenceInfo (i.e. complete local and remote inference info). + l1.onResponse( + Map.of( + clusterAlias, + Tuple.tuple(new GetInferenceFieldsInternalAction.Response(Map.of(), Map.of()), transportVersion) + ) + ); + } else { + client.execute( + connection, + GetInferenceFieldsInternalAction.REMOTE_TYPE, + request, + l1.delegateFailureAndWrap((l2, resp) -> { + l2.onResponse(Map.of(clusterAlias, Tuple.tuple(resp, transportVersion))); + }) + ); + } + })); } @Override @@ -707,7 +717,7 @@ public boolean doEquals(RemoteInferenceInfoAsyncAction other) { } private static final class RemoteTransportVersionAsyncAction extends QueryRewriteRemoteAsyncAction< - Map>, + Map>, RemoteTransportVersionAsyncAction> { private RemoteTransportVersionAsyncAction(String clusterAlias) { @@ -718,12 +728,15 @@ private RemoteTransportVersionAsyncAction(String clusterAlias) { protected void execute( RemoteClusterClient client, ThreadContext threadContext, - ActionListener>> listener + ActionListener>> listener ) { final String clusterAlias = getClusterAlias(); client.getConnection(null, listener.delegateFailureAndWrap((l, c) -> { l.onResponse( - Map.of(clusterAlias, Tuple.tuple(new GetInferenceFieldsAction.Response(Map.of(), Map.of()), c.getTransportVersion())) + Map.of( + clusterAlias, + Tuple.tuple(new GetInferenceFieldsInternalAction.Response(Map.of(), Map.of()), c.getTransportVersion()) + ) ); })); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionRequestTests.java similarity index 56% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionRequestTests.java index df23d26381814..df84474b4319f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionRequestTests.java @@ -13,26 +13,26 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.Collection; import java.util.Map; -import java.util.Set; -import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction.GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV; -public class GetInferenceFieldsActionRequestTests extends AbstractBWCWireSerializationTestCase { +public class GetInferenceFieldsInternalActionRequestTests extends AbstractBWCWireSerializationTestCase< + GetInferenceFieldsInternalAction.Request> { @Override - protected Writeable.Reader instanceReader() { - return GetInferenceFieldsAction.Request::new; + protected Writeable.Reader instanceReader() { + return GetInferenceFieldsInternalAction.Request::new; } @Override - protected GetInferenceFieldsAction.Request createTestInstance() { - return new GetInferenceFieldsAction.Request( - randomIndentifierSet(), + protected GetInferenceFieldsInternalAction.Request createTestInstance() { + return new GetInferenceFieldsInternalAction.Request( + randomIndices(), randomFields(), randomBoolean(), randomBoolean(), @@ -42,57 +42,58 @@ protected GetInferenceFieldsAction.Request createTestInstance() { } @Override - protected GetInferenceFieldsAction.Request mutateInstance(GetInferenceFieldsAction.Request instance) throws IOException { + protected GetInferenceFieldsInternalAction.Request mutateInstance(GetInferenceFieldsInternalAction.Request instance) + throws IOException { return switch (between(0, 5)) { - case 0 -> new GetInferenceFieldsAction.Request( - randomValueOtherThan(instance.getIndices(), GetInferenceFieldsActionRequestTests::randomIndentifierSet), - instance.getFields(), + case 0 -> new GetInferenceFieldsInternalAction.Request( + randomArrayOtherThan(instance.indices(), GetInferenceFieldsInternalActionRequestTests::randomIndices), + instance.fields(), instance.resolveWildcards(), instance.useDefaultFields(), - instance.getQuery(), - instance.getIndicesOptions() + instance.query(), + instance.indicesOptions() ); - case 1 -> new GetInferenceFieldsAction.Request( - instance.getIndices(), - randomValueOtherThan(instance.getFields(), GetInferenceFieldsActionRequestTests::randomFields), + case 1 -> new GetInferenceFieldsInternalAction.Request( + instance.indices(), + randomValueOtherThan(instance.fields(), GetInferenceFieldsInternalActionRequestTests::randomFields), instance.resolveWildcards(), instance.useDefaultFields(), - instance.getQuery(), - instance.getIndicesOptions() + instance.query(), + instance.indicesOptions() ); - case 2 -> new GetInferenceFieldsAction.Request( - instance.getIndices(), - instance.getFields(), + case 2 -> new GetInferenceFieldsInternalAction.Request( + instance.indices(), + instance.fields(), randomValueOtherThan(instance.resolveWildcards(), ESTestCase::randomBoolean), instance.useDefaultFields(), - instance.getQuery(), - instance.getIndicesOptions() + instance.query(), + instance.indicesOptions() ); - case 3 -> new GetInferenceFieldsAction.Request( - instance.getIndices(), - instance.getFields(), + case 3 -> new GetInferenceFieldsInternalAction.Request( + instance.indices(), + instance.fields(), instance.resolveWildcards(), randomValueOtherThan(instance.useDefaultFields(), ESTestCase::randomBoolean), - instance.getQuery(), - instance.getIndicesOptions() + instance.query(), + instance.indicesOptions() ); - case 4 -> new GetInferenceFieldsAction.Request( - instance.getIndices(), - instance.getFields(), + case 4 -> new GetInferenceFieldsInternalAction.Request( + instance.indices(), + instance.fields(), instance.resolveWildcards(), instance.useDefaultFields(), - randomValueOtherThan(instance.getQuery(), GetInferenceFieldsActionRequestTests::randomQuery), - instance.getIndicesOptions() + randomValueOtherThan(instance.query(), GetInferenceFieldsInternalActionRequestTests::randomQuery), + instance.indicesOptions() ); - case 5 -> new GetInferenceFieldsAction.Request( - instance.getIndices(), - instance.getFields(), + case 5 -> new GetInferenceFieldsInternalAction.Request( + instance.indices(), + instance.fields(), instance.resolveWildcards(), instance.useDefaultFields(), - instance.getQuery(), - randomValueOtherThan(instance.getIndicesOptions(), () -> { + instance.query(), + randomValueOtherThan(instance.indicesOptions(), () -> { IndicesOptions newOptions = randomIndicesOptions(); - while (instance.getIndicesOptions() == IndicesOptions.DEFAULT && newOptions == null) { + while (instance.indicesOptions() == IndicesOptions.DEFAULT && newOptions == null) { newOptions = randomIndicesOptions(); } return newOptions; @@ -104,20 +105,23 @@ protected GetInferenceFieldsAction.Request mutateInstance(GetInferenceFieldsActi @Override protected Collection bwcVersions() { - TransportVersion minVersion = TransportVersion.max(TransportVersion.minimumCompatible(), GET_INFERENCE_FIELDS_ACTION_TV); + TransportVersion minVersion = TransportVersion.max( + TransportVersion.minimumCompatible(), + GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV + ); return TransportVersionUtils.allReleasedVersions().tailSet(minVersion, true); } @Override - protected GetInferenceFieldsAction.Request mutateInstanceForVersion( - GetInferenceFieldsAction.Request instance, + protected GetInferenceFieldsInternalAction.Request mutateInstanceForVersion( + GetInferenceFieldsInternalAction.Request instance, TransportVersion version ) { return instance; } - private static Set randomIndentifierSet() { - return randomSet(0, 5, ESTestCase::randomIdentifier); + private static String[] randomIndices() { + return randomArray(0, 5, String[]::new, ESTestCase::randomIdentifier); } private static Map randomFields() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionResponseTests.java similarity index 63% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionResponseTests.java index 9b2d4d86ef880..bbbea2caf96de 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceFieldsInternalActionResponseTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResultsTests; @@ -30,9 +30,10 @@ import java.util.Map; import static org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests.generateRandomChunkingSettings; -import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction.GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV; -public class GetInferenceFieldsActionResponseTests extends AbstractBWCWireSerializationTestCase { +public class GetInferenceFieldsInternalActionResponseTests extends AbstractBWCWireSerializationTestCase< + GetInferenceFieldsInternalAction.Response> { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { @@ -40,36 +41,43 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { } @Override - protected Writeable.Reader instanceReader() { - return GetInferenceFieldsAction.Response::new; + protected Writeable.Reader instanceReader() { + return GetInferenceFieldsInternalAction.Response::new; } @Override - protected GetInferenceFieldsAction.Response createTestInstance() { - return new GetInferenceFieldsAction.Response(randomInferenceFieldsMap(), randomInferenceResultsMap()); + protected GetInferenceFieldsInternalAction.Response createTestInstance() { + return new GetInferenceFieldsInternalAction.Response(randomInferenceFieldsMap(), randomInferenceResultsMap()); } @Override - protected GetInferenceFieldsAction.Response mutateInstance(GetInferenceFieldsAction.Response instance) throws IOException { + protected GetInferenceFieldsInternalAction.Response mutateInstance(GetInferenceFieldsInternalAction.Response instance) + throws IOException { return switch (between(0, 1)) { - case 0 -> new GetInferenceFieldsAction.Response( - randomValueOtherThan(instance.getInferenceFieldsMap(), GetInferenceFieldsActionResponseTests::randomInferenceFieldsMap), + case 0 -> new GetInferenceFieldsInternalAction.Response( + randomValueOtherThan( + instance.getInferenceFieldsMap(), + GetInferenceFieldsInternalActionResponseTests::randomInferenceFieldsMap + ), instance.getInferenceResultsMap() ); - case 1 -> new GetInferenceFieldsAction.Response( + case 1 -> new GetInferenceFieldsInternalAction.Response( instance.getInferenceFieldsMap(), - randomValueOtherThan(instance.getInferenceResultsMap(), GetInferenceFieldsActionResponseTests::randomInferenceResultsMap) + randomValueOtherThan( + instance.getInferenceResultsMap(), + GetInferenceFieldsInternalActionResponseTests::randomInferenceResultsMap + ) ); default -> throw new AssertionError("Invalid value"); }; } - private static Map> randomInferenceFieldsMap() { - Map> map = new HashMap<>(); + private static Map> randomInferenceFieldsMap() { + Map> map = new HashMap<>(); int numIndices = randomIntBetween(0, 5); for (int i = 0; i < numIndices; i++) { String indexName = randomIdentifier(); - List fields = new ArrayList<>(); + List fields = new ArrayList<>(); int numFields = randomIntBetween(0, 5); for (int j = 0; j < numFields; j++) { fields.add(randomeExtendedInferenceFieldMetadata()); @@ -81,20 +89,23 @@ private static Map bwcVersions() { - TransportVersion minVersion = TransportVersion.max(TransportVersion.minimumCompatible(), GET_INFERENCE_FIELDS_ACTION_TV); + TransportVersion minVersion = TransportVersion.max( + TransportVersion.minimumCompatible(), + GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV + ); return TransportVersionUtils.allReleasedVersions().tailSet(minVersion, true); } @Override - protected GetInferenceFieldsAction.Response mutateInstanceForVersion( - GetInferenceFieldsAction.Response instance, + protected GetInferenceFieldsInternalAction.Response mutateInstanceForVersion( + GetInferenceFieldsInternalAction.Response instance, TransportVersion version ) { return instance; } - private static GetInferenceFieldsAction.ExtendedInferenceFieldMetadata randomeExtendedInferenceFieldMetadata() { - return new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata(randomInferenceFieldMetadata(), randomFloat()); + private static GetInferenceFieldsInternalAction.ExtendedInferenceFieldMetadata randomeExtendedInferenceFieldMetadata() { + return new GetInferenceFieldsInternalAction.ExtendedInferenceFieldMetadata(randomInferenceFieldMetadata(), randomFloat()); } private static InferenceFieldMetadata randomInferenceFieldMetadata() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java index 2c758523374db..1d9f83aada4cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/AbstractInterceptedInferenceQueryBuilderTestCase.java @@ -69,7 +69,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction.GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder.INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; @@ -237,7 +237,7 @@ protected void ccsSerializationWithMinimizeRoundTripsFalseTestCase(TaskType task new MockInferenceRemoteClusterClient.RemoteClusterConfig( remoteInferenceEndpoints, remoteIndexConfigs, - TransportVersionUtils.getPreviousVersion(GET_INFERENCE_FIELDS_ACTION_TV) + TransportVersionUtils.getPreviousVersion(GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV) ); QueryRewriteContext preCcsRemoteClusterContext = createQueryRewriteContext( @@ -254,7 +254,7 @@ protected void ccsSerializationWithMinimizeRoundTripsFalseTestCase(TaskType task + queryName + " query cross-cluster search when" + " [ccs_minimize_roundtrips] is false. Please update all clusters to at least " - + GET_INFERENCE_FIELDS_ACTION_TV.toReleaseVersion() + + GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV.toReleaseVersion() ), null ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceRemoteClusterClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceRemoteClusterClient.java index 0fb2b699abd60..3a8a40f28245c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceRemoteClusterClient.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceRemoteClusterClient.java @@ -19,7 +19,7 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportResponse; -import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction; import org.mockito.Mockito; import java.util.HashMap; @@ -28,7 +28,7 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction.GET_INFERENCE_FIELDS_ACTION_TV; +import static org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsInternalAction.GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV; import static org.mockito.Mockito.when; public class MockInferenceRemoteClusterClient implements RemoteClusterClient { @@ -57,27 +57,30 @@ public void Request request, ActionListener listener ) { - if (action.equals(GetInferenceFieldsAction.REMOTE_TYPE) - && request instanceof GetInferenceFieldsAction.Request getInferenceFieldsRequest) { + if (action.equals(GetInferenceFieldsInternalAction.REMOTE_TYPE) + && request instanceof GetInferenceFieldsInternalAction.Request getInferenceFieldsRequest) { @SuppressWarnings("unchecked") - ActionListener actionListener = (ActionListener) listener; + ActionListener actionListener = (ActionListener< + GetInferenceFieldsInternalAction.Response>) listener; - if (connection.getTransportVersion().supports(GET_INFERENCE_FIELDS_ACTION_TV) == false) { - actionListener.onFailure(new IllegalStateException("Mock remote cluster does not support GetInferenceFieldsAction")); + if (connection.getTransportVersion().supports(GET_INFERENCE_FIELDS_ACTION_AS_INDICES_ACTION_TV) == false) { + actionListener.onFailure( + new IllegalStateException("Mock remote cluster does not support GetInferenceFieldsInternalAction") + ); return; } - final Set indices = getInferenceFieldsRequest.getIndices(); - final Map fields = getInferenceFieldsRequest.getFields(); + final String[] indices = getInferenceFieldsRequest.indices(); + final Map fields = getInferenceFieldsRequest.fields(); final boolean resolveWildcards = getInferenceFieldsRequest.resolveWildcards(); final boolean useDefaultFields = getInferenceFieldsRequest.useDefaultFields(); - final String query = getInferenceFieldsRequest.getQuery(); + final String query = getInferenceFieldsRequest.query(); try { var inferenceFieldsMap = getInferenceFieldsMap(indices, fields, resolveWildcards, useDefaultFields); var inferenceResultsMap = getInferenceResultsMap(inferenceFieldsMap, query); - actionListener.onResponse(new GetInferenceFieldsAction.Response(inferenceFieldsMap, inferenceResultsMap)); + actionListener.onResponse(new GetInferenceFieldsInternalAction.Response(inferenceFieldsMap, inferenceResultsMap)); } catch (Exception e) { actionListener.onFailure(e); } @@ -91,8 +94,8 @@ public void getConnection(Request request, Actio listener.onResponse(connection); } - private Map> getInferenceFieldsMap( - Set indices, + private Map> getInferenceFieldsMap( + String[] indices, Map fields, boolean resolveWildcards, boolean useDefaultFields @@ -102,7 +105,9 @@ private Map> inferenceFieldsMap = new HashMap<>(indices.size()); + Map> inferenceFieldsMap = new HashMap<>( + indices.length + ); for (String index : indices) { var inferenceFieldsMetadataMap = clusterInferenceFieldsMap.get(index); if (inferenceFieldsMetadataMap == null) { @@ -114,9 +119,9 @@ private Map eifm = matchingInferenceFieldsMap.entrySet() + List eifm = matchingInferenceFieldsMap.entrySet() .stream() - .map(e -> new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata(e.getKey(), e.getValue())) + .map(e -> new GetInferenceFieldsInternalAction.ExtendedInferenceFieldMetadata(e.getKey(), e.getValue())) .toList(); inferenceFieldsMap.put(index, eifm); @@ -126,7 +131,7 @@ private Map getInferenceResultsMap( - Map> inferenceFieldsMap, + Map> inferenceFieldsMap, @Nullable String query ) { if (query == null || query.isBlank()) { diff --git a/x-pack/plugin/security/qa/multi-cluster/build.gradle b/x-pack/plugin/security/qa/multi-cluster/build.gradle index 6cbabbabe00ce..4089ce36d315a 100644 --- a/x-pack/plugin/security/qa/multi-cluster/build.gradle +++ b/x-pack/plugin/security/qa/multi-cluster/build.gradle @@ -27,6 +27,11 @@ dependencies { clusterModules project(':x-pack:plugin:ml') clusterModules project(xpackModule('ilm')) clusterModules(project(":modules:ingest-common")) + // dependencies for inference tests + clusterModules project(':x-pack:plugin:ilm') + clusterModules project(':x-pack:plugin:ml') + clusterModules project(':x-pack:plugin:inference') + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } tasks.named("javaRestTest") { diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityCCSCrossClusterInferenceIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityCCSCrossClusterInferenceIT.java new file mode 100644 index 0000000000000..bfadcc0554012 --- /dev/null +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityCCSCrossClusterInferenceIT.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.remotecluster; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.util.resource.Resource; +import org.elasticsearch.test.rest.ObjectPath; +import org.junit.After; +import org.junit.ClassRule; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.equalTo; + +public class RemoteClusterSecurityCCSCrossClusterInferenceIT extends AbstractRemoteClusterSecurityTestCase { + + private static final AtomicReference> API_KEY_MAP_REF = new AtomicReference<>(); + + static { + fulfillingCluster = ElasticsearchCluster.local() + .name("fulfilling-cluster") + .apply(commonClusterConfig) + // Enable RCS 2.0 (remote cluster server) + .setting("remote_cluster_server.enabled", "true") + .setting("remote_cluster.port", "0") + .setting("xpack.security.remote_cluster_server.ssl.enabled", "true") + .setting("xpack.security.remote_cluster_server.ssl.key", "remote-cluster.key") + .setting("xpack.security.remote_cluster_server.ssl.certificate", "remote-cluster.crt") + .keystore("xpack.security.remote_cluster_server.ssl.secure_key_passphrase", "remote-cluster-password") + .module("x-pack-ilm") + .module("x-pack-ml") + .module("x-pack-inference") + .plugin("inference-service-test") + .build(); + + queryCluster = ElasticsearchCluster.local() + .name("query-cluster") + .apply(commonClusterConfig) + .setting("xpack.security.remote_cluster_client.ssl.enabled", "true") + .setting("xpack.security.remote_cluster_client.ssl.certificate_authorities", "remote-cluster-ca.crt") + .module("x-pack-ilm") + .module("x-pack-ml") + .module("x-pack-inference") + .plugin("inference-service-test") + .keystore("cluster.remote.my_remote_cluster.credentials", () -> { + if (API_KEY_MAP_REF.get() == null) { + final Map apiKeyMap = createCrossClusterAccessApiKey(""" + { + "search": [ + { + "names": ["*"] + } + ] + }"""); + API_KEY_MAP_REF.set(apiKeyMap); + } + return (String) API_KEY_MAP_REF.get().get("encoded"); + }) + .rolesFile(Resource.fromClasspath("roles.yml")) + .user(REMOTE_SEARCH_USER, PASS.toString(), "remote_search_cross_cluster_inference", false) + .build(); + } + + @ClassRule + public static TestRule clusterRule = RuleChain.outerRule(fulfillingCluster).around(queryCluster); + + private final Set indices = new HashSet<>(); + private final Set inferenceIds = new HashSet<>(); + + @After + public void cleanUpIndicesAndInferenceEndpoints() throws Exception { + for (String index : indices) { + deleteIndexOnFulfillingCluster(index); + } + + for (String inferenceId : inferenceIds) { + deleteInferenceEndpointOnFulfillingCluster(inferenceId, true); + } + } + + public void testCrossClusterInference() throws Exception { + configureRemoteCluster(); + + // Create an inference endpoint on the remote cluster + final String inferenceId = createTextEmbeddingInferenceEndpointOnFulfillingCluster(); + + // Create an index on the remote cluster with a semantic_text field that uses the inference endpoint + createSemanticTextIndexOnFulfillingCluster("test-index", inferenceId); + indexDocOnFulfillingCluster("test-index", "1", "test document for cross cluster search"); + + // Execute a basic match query with ccs_minimize_roundtrips=false. + // This will be intercepted by the inference plugin and rewritten to a semantic query on the content field, which will trigger + // a cross-cluster inference action. + Response response = performSearchRequest(List.of("my_remote_cluster:test-index"), "test"); + ObjectPath objectPath = assertOKAndCreateObjectPath(response); + assertThat(objectPath.evaluate("hits.total.value"), equalTo(1)); + assertThat(objectPath.evaluate("hits.hits.0._id"), equalTo("1")); + } + + public void testQueryUnauthorizedIndex() throws Exception { + configureRemoteCluster(); + + // Create an inference endpoint on the remote cluster + final String inferenceId1 = createTextEmbeddingInferenceEndpointOnFulfillingCluster(); + final String inferenceId2 = createTextEmbeddingInferenceEndpointOnFulfillingCluster(); + + // Create two indices on the remote cluster with a semantic_text field that uses the inference endpoint. + // The user is only authorized to query one of these indices. + createSemanticTextIndexOnFulfillingCluster("test-index", inferenceId1); + createSemanticTextIndexOnFulfillingCluster("unauthorized-index", inferenceId2); + + indexDocOnFulfillingCluster("test-index", "1", "test document for cross cluster search"); + indexDocOnFulfillingCluster("unauthorized-index", "2", "another test document for cross cluster"); + + // Execute a query and verify that results are returned only for the authorized index + Response response = performSearchRequest(List.of("my_remote_cluster:*-index"), "test"); + ObjectPath objectPath = assertOKAndCreateObjectPath(response); + assertThat(objectPath.evaluate("hits.total.value"), equalTo(1)); + assertThat(objectPath.evaluate("hits.hits.0._id"), equalTo("1")); + } + + private String createTextEmbeddingInferenceEndpointOnFulfillingCluster() throws Exception { + final String inferenceId = randomIdentifier(); + Request createInferenceEndpointRequest = new Request("PUT", "/_inference/text_embedding/" + inferenceId); + createInferenceEndpointRequest.setJsonEntity(""" + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 256, + "similarity": "cosine", + "api_key": "abc64" + } + } + """); + performRequestAgainstFulfillingCluster(createInferenceEndpointRequest); + + inferenceIds.add(inferenceId); + return inferenceId; + } + + private void createSemanticTextIndexOnFulfillingCluster(String indexName, String inferenceId) throws Exception { + Request createIndexRequest = new Request("PUT", "/" + indexName); + createIndexRequest.setJsonEntity(Strings.format(""" + { + "mappings": { + "properties": { + "content": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + } + """, inferenceId)); + performRequestAgainstFulfillingCluster(createIndexRequest); + indices.add(indexName); + } + + private static void indexDocOnFulfillingCluster(String indexName, String id, String content) throws Exception { + Request indexDocRequest = new Request("POST", Strings.format("/%s/_doc/%s?refresh=true", indexName, id)); + indexDocRequest.setJsonEntity(Strings.format(""" + { + "content": "%s" + } + """, content)); + performRequestAgainstFulfillingCluster(indexDocRequest); + } + + private static Response performSearchRequest(List indexNames, String query) throws Exception { + Request searchRequest = new Request("GET", Strings.format("/%s/_search", String.join(",", indexNames))); + searchRequest.addParameter("ccs_minimize_roundtrips", "false"); + searchRequest.setJsonEntity(Strings.format(""" + { + "query": { + "match": { + "content": "%s" + } + } + } + """, query)); + searchRequest.setOptions( + searchRequest.getOptions().toBuilder().addHeader("Authorization", headerFromRandomAuthMethod(REMOTE_SEARCH_USER, PASS)) + ); + + return client().performRequest(searchRequest); + } + + private static void deleteIndexOnFulfillingCluster(String indexName) throws Exception { + Request deleteIndexRequest = new Request("DELETE", "/" + indexName); + performRequestAgainstFulfillingCluster(deleteIndexRequest); + } + + private static void deleteInferenceEndpointOnFulfillingCluster(String inferenceId, boolean force) throws Exception { + Request deleteInferenceEndpointRequest = new Request("DELETE", "/_inference/" + inferenceId + "?force=" + force); + performRequestAgainstFulfillingCluster(deleteInferenceEndpointRequest); + } +} diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/resources/roles.yml b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/resources/roles.yml index c09f9dc620a4c..41912c599a3b0 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/resources/roles.yml +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/resources/roles.yml @@ -1,3 +1,14 @@ +remote_search_cross_cluster_inference: + cluster: + - monitor + indices: + - names: ['*'] + privileges: ['read', 'view_index_metadata'] + remote_indices: + - names: ['test-*'] + privileges: ['read', 'read_cross_cluster', 'view_index_metadata'] + clusters: ['*'] + read_remote_shared_logs: remote_indices: - names: [ 'shared-logs' ] diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index aa0695d626724..d119b3f5cdf91 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -335,7 +335,6 @@ public class Constants { "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:internal/xpack/inference/create_endpoints", "cluster:internal/xpack/inference/embedding", - "cluster:internal/xpack/inference/fields/get", "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/inference/update_authorization_task", @@ -547,6 +546,7 @@ public class Constants { "indices:admin/index_template/put", "indices:admin/index_template/simulate", "indices:admin/index_template/simulate_index", + "indices:admin/inference/fields/get", "indices:admin/mapping/auto_put", "indices:admin/mapping/put", "indices:admin/mappings/fields/get",