Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,11 @@

package org.elasticsearch.action.search;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc;
Expand All @@ -44,6 +37,8 @@
import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.elasticsearch.common.collect.HppcMaps;
import org.elasticsearch.common.io.stream.DelayableWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchHit;
Expand All @@ -67,16 +62,28 @@
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public final class SearchPhaseController {
private static final Logger logger = LogManager.getLogger(SearchPhaseController.class);
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];

private final NamedWriteableRegistry namedWriteableRegistry;
private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;

public SearchPhaseController(
public SearchPhaseController(NamedWriteableRegistry namedWriteableRegistry,
Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) {
this.namedWriteableRegistry = namedWriteableRegistry;
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
}

Expand Down Expand Up @@ -430,7 +437,8 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
* @see QuerySearchResult#consumeProfileResult()
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs,
List<Supplier<InternalAggregations>> bufferedAggs,
List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
Expand Down Expand Up @@ -522,7 +530,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
private InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<Supplier<InternalAggregations>> aggregationsList
List<? extends Supplier<InternalAggregations>> aggregationsList
) {
/*
* Parse the aggregations, clearing the list as we go so bits backing
Expand Down Expand Up @@ -617,8 +625,9 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
* iff the buffer is exhausted.
*/
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private final NamedWriteableRegistry namedWriteableRegistry;
private final SearchShardTarget[] processedShards;
private final Supplier<InternalAggregations>[] aggsBuffer;
private final DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer;
private final TopDocs[] topDocsBuffer;
private final boolean hasAggs;
private final boolean hasTopDocs;
Expand All @@ -631,6 +640,8 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
private final int topNSize;
private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
private final boolean performFinalReduce;
private long aggsCurrentBufferSize;
private long aggsMaxBufferSize;

/**
* Creates a new {@link QueryPhaseResultConsumer}
Expand All @@ -641,12 +652,14 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
* @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results
* the buffer is used to incrementally reduce aggregation results before all shards responded.
*/
private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller,
private QueryPhaseResultConsumer(NamedWriteableRegistry namedWriteableRegistry, SearchProgressListener progressListener,
SearchPhaseController controller,
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
int trackTotalHitsUpTo, int topNSize,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
super(expectedResultSize);
this.namedWriteableRegistry = namedWriteableRegistry;
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
}
Expand All @@ -661,7 +674,7 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
this.processedShards = new SearchShardTarget[expectedResultSize];
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
@SuppressWarnings("unchecked")
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0];
DelayableWriteable.Serialized<InternalAggregations>[] aggsBuffer = new DelayableWriteable.Serialized[hasAggs ? bufferSize : 0];
this.aggsBuffer = aggsBuffer;
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
this.hasTopDocs = hasTopDocs;
Expand All @@ -684,15 +697,21 @@ public void consumeResult(SearchPhaseResult result) {
private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
InternalAggregations reducedAggs = null;
if (hasAggs) {
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(
aggs, aggReduceContextBuilder.forPartialReduction());
aggsBuffer[0] = () -> reducedAggs;
reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forPartialReduction());
aggsBuffer[0] = DelayableWriteable.referencing(reducedAggs)
.asSerialized(InternalAggregations::new, namedWriteableRegistry);
long previousBufferSize = aggsCurrentBufferSize;
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
aggsCurrentBufferSize = aggsBuffer[0].ramBytesUsed();
logger.trace("aggs partial reduction [{}->{}] max [{}]",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should add the task id to the output. That'd help a bit with debugging because setting the task manager to trace logging logs the query. Not that it is a good choice on a busy system, but it could be useful.

I did look into returning this data in other ways but I couldn't come up with the "right" way. And it is super useful to be able to see the partial reduction memory usage. I mean, it'd probably be useful in production. But I think it'll be super useful for me when I'm just hacking on things.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have SearchProgressListener I think it'd make more sense to forward this information through that interface and have a default implementation that logs the reduction. That is a little more work but makes it so I can test that we make these calls in a sane way which is nice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to move this information into SearchTask so you can see it in the tasks API. That seems useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other other hand maybe this is a good start and adding it to the SearchTask would be a good change for a follow up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to add in the SearchTask as a follow up

previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize);
}
if (hasTopDocs) {
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
Expand All @@ -705,12 +724,13 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
index = 1;
if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases);
topDocsStats.getTotalHits(), reducedAggs, numReducePhases);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked to @jimczi and @javanna and this line is a release blocker. We're pretty ok merging it, but not releasing it. Because async_search keeps a hard reference to the aggs passed to it. Actually async search has all kinds of trouble with aggs because it doesn't perform the final reduction until sync search would. But it does return aggs without the final reduction applied if you get the "progress" of the search. These aggs are going to be "funny". They'll be missing pipeline aggs, for instant. And scripted_metric will be borked in some way. As will a lot of other things. But you'll mostly get something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually async search has all kinds of trouble with aggs because it doesn't perform the final reduction until sync search would. But it does return aggs without the final reduction applied if you get the "progress" of the search. These aggs are going to be "funny". They'll be missing pipeline aggs, for instant. And scripted_metric will be borked in some way. As will a lot of other things. But you'll mostly get something.

Scratch that - it does perform the final reduction when you fetch the result. You could still get weird results because things are missing, but they'll be a lot less weird than I was thinking.

}
}
final int i = index++;
if (hasAggs) {
aggsBuffer[i] = querySearchResult.consumeAggs();
aggsBuffer[i] = querySearchResult.consumeAggs().asSerialized(InternalAggregations::new, namedWriteableRegistry);
aggsCurrentBufferSize += aggsBuffer[i].ramBytesUsed();
}
if (hasTopDocs) {
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
Expand All @@ -723,7 +743,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
}

private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null;
return hasAggs ? Arrays.asList((Supplier<InternalAggregations>[]) aggsBuffer).subList(0, index) : null;
}

private synchronized List<TopDocs> getRemainingTopDocs() {
Expand All @@ -732,6 +752,8 @@ private synchronized List<TopDocs> getRemainingTopDocs() {

@Override
public ReducedQueryPhase reduce() {
aggsMaxBufferSize = Math.max(aggsMaxBufferSize, aggsCurrentBufferSize);
logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize);
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(),
topDocsStats, numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
Expand Down Expand Up @@ -766,8 +788,8 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressL
if (request.getBatchedReduceSize() < numShards) {
int topNSize = getTopDocsSize(request);
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
return new QueryPhaseResultConsumer(namedWriteableRegistry, listener, this, numShards, request.getBatchedReduceSize(),
hasTopDocs, hasAggs, trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
}
}
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

package org.elasticsearch.common.io.stream;

import java.io.IOException;
import java.util.function.Supplier;

import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;

import java.io.IOException;
import java.util.function.Supplier;

/**
* A holder for {@link Writeable}s that can delays reading the underlying
* {@linkplain Writeable} when it is read from a remote node.
Expand All @@ -43,12 +45,22 @@ public static <T extends Writeable> DelayableWriteable<T> referencing(T referenc
* when {@link Supplier#get()} is called.
*/
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
return new Delayed<>(reader, in);
return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readBytesReference());
}

private DelayableWriteable() {}

public abstract boolean isDelayed();
/**
* Returns a {@linkplain DelayableWriteable} that stores its contents
* in serialized form.
*/
public abstract Serialized<T> asSerialized(Writeable.Reader<T> reader, NamedWriteableRegistry registry);

/**
* {@code true} if the {@linkplain Writeable} is being stored in
* serialized form, {@code false} otherwise.
*/
abstract boolean isSerialized();

private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
private T reference;
Expand All @@ -59,11 +71,7 @@ private static class Referencing<T extends Writeable> extends DelayableWriteable

@Override
public void writeTo(StreamOutput out) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
buffer.setVersion(out.getVersion());
reference.writeTo(buffer);
out.writeBytesReference(buffer.bytes());
}
out.writeBytesReference(writeToBuffer(out.getVersion()).bytes());
}

@Override
Expand All @@ -72,27 +80,48 @@ public T get() {
}

@Override
public boolean isDelayed() {
public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
try {
return new Serialized<T>(reader, Version.CURRENT, registry, writeToBuffer(Version.CURRENT).bytes());
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
}
}

@Override
boolean isSerialized() {
return false;
}

private BytesStreamOutput writeToBuffer(Version version) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
buffer.setVersion(version);
reference.writeTo(buffer);
return buffer;
}
}
}

private static class Delayed<T extends Writeable> extends DelayableWriteable<T> {
/**
* A {@link Writeable} stored in serialized form.
*/
public static class Serialized<T extends Writeable> extends DelayableWriteable<T> implements Accountable {
private final Writeable.Reader<T> reader;
private final Version remoteVersion;
private final BytesReference serialized;
private final Version serializedAtVersion;
private final NamedWriteableRegistry registry;
private final BytesReference serialized;

Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
Serialized(Writeable.Reader<T> reader, Version serializedAtVersion,
NamedWriteableRegistry registry, BytesReference serialized) throws IOException {
this.reader = reader;
remoteVersion = in.getVersion();
serialized = in.readBytesReference();
registry = in.namedWriteableRegistry();
this.serializedAtVersion = serializedAtVersion;
this.registry = registry;
this.serialized = serialized;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion() == remoteVersion) {
if (out.getVersion() == serializedAtVersion) {
/*
* If the version *does* line up we can just copy the bytes
* which is good because this is how shard request caching
Expand All @@ -116,7 +145,7 @@ public T get() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(remoteVersion);
in.setVersion(serializedAtVersion);
return reader.read(in);
}
} catch (IOException e) {
Expand All @@ -125,8 +154,18 @@ public T get() {
}

@Override
public boolean isDelayed() {
public Serialized<T> asSerialized(Reader<T> reader, NamedWriteableRegistry registry) {
return this; // We're already serialized
}

@Override
boolean isSerialized() {
return true;
}

@Override
public long ramBytesUsed() {
return serialized.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 3 + RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
}
}
}
3 changes: 2 additions & 1 deletion server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ protected Node(final Environment initialEnvironment,
b.bind(MetadataCreateIndexService.class).toInstance(metadataCreateIndexService);
b.bind(SearchService.class).toInstance(searchService);
b.bind(SearchTransportService.class).toInstance(searchTransportService);
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::aggReduceContextBuilder));
b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(
namedWriteableRegistry, searchService::aggReduceContextBuilder));
b.bind(Transport.class).toInstance(transport);
b.bind(TransportService.class).toInstance(transportService);
b.bind(NetworkService.class).toInstance(networkService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ public void run() throws IOException {
}

private SearchPhaseController searchPhaseController() {
return new SearchPhaseController(request -> InternalAggregationTestCase.emptyReduceContextBuilder());
return new SearchPhaseController(writableRegistry(), request -> InternalAggregationTestCase.emptyReduceContextBuilder());
}
}
Loading