diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc16369ca..caa7a1965b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) +- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007)) ### Bug Fixes ### Infrastructure - [3.0] Update neural-search for OpenSearch 3.0 compatibility ([#1141](https://github.com/opensearch-project/neural-search/pull/1141)) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java index a3cb156316..d397fc6fbf 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/BatchIngestionIT.java @@ -31,8 +31,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio loadModel(sparseModelId); MLModelState oldModelState = getModelState(sparseModelId); logger.info("Model state in OLD phase: {}", oldModelState); - if (oldModelState != MLModelState.LOADED) { - logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState); + if (oldModelState != MLModelState.LOADED && oldModelState != MLModelState.DEPLOYED) { + logger.error( + "Model {} is not in LOADED or DEPLOYED state in OLD phase. Current state: {}", + sparseModelId, + oldModelState + ); waitForModelToLoad(sparseModelId); } createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2); @@ -52,8 +56,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio loadModel(sparseModelId); MLModelState mixedModelState = getModelState(sparseModelId); logger.info("Model state in MIXED phase: {}", mixedModelState); - if (mixedModelState != MLModelState.LOADED) { - logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState); + if (mixedModelState != MLModelState.LOADED && mixedModelState != MLModelState.DEPLOYED) { + logger.error( + "Model {} is not in LOADED or DEPLOYED state in MIXED phase. Current state: {}", + sparseModelId, + mixedModelState + ); waitForModelToLoad(sparseModelId); } logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE)); diff --git a/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java index bfbb2e6d9c..5e5f5cd333 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java @@ -21,10 +21,10 @@ public class VectorUtil { * @param vectorAsList {@link List} of {@link Float}'s representing the vector * @return array of floats produced from input list */ - public static float[] vectorAsListToArray(List vectorAsList) { + public static float[] vectorAsListToArray(List vectorAsList) { float[] vector = new float[vectorAsList.size()]; for (int i = 0; i < vectorAsList.size(); i++) { - vector[i] = vectorAsList.get(i); + vector[i] = vectorAsList.get(i).floatValue(); } return vector; } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 9b088be6ed..06cdb66905 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -55,7 +55,7 @@ public class MLCommonsClientAccessor { public void inferenceSentence( @NonNull final String modelId, @NonNull final String inputText, - @NonNull final ActionListener> listener + @NonNull final ActionListener> listener ) { inferenceSentences( @@ -87,7 +87,7 @@ public void inferenceSentence( */ public void inferenceSentences( @NonNull final TextInferenceRequest inferenceRequest, - @NonNull final ActionListener>> listener + @NonNull final ActionListener>> listener ) { retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener); } @@ -107,7 +107,7 @@ public void inferenceSentencesWithMapResult( * @param inferenceRequest {@link InferenceRequest} * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. */ - public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { + public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener> listener) { retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener); } @@ -148,11 +148,11 @@ private void retryableInferenceSentencesWithMapResult( private void retryableInferenceSentencesWithVectorResult( final TextInferenceRequest inferenceRequest, final int retryTime, - final ActionListener>> listener + final ActionListener>> listener ) { MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts()); mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); + final List> vector = buildVectorFromResponse(mlOutput); listener.onResponse(vector); }, e -> RetryUtil.handleRetryOrFailure( @@ -171,7 +171,9 @@ private void retryableInferenceSimilarityWithVectorResult( ) { MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts()); mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + final List scores = buildVectorFromResponse(mlOutput).stream() + .map(v -> v.getFirst().floatValue()) + .collect(Collectors.toList()); listener.onResponse(scores); }, e -> RetryUtil.handleRetryOrFailure( @@ -194,14 +196,14 @@ private MLInput createMLTextPairsInput(final String query, final List in return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); } - private List> buildVectorFromResponse(MLOutput mlOutput) { - final List> vector = new ArrayList<>(); + private List> buildVectorFromResponse(MLOutput mlOutput) { + final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); for (final ModelTensors tensors : tensorOutputList) { final List tensorsList = tensors.getMlModelTensors(); for (final ModelTensor tensor : tensorsList) { - vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList())); + vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList())); } } return vector; @@ -225,19 +227,19 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } - private List buildSingleVectorFromResponse(final MLOutput mlOutput) { - final List> vector = buildVectorFromResponse(mlOutput); + private List buildSingleVectorFromResponse(final MLOutput mlOutput) { + final List> vector = buildVectorFromResponse(mlOutput); return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } private void retryableInferenceSentencesWithSingleVectorResult( final MapInferenceRequest inferenceRequest, final int retryTime, - final ActionListener> listener + final ActionListener> listener ) { MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects()); mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); + final List vector = buildSingleVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence is : {} ", vector); listener.onResponse(vector); }, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 72a2496cc7..d675f6f046 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -127,7 +127,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer vectors) { + private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); @@ -167,7 +167,7 @@ Map buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { + Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { Map result = new LinkedHashMap<>(); result.put(knnKey, modelTensorList); return result; diff --git a/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java index a06e8f84dd..4ebb7858fc 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java @@ -12,15 +12,15 @@ public class VectorUtilTests extends OpenSearchTestCase { public void testVectorAsListToArray() { - List vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); + List vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements); assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length); for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) { - assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f); + assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f); } - List vectorAsList_withNoElements = Collections.emptyList(); + List vectorAsList_withNoElements = Collections.emptyList(); float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements); assertEquals(0, vectorAsArray_withNoElements.length); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 5d32b3ded6..3fea202d0c 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -36,10 +36,13 @@ public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock - private ActionListener>> resultListener; + private ActionListener>> resultListener; @Mock - private ActionListener> singleSentenceResultListener; + private ActionListener> singleSentenceResultListener; + + @Mock + private ActionListener> similarityResultListener; @Mock private MachineLearningNodeClient client; @@ -53,7 +56,7 @@ public void setup() { } public void testInferenceSentence_whenValidInput_thenSuccess() { - final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { } public void testInferenceSentences_whenValidInputThenSuccess() { - final List> vectorList = new ArrayList<>(); + final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() { } public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { - final List> vectorList = new ArrayList<>(); + final List> vectorList = new ArrayList<>(); vectorList.add(Collections.emptyList()); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); @@ -127,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); // Verify client.predict is called 4 times (1 initial + 3 retries) Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); // Verify failure is propagated to the listener after all retries - Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException); // Ensure no additional interactions with the listener - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verifyNoMoreInteractions(similarityResultListener); } public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() { @@ -288,7 +291,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa } public void testInferenceMultimodal_whenValidInput_thenSuccess() { - final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -353,12 +356,12 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onResponse(vector); - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verify(similarityResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(similarityResultListener); } public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { @@ -369,12 +372,12 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener); + accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(singleSentenceResultListener).onFailure(exception); - Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + Mockito.verify(similarityResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(similarityResultListener); } private ModelTensorOutput createModelTensorOutput(final Float[] output) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 8b8bf66a4a..68ba11f9bd 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -771,10 +771,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { .modelId(MODEL_ID) .k(K) .build(); - List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; }).when(mlCommonsClientAccessor) @@ -810,10 +810,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe .modelId(MODEL_ID) .k(K) .build(); - List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(expectedVector); return null; }).when(mlCommonsClientAccessor) diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 72619aba51..6045599c8e 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -312,7 +312,7 @@ protected float[] runInference(final String modelId, final String queryText) { List output = (List) result.get("output"); assertEquals(1, output.size()); Map map = (Map) output.get(0); - List data = ((List) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList()); + List data = ((List) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList()); return vectorAsListToArray(data); }