@@ -80,6 +80,9 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<S
8080 private static final Logger logger = LogManager .getLogger (SearchQueryThenFetchAsyncAction .class );
8181
8282 private static final TransportVersion BATCHED_QUERY_PHASE_VERSION = TransportVersion .fromName ("batched_query_phase_version" );
83+ private static final TransportVersion BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE = TransportVersion .fromName (
84+ "batched_response_might_include_reduction_failure"
85+ );
8386
8487 private final SearchProgressListener progressListener ;
8588
@@ -218,18 +221,27 @@ public static final class NodeQueryResponse extends TransportResponse {
218221 private final RefCounted refCounted = LeakTracker .wrap (new SimpleRefCounted ());
219222
220223 private final Object [] results ;
224+ private final Exception reductionFailure ;
221225 private final SearchPhaseController .TopDocsStats topDocsStats ;
222226 private final QueryPhaseResultConsumer .MergeResult mergeResult ;
223227
224228 NodeQueryResponse (StreamInput in ) throws IOException {
225229 this .results = in .readArray (i -> i .readBoolean () ? new QuerySearchResult (i ) : i .readException (), Object []::new );
226- this .mergeResult = QueryPhaseResultConsumer .MergeResult .readFrom (in );
227- this .topDocsStats = SearchPhaseController .TopDocsStats .readFrom (in );
230+ if (in .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) && in .readBoolean ()) {
231+ this .reductionFailure = in .readException ();
232+ this .mergeResult = null ;
233+ this .topDocsStats = null ;
234+ } else {
235+ this .reductionFailure = null ;
236+ this .mergeResult = QueryPhaseResultConsumer .MergeResult .readFrom (in );
237+ this .topDocsStats = SearchPhaseController .TopDocsStats .readFrom (in );
238+ }
228239 }
229240
230241 NodeQueryResponse (
231- QueryPhaseResultConsumer .MergeResult mergeResult ,
232242 Object [] results ,
243+ Exception reductionFailure ,
244+ QueryPhaseResultConsumer .MergeResult mergeResult ,
233245 SearchPhaseController .TopDocsStats topDocsStats
234246 ) {
235247 this .results = results ;
@@ -238,6 +250,7 @@ public static final class NodeQueryResponse extends TransportResponse {
238250 r .incRef ();
239251 }
240252 }
253+ this .reductionFailure = reductionFailure ;
241254 this .mergeResult = mergeResult ;
242255 this .topDocsStats = topDocsStats ;
243256 assert Arrays .stream (results ).noneMatch (Objects ::isNull ) : Arrays .toString (results );
@@ -248,6 +261,10 @@ public Object[] getResults() {
248261 return results ;
249262 }
250263
264+ Exception getReductionFailure () {
265+ return reductionFailure ;
266+ }
267+
251268 @ Override
252269 public void writeTo (StreamOutput out ) throws IOException {
253270 out .writeArray ((o , v ) -> {
@@ -260,8 +277,19 @@ public void writeTo(StreamOutput out) throws IOException {
260277 ((QuerySearchResult ) v ).writeTo (o );
261278 }
262279 }, results );
263- mergeResult .writeTo (out );
264- topDocsStats .writeTo (out );
280+ if (out .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE )) {
281+ boolean hasReductionFailure = reductionFailure != null ;
282+ out .writeBoolean (hasReductionFailure );
283+ if (hasReductionFailure ) {
284+ out .writeException (reductionFailure );
285+ } else {
286+ mergeResult .writeTo (out );
287+ topDocsStats .writeTo (out );
288+ }
289+ } else {
290+ mergeResult .writeTo (out );
291+ topDocsStats .writeTo (out );
292+ }
265293 }
266294
267295 @ Override
@@ -495,7 +523,12 @@ public Executor executor() {
495523 @ Override
496524 public void handleResponse (NodeQueryResponse response ) {
497525 if (results instanceof QueryPhaseResultConsumer queryPhaseResultConsumer ) {
498- queryPhaseResultConsumer .addBatchedPartialResult (response .topDocsStats , response .mergeResult );
526+ Exception reductionFailure = response .getReductionFailure ();
527+ if (reductionFailure != null ) {
528+ queryPhaseResultConsumer .failure .compareAndSet (null , reductionFailure );
529+ } else {
530+ queryPhaseResultConsumer .addBatchedPartialResult (response .topDocsStats , response .mergeResult );
531+ }
499532 }
500533 for (int i = 0 ; i < response .results .length ; i ++) {
501534 var s = request .shards .get (i );
@@ -515,6 +548,21 @@ public void handleResponse(NodeQueryResponse response) {
515548
516549 @ Override
517550 public void handleException (TransportException e ) {
551+ if (connection .getTransportVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) == false ) {
552+ bwcHandleException (e );
553+ return ;
554+ }
555+ Exception cause = (Exception ) ExceptionsHelper .unwrapCause (e );
556+ logger .debug ("handling node search exception coming from [" + nodeId + "]" , cause );
557+ onNodeQueryFailure (e , request , routing );
558+ }
559+
560+ /**
561+ * This code is strictly for _snapshot_ backwards compatibility. The feature flag
562+ * {@link SearchService#BATCHED_QUERY_PHASE_FEATURE_FLAG} was not turned on when the transport version
563+ * {@link SearchQueryThenFetchAsyncAction#BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE} was introduced.
564+ */
565+ private void bwcHandleException (TransportException e ) {
518566 Exception cause = (Exception ) ExceptionsHelper .unwrapCause (e );
519567 logger .debug ("handling node search exception coming from [" + nodeId + "]" , cause );
520568 if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException ) {
@@ -786,11 +834,93 @@ void onShardDone() {
786834 if (countDown .countDown () == false ) {
787835 return ;
788836 }
837+ if (channel .getVersion ().supports (BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE ) == false ) {
838+ bwcRespond ();
839+ return ;
840+ }
841+ var channelListener = new ChannelActionListener <>(channel );
842+ NodeQueryResponse nodeQueryResponse ;
843+ try (queryPhaseResultConsumer ) {
844+ Exception reductionFailure = queryPhaseResultConsumer .failure .get ();
845+ if (reductionFailure == null ) {
846+ nodeQueryResponse = getSuccessfulResponse ();
847+ } else {
848+ nodeQueryResponse = getReductionFailureResponse (reductionFailure );
849+ }
850+ } catch (IOException e ) {
851+ releaseAllResultsContexts ();
852+ channelListener .onFailure (e );
853+ return ;
854+ }
855+ ActionListener .respondAndRelease (channelListener , nodeQueryResponse );
856+ }
857+
858+ private NodeQueryResponse getSuccessfulResponse () throws IOException {
859+ final QueryPhaseResultConsumer .MergeResult mergeResult ;
860+ try {
861+ mergeResult = Objects .requireNonNullElse (
862+ queryPhaseResultConsumer .consumePartialMergeResultDataNode (),
863+ EMPTY_PARTIAL_MERGE_RESULT
864+ );
865+ } catch (Exception e ) {
866+ return getReductionFailureResponse (e );
867+ }
868+ // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments,
869+ // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other
870+ // indices without a roundtrip to the coordinating node
871+ final BitSet relevantShardIndices = new BitSet (searchRequest .shards .size ());
872+ if (mergeResult .reducedTopDocs () != null ) {
873+ for (ScoreDoc scoreDoc : mergeResult .reducedTopDocs ().scoreDocs ) {
874+ final int localIndex = scoreDoc .shardIndex ;
875+ scoreDoc .shardIndex = searchRequest .shards .get (localIndex ).shardIndex ;
876+ relevantShardIndices .set (localIndex );
877+ }
878+ }
879+ final Object [] results = new Object [queryPhaseResultConsumer .getNumShards ()];
880+ for (int i = 0 ; i < results .length ; i ++) {
881+ var result = queryPhaseResultConsumer .results .get (i );
882+ if (result == null ) {
883+ results [i ] = failures .get (i );
884+ } else {
885+ // free context id and remove it from the result right away in case we don't need it anymore
886+ maybeFreeContext (result , relevantShardIndices , namedWriteableRegistry );
887+ results [i ] = result ;
888+ }
889+ assert results [i ] != null ;
890+ }
891+ return new NodeQueryResponse (results , null , mergeResult , queryPhaseResultConsumer .topDocsStats );
892+ }
893+
894+ private NodeQueryResponse getReductionFailureResponse (Exception reductionFailure ) throws IOException {
895+ try {
896+ final Object [] results = new Object [queryPhaseResultConsumer .getNumShards ()];
897+ for (int i = 0 ; i < results .length ; i ++) {
898+ var result = queryPhaseResultConsumer .results .get (i );
899+ if (result == null ) {
900+ results [i ] = failures .get (i );
901+ } else {
902+ results [i ] = result ;
903+ }
904+ assert results [i ] != null ;
905+ }
906+ return new NodeQueryResponse (results , reductionFailure , null , null );
907+ } finally {
908+ releaseAllResultsContexts ();
909+ }
910+ }
911+
912+ /**
913+ * This code is strictly for _snapshot_ backwards compatibility. The feature flag
914+ * {@link SearchService#BATCHED_QUERY_PHASE_FEATURE_FLAG} was not turned on when the transport version
915+ * {@link SearchQueryThenFetchAsyncAction#BATCHED_RESPONSE_MIGHT_INCLUDE_REDUCTION_FAILURE} was introduced.
916+ */
917+ void bwcRespond () {
789918 var channelListener = new ChannelActionListener <>(channel );
790919 try (queryPhaseResultConsumer ) {
791920 var failure = queryPhaseResultConsumer .failure .get ();
792921 if (failure != null ) {
793- handleMergeFailure (failure , channelListener , namedWriteableRegistry );
922+ releaseAllResultsContexts ();
923+ channelListener .onFailure (failure );
794924 return ;
795925 }
796926 final QueryPhaseResultConsumer .MergeResult mergeResult ;
@@ -800,7 +930,8 @@ void onShardDone() {
800930 EMPTY_PARTIAL_MERGE_RESULT
801931 );
802932 } catch (Exception e ) {
803- handleMergeFailure (e , channelListener , namedWriteableRegistry );
933+ releaseAllResultsContexts ();
934+ channelListener .onFailure (e );
804935 return ;
805936 }
806937 // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments,
@@ -839,16 +970,30 @@ && isPartOfPIT(searchRequest.searchRequest, q.getContextId(), namedWriteableRegi
839970
840971 ActionListener .respondAndRelease (
841972 channelListener ,
842- new NodeQueryResponse (mergeResult , results , queryPhaseResultConsumer .topDocsStats )
973+ new NodeQueryResponse (results , null , mergeResult , queryPhaseResultConsumer .topDocsStats )
843974 );
844975 }
845976 }
846977
847- private void handleMergeFailure (
848- Exception e ,
849- ChannelActionListener < TransportResponse > channelListener ,
978+ private void maybeFreeContext (
979+ SearchPhaseResult result ,
980+ BitSet relevantShardIndices ,
850981 NamedWriteableRegistry namedWriteableRegistry
851982 ) {
983+ if (result instanceof QuerySearchResult q
984+ && q .getContextId () != null
985+ && relevantShardIndices .get (q .getShardIndex ()) == false
986+ && q .hasSuggestHits () == false
987+ && q .getRankShardResult () == null
988+ && searchRequest .searchRequest .scroll () == null
989+ && isPartOfPIT (searchRequest .searchRequest , q .getContextId (), namedWriteableRegistry ) == false ) {
990+ if (dependencies .searchService .freeReaderContext (q .getContextId ())) {
991+ q .clearContextId ();
992+ }
993+ }
994+ }
995+
996+ private void releaseAllResultsContexts () {
852997 queryPhaseResultConsumer .getSuccessfulResults ()
853998 .forEach (
854999 searchPhaseResult -> releaseLocalContext (
@@ -858,7 +1003,6 @@ private void handleMergeFailure(
8581003 namedWriteableRegistry
8591004 )
8601005 );
861- channelListener .onFailure (e );
8621006 }
8631007
8641008 void consumeResult (QuerySearchResult queryResult ) {
0 commit comments