Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.prestosql.spi.Page;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockEncodingSerde;

import java.util.Iterator;
import java.util.List;

import static io.prestosql.block.BlockSerdeUtil.readBlock;
import static io.prestosql.block.BlockSerdeUtil.writeBlock;
Expand All @@ -32,6 +35,13 @@ public final class PagesSerdeUtil
{
private PagesSerdeUtil() {}

/**
* Special checksum value used to verify configuration consistency across nodes (all nodes need to have data integrity configured the same way).
*
* @implNote It's not just 0, so that hypothetical zero-ed out data is not treated as valid payload with no checksum.
*/
public static final long NO_CHECKSUM = 0x0123456789abcdefL;
Comment thread
findepi marked this conversation as resolved.
Outdated

static void writeRawPage(Page page, SliceOutput output, BlockEncodingSerde serde)
{
output.writeInt(page.getChannelCount());
Expand All @@ -53,13 +63,24 @@ static Page readRawPage(int positionCount, SliceInput input, BlockEncodingSerde

public static void writeSerializedPage(SliceOutput output, SerializedPage page)
{
// Every new field being written here must be added in updateChecksum() too.
output.writeInt(page.getPositionCount());
output.writeByte(page.getPageCodecMarkers());
output.writeInt(page.getUncompressedSizeInBytes());
output.writeInt(page.getSizeInBytes());
output.writeBytes(page.getSlice());
}

private static void updateChecksum(XxHash64 hash, SerializedPage page)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: aren't we typically ordering methods that usage is first and declaration follows?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updateChecksum is placed right under writeSerializedPage, because they need to stay in sync.
i added a code comment

{
hash.update(Slices.wrappedIntArray(
page.getPositionCount(),
page.getPageCodecMarkers(),
page.getUncompressedSizeInBytes(),
page.getSizeInBytes()));
hash.update(page.getSlice());
}

private static SerializedPage readSerializedPage(SliceInput sliceInput)
{
int positionCount = sliceInput.readInt();
Expand All @@ -82,6 +103,20 @@ public static long writeSerializedPages(SliceOutput sliceOutput, Iterable<Serial
return size;
}

public static long calculateChecksum(List<SerializedPage> pages)
{
XxHash64 hash = new XxHash64();
for (SerializedPage page : pages) {
updateChecksum(hash, page);
}
long checksum = hash.hash();
// Since NO_CHECKSUM is assigned a special meaning, it is not a valid checksum.
if (checksum == NO_CHECKSUM) {
return checksum + 1;
}
return checksum;
}

public static long writePages(PagesSerde serde, SliceOutput sliceOutput, Page... pages)
{
return writePages(serde, sliceOutput, asList(pages).iterator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.operator.HttpPageBufferClient.ClientCallback;
import io.prestosql.operator.WorkProcessor.ProcessState;
import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
Expand Down Expand Up @@ -57,6 +58,8 @@ public class ExchangeClient
{
private static final SerializedPage NO_MORE_PAGES = new SerializedPage(EMPTY_SLICE, PageCodecMarker.MarkerSet.empty(), 0, 0);

private final String selfAddress;
private final DataIntegrityVerification dataIntegrityVerification;
private final long bufferCapacity;
private final DataSize maxResponseSize;
private final int concurrentRequestMultiplier;
Expand Down Expand Up @@ -97,6 +100,8 @@ public class ExchangeClient
// ExchangeClientStatus.mergeWith assumes all clients have the same bufferCapacity.
// Please change that method accordingly when this assumption becomes not true.
public ExchangeClient(
String selfAddress,
DataIntegrityVerification dataIntegrityVerification,
DataSize bufferCapacity,
DataSize maxResponseSize,
int concurrentRequestMultiplier,
Expand All @@ -107,6 +112,8 @@ public ExchangeClient(
LocalMemoryContext systemMemoryContext,
Executor pageBufferClientCallbackExecutor)
{
this.selfAddress = requireNonNull(selfAddress, "selfAddress is null");
this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null");
this.bufferCapacity = bufferCapacity.toBytes();
this.maxResponseSize = maxResponseSize;
this.concurrentRequestMultiplier = concurrentRequestMultiplier;
Expand Down Expand Up @@ -156,7 +163,9 @@ public synchronized void addLocation(URI location)
checkState(!noMoreLocations, "No more locations already set");

HttpPageBufferClient client = new HttpPageBufferClient(
selfAddress,
httpClient,
dataIntegrityVerification,
maxResponseSize,
maxErrorDuration,
acknowledgePages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import io.airlift.concurrent.ThreadPoolExecutorMBean;
import io.airlift.http.client.HttpClient;
import io.airlift.node.NodeInfo;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

Expand All @@ -36,6 +39,8 @@
public class ExchangeClientFactory
implements ExchangeClientSupplier
{
private final NodeInfo nodeInfo;
private final DataIntegrityVerification dataIntegrityVerification;
private final DataSize maxBufferedBytes;
private final int concurrentRequestMultiplier;
private final Duration maxErrorDuration;
Expand All @@ -48,11 +53,15 @@ public class ExchangeClientFactory

@Inject
public ExchangeClientFactory(
NodeInfo nodeInfo,
FeaturesConfig featuresConfig,
ExchangeClientConfig config,
@ForExchange HttpClient httpClient,
@ForExchange ScheduledExecutorService scheduler)
{
this(
nodeInfo,
featuresConfig.getExchangeDataIntegrityVerification(),
config.getMaxBufferSize(),
config.getMaxResponseSize(),
config.getConcurrentRequestMultiplier(),
Expand All @@ -64,6 +73,8 @@ public ExchangeClientFactory(
}

public ExchangeClientFactory(
NodeInfo nodeInfo,
DataIntegrityVerification dataIntegrityVerification,
DataSize maxBufferedBytes,
DataSize maxResponseSize,
int concurrentRequestMultiplier,
Expand All @@ -73,6 +84,8 @@ public ExchangeClientFactory(
HttpClient httpClient,
ScheduledExecutorService scheduler)
{
this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null");
this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null");
this.maxBufferedBytes = requireNonNull(maxBufferedBytes, "maxBufferedBytes is null");
this.concurrentRequestMultiplier = concurrentRequestMultiplier;
this.maxErrorDuration = requireNonNull(maxErrorDuration, "maxErrorDuration is null");
Expand Down Expand Up @@ -112,6 +125,8 @@ public ThreadPoolExecutorMBean getExecutor()
public ExchangeClient get(LocalMemoryContext systemMemoryContext)
{
return new ExchangeClient(
nodeInfo.getExternalAddress(),
dataIntegrityVerification,
maxBufferedBytes,
maxResponseSize,
concurrentRequestMultiplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.prestosql.execution.buffer.SerializedPage;
import io.prestosql.server.remotetask.Backoff;
import io.prestosql.spi.PrestoException;
import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification;
import org.joda.time.DateTime;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -70,10 +71,14 @@
import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN;
import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_TOKEN;
import static io.prestosql.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID;
import static io.prestosql.execution.buffer.PagesSerdeUtil.NO_CHECKSUM;
import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum;
import static io.prestosql.execution.buffer.PagesSerdeUtil.readSerializedPages;
import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createEmptyPagesResponse;
import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createPagesResponse;
import static io.prestosql.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC;
import static io.prestosql.spi.HostAddress.fromUri;
import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.prestosql.spi.StandardErrorCode.REMOTE_BUFFER_CLOSE_FAILED;
import static io.prestosql.spi.StandardErrorCode.REMOTE_TASK_MISMATCH;
import static io.prestosql.util.Failures.REMOTE_TASK_MISMATCH_ERROR;
Expand Down Expand Up @@ -109,7 +114,9 @@ public interface ClientCallback
void clientFailed(HttpPageBufferClient client, Throwable cause);
}

private final String selfAddress;
private final HttpClient httpClient;
private final DataIntegrityVerification dataIntegrityVerification;
private final DataSize maxResponseSize;
private final boolean acknowledgePages;
private final URI location;
Expand Down Expand Up @@ -145,7 +152,9 @@ public interface ClientCallback
private final Executor pageBufferClientCallbackExecutor;

public HttpPageBufferClient(
String selfAddress,
HttpClient httpClient,
DataIntegrityVerification dataIntegrityVerification,
DataSize maxResponseSize,
Duration maxErrorDuration,
boolean acknowledgePages,
Expand All @@ -154,11 +163,24 @@ public HttpPageBufferClient(
ScheduledExecutorService scheduler,
Executor pageBufferClientCallbackExecutor)
{
this(httpClient, maxResponseSize, maxErrorDuration, acknowledgePages, location, clientCallback, scheduler, Ticker.systemTicker(), pageBufferClientCallbackExecutor);
this(
selfAddress,
httpClient,
dataIntegrityVerification,
maxResponseSize,
maxErrorDuration,
acknowledgePages,
location,
clientCallback,
scheduler,
Ticker.systemTicker(),
pageBufferClientCallbackExecutor);
}

public HttpPageBufferClient(
String selfAddress,
HttpClient httpClient,
DataIntegrityVerification dataIntegrityVerification,
DataSize maxResponseSize,
Duration maxErrorDuration,
boolean acknowledgePages,
Expand All @@ -168,7 +190,9 @@ public HttpPageBufferClient(
Ticker ticker,
Executor pageBufferClientCallbackExecutor)
{
this.selfAddress = requireNonNull(selfAddress, "selfAddress is null");
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null");
this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null");
this.acknowledgePages = acknowledgePages;
this.location = requireNonNull(location, "location is null");
Expand Down Expand Up @@ -301,7 +325,7 @@ private synchronized void sendGetResults()
prepareGet()
.setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
.setUri(uri).build(),
new PageResponseHandler());
new PageResponseHandler(dataIntegrityVerification != DataIntegrityVerification.NONE));

future = resultFuture;
Futures.addCallback(resultFuture, new FutureCallback<PagesResponse>()
Expand Down Expand Up @@ -403,6 +427,22 @@ public void onFailure(Throwable t)
log.debug("Request to %s failed %s", uri, t);
checkNotHoldsLock(this);

if (t instanceof ChecksumVerificationException) {
switch (dataIntegrityVerification) {
case NONE:
// In case of NONE, failure is possible in case of inconsistent cluster configuration, so we should not retry.
case ABORT:
// PrestoException will not be retried
t = new PrestoException(GENERIC_INTERNAL_ERROR, format("Checksum verification failure on %s when reading from %s: %s", selfAddress, uri, t.getMessage()), t);
break;
case RETRY:
log.warn("Checksum verification failure on %s when reading from %s, may be retried: %s", selfAddress, uri, t.getMessage());
break;
default:
throw new AssertionError("Unsupported option: " + dataIntegrityVerification);
}
}

t = rewriteException(t);
if (!(t instanceof PrestoException) && backoff.failure()) {
String message = format("%s (%s - %s failures, failure duration %s, total failed request time %s)",
Expand Down Expand Up @@ -542,6 +582,13 @@ private static Throwable rewriteException(Throwable t)
public static class PageResponseHandler
implements ResponseHandler<PagesResponse, RuntimeException>
{
private final boolean dataIntegrityVerificationEnabled;

private PageResponseHandler(boolean dataIntegrityVerificationEnabled)
{
this.dataIntegrityVerificationEnabled = dataIntegrityVerificationEnabled;
}

@Override
public PagesResponse handleException(Request request, Exception exception)
{
Expand Down Expand Up @@ -593,7 +640,15 @@ public PagesResponse handle(Request request, Response response)
boolean complete = getComplete(response);

try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) {
int magic = input.readInt();
if (magic != SERIALIZED_PAGES_MAGIC) {
throw new IllegalStateException(format("Invalid stream header, expected 0x%08x, but was 0x%08x", SERIALIZED_PAGES_MAGIC, magic));
}
long checksum = input.readLong();
int pagesCount = input.readInt();
List<SerializedPage> pages = ImmutableList.copyOf(readSerializedPages(input));
verifyChecksum(checksum, pages);
checkState(pages.size() == pagesCount, "Wrong number of pages, expected %s, but read %s", pagesCount, pages.size());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this surface to the user? It'd be ideal if it could be mapped to a proper error code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practise, this is reachable only when checksum verification is OFF. This is then treated as if IO error occurred and retried.

return createPagesResponse(taskInstanceId, token, nextToken, pages, complete);
}
catch (IOException e) {
Expand All @@ -605,6 +660,21 @@ public PagesResponse handle(Request request, Response response)
}
}

private void verifyChecksum(long readChecksum, List<SerializedPage> pages)
{
if (dataIntegrityVerificationEnabled) {
long calculatedChecksum = calculateChecksum(pages);
if (readChecksum != calculatedChecksum) {
throw new ChecksumVerificationException(format("Data corruption, read checksum: 0x%08x, calculated checksum: 0x%08x", readChecksum, calculatedChecksum));
}
}
else {
if (readChecksum != NO_CHECKSUM) {
throw new ChecksumVerificationException(format("Expected checksum to be NO_CHECKSUM (0x%08x) but is 0x%08x", NO_CHECKSUM, readChecksum));
}
}
}

private static String getTaskInstanceId(Response response)
{
String taskInstanceId = response.getHeader(PRESTO_TASK_INSTANCE_ID);
Expand Down Expand Up @@ -715,4 +785,13 @@ public String toString()
.toString();
}
}

private static class ChecksumVerificationException
extends RuntimeException
{
public ChecksumVerificationException(String message)
{
super(requireNonNull(message, "message is null"));
}
}
}
Loading