diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1fb4cc3c65..5caefbbf05 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -63,6 +63,7 @@ import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -71,6 +72,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; @@ -764,6 +766,11 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + // mock to query to generate the parsed query + doAnswer(invocationOnMock -> { + final QueryBuilder queryBuilder = invocationOnMock.getArgument(0); + return new ParsedQuery(queryBuilder.toQuery(mockQueryShardContext), Collections.emptyMap()); + }).when(mockQueryShardContext).toQuery(any()); HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( @@ -804,6 +811,7 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc RescorerBuilder rescorerBuilder = new QueryRescorerBuilder(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2)); RescoreContext rescoreContext = rescorerBuilder.buildContext(mockQueryShardContext); + List rescoreContexts = List.of(rescoreContext); when(searchContext.rescore()).thenReturn(rescoreContexts); Weight rescoreWeight = mock(Weight.class); @@ -890,6 +898,11 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + // mock to query to generate the parsed query + doAnswer(invocationOnMock -> { + final QueryBuilder queryBuilder = invocationOnMock.getArgument(0); + return new ParsedQuery(queryBuilder.toQuery(mockQueryShardContext), Collections.emptyMap()); + }).when(mockQueryShardContext).toQuery(any()); HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery(