-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Verify data integrity in exchanges #3438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,14 @@ | |
| import io.airlift.slice.Slice; | ||
| import io.airlift.slice.SliceInput; | ||
| import io.airlift.slice.SliceOutput; | ||
| import io.airlift.slice.Slices; | ||
| import io.airlift.slice.XxHash64; | ||
| import io.prestosql.spi.Page; | ||
| import io.prestosql.spi.block.Block; | ||
| import io.prestosql.spi.block.BlockEncodingSerde; | ||
|
|
||
| import java.util.Iterator; | ||
| import java.util.List; | ||
|
|
||
| import static io.prestosql.block.BlockSerdeUtil.readBlock; | ||
| import static io.prestosql.block.BlockSerdeUtil.writeBlock; | ||
|
|
@@ -32,6 +35,13 @@ public final class PagesSerdeUtil | |
| { | ||
| private PagesSerdeUtil() {} | ||
|
|
||
| /** | ||
| * Special checksum value used to verify configuration consistency across nodes (all nodes need to have data integrity configured the same way). | ||
| * | ||
| * @implNote It's not just 0, so that hypothetical zero-ed out data is not treated as valid payload with no checksum. | ||
| */ | ||
| public static final long NO_CHECKSUM = 0x0123456789abcdefL; | ||
|
|
||
| static void writeRawPage(Page page, SliceOutput output, BlockEncodingSerde serde) | ||
| { | ||
| output.writeInt(page.getChannelCount()); | ||
|
|
@@ -53,13 +63,24 @@ static Page readRawPage(int positionCount, SliceInput input, BlockEncodingSerde | |
|
|
||
| public static void writeSerializedPage(SliceOutput output, SerializedPage page) | ||
| { | ||
| // Every new field being written here must be added in updateChecksum() too. | ||
| output.writeInt(page.getPositionCount()); | ||
| output.writeByte(page.getPageCodecMarkers()); | ||
| output.writeInt(page.getUncompressedSizeInBytes()); | ||
| output.writeInt(page.getSizeInBytes()); | ||
| output.writeBytes(page.getSlice()); | ||
| } | ||
|
|
||
| private static void updateChecksum(XxHash64 hash, SerializedPage page) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: aren't we typically ordering methods that usage is first and declaration follows?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| { | ||
| hash.update(Slices.wrappedIntArray( | ||
| page.getPositionCount(), | ||
| page.getPageCodecMarkers(), | ||
| page.getUncompressedSizeInBytes(), | ||
| page.getSizeInBytes())); | ||
| hash.update(page.getSlice()); | ||
| } | ||
|
|
||
| private static SerializedPage readSerializedPage(SliceInput sliceInput) | ||
| { | ||
| int positionCount = sliceInput.readInt(); | ||
|
|
@@ -82,6 +103,20 @@ public static long writeSerializedPages(SliceOutput sliceOutput, Iterable<Serial | |
| return size; | ||
| } | ||
|
|
||
| public static long calculateChecksum(List<SerializedPage> pages) | ||
| { | ||
| XxHash64 hash = new XxHash64(); | ||
| for (SerializedPage page : pages) { | ||
| updateChecksum(hash, page); | ||
| } | ||
| long checksum = hash.hash(); | ||
| // Since NO_CHECKSUM is assigned a special meaning, it is not a valid checksum. | ||
| if (checksum == NO_CHECKSUM) { | ||
| return checksum + 1; | ||
| } | ||
| return checksum; | ||
| } | ||
|
|
||
| public static long writePages(PagesSerde serde, SliceOutput sliceOutput, Page... pages) | ||
| { | ||
| return writePages(serde, sliceOutput, asList(pages).iterator()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| import io.prestosql.execution.buffer.SerializedPage; | ||
| import io.prestosql.server.remotetask.Backoff; | ||
| import io.prestosql.spi.PrestoException; | ||
| import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; | ||
| import org.joda.time.DateTime; | ||
|
|
||
| import javax.annotation.Nullable; | ||
|
|
@@ -70,10 +71,14 @@ | |
| import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN; | ||
| import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_TOKEN; | ||
| import static io.prestosql.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; | ||
| import static io.prestosql.execution.buffer.PagesSerdeUtil.NO_CHECKSUM; | ||
| import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum; | ||
| import static io.prestosql.execution.buffer.PagesSerdeUtil.readSerializedPages; | ||
| import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createEmptyPagesResponse; | ||
| import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createPagesResponse; | ||
| import static io.prestosql.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; | ||
| import static io.prestosql.spi.HostAddress.fromUri; | ||
| import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; | ||
| import static io.prestosql.spi.StandardErrorCode.REMOTE_BUFFER_CLOSE_FAILED; | ||
| import static io.prestosql.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; | ||
| import static io.prestosql.util.Failures.REMOTE_TASK_MISMATCH_ERROR; | ||
|
|
@@ -109,7 +114,9 @@ public interface ClientCallback | |
| void clientFailed(HttpPageBufferClient client, Throwable cause); | ||
| } | ||
|
|
||
| private final String selfAddress; | ||
| private final HttpClient httpClient; | ||
| private final DataIntegrityVerification dataIntegrityVerification; | ||
| private final DataSize maxResponseSize; | ||
| private final boolean acknowledgePages; | ||
| private final URI location; | ||
|
|
@@ -145,7 +152,9 @@ public interface ClientCallback | |
| private final Executor pageBufferClientCallbackExecutor; | ||
|
|
||
| public HttpPageBufferClient( | ||
| String selfAddress, | ||
| HttpClient httpClient, | ||
| DataIntegrityVerification dataIntegrityVerification, | ||
| DataSize maxResponseSize, | ||
| Duration maxErrorDuration, | ||
| boolean acknowledgePages, | ||
|
|
@@ -154,11 +163,24 @@ public HttpPageBufferClient( | |
| ScheduledExecutorService scheduler, | ||
| Executor pageBufferClientCallbackExecutor) | ||
| { | ||
| this(httpClient, maxResponseSize, maxErrorDuration, acknowledgePages, location, clientCallback, scheduler, Ticker.systemTicker(), pageBufferClientCallbackExecutor); | ||
| this( | ||
| selfAddress, | ||
| httpClient, | ||
| dataIntegrityVerification, | ||
| maxResponseSize, | ||
| maxErrorDuration, | ||
| acknowledgePages, | ||
| location, | ||
| clientCallback, | ||
| scheduler, | ||
| Ticker.systemTicker(), | ||
| pageBufferClientCallbackExecutor); | ||
| } | ||
|
|
||
| public HttpPageBufferClient( | ||
| String selfAddress, | ||
| HttpClient httpClient, | ||
| DataIntegrityVerification dataIntegrityVerification, | ||
| DataSize maxResponseSize, | ||
| Duration maxErrorDuration, | ||
| boolean acknowledgePages, | ||
|
|
@@ -168,7 +190,9 @@ public HttpPageBufferClient( | |
| Ticker ticker, | ||
| Executor pageBufferClientCallbackExecutor) | ||
| { | ||
| this.selfAddress = requireNonNull(selfAddress, "selfAddress is null"); | ||
| this.httpClient = requireNonNull(httpClient, "httpClient is null"); | ||
| this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); | ||
| this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null"); | ||
| this.acknowledgePages = acknowledgePages; | ||
| this.location = requireNonNull(location, "location is null"); | ||
|
|
@@ -301,7 +325,7 @@ private synchronized void sendGetResults() | |
| prepareGet() | ||
| .setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()) | ||
| .setUri(uri).build(), | ||
| new PageResponseHandler()); | ||
| new PageResponseHandler(dataIntegrityVerification != DataIntegrityVerification.NONE)); | ||
|
|
||
| future = resultFuture; | ||
| Futures.addCallback(resultFuture, new FutureCallback<PagesResponse>() | ||
|
|
@@ -403,6 +427,22 @@ public void onFailure(Throwable t) | |
| log.debug("Request to %s failed %s", uri, t); | ||
| checkNotHoldsLock(this); | ||
|
|
||
| if (t instanceof ChecksumVerificationException) { | ||
| switch (dataIntegrityVerification) { | ||
| case NONE: | ||
| // In case of NONE, failure is possible in case of inconsistent cluster configuration, so we should not retry. | ||
| case ABORT: | ||
| // PrestoException will not be retried | ||
| t = new PrestoException(GENERIC_INTERNAL_ERROR, format("Checksum verification failure on %s when reading from %s: %s", selfAddress, uri, t.getMessage()), t); | ||
| break; | ||
| case RETRY: | ||
| log.warn("Checksum verification failure on %s when reading from %s, may be retried: %s", selfAddress, uri, t.getMessage()); | ||
| break; | ||
| default: | ||
| throw new AssertionError("Unsupported option: " + dataIntegrityVerification); | ||
| } | ||
| } | ||
|
|
||
| t = rewriteException(t); | ||
| if (!(t instanceof PrestoException) && backoff.failure()) { | ||
| String message = format("%s (%s - %s failures, failure duration %s, total failed request time %s)", | ||
|
|
@@ -542,6 +582,13 @@ private static Throwable rewriteException(Throwable t) | |
| public static class PageResponseHandler | ||
| implements ResponseHandler<PagesResponse, RuntimeException> | ||
| { | ||
| private final boolean dataIntegrityVerificationEnabled; | ||
|
|
||
| private PageResponseHandler(boolean dataIntegrityVerificationEnabled) | ||
| { | ||
| this.dataIntegrityVerificationEnabled = dataIntegrityVerificationEnabled; | ||
| } | ||
|
|
||
| @Override | ||
| public PagesResponse handleException(Request request, Exception exception) | ||
| { | ||
|
|
@@ -593,7 +640,15 @@ public PagesResponse handle(Request request, Response response) | |
| boolean complete = getComplete(response); | ||
|
|
||
| try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) { | ||
| int magic = input.readInt(); | ||
| if (magic != SERIALIZED_PAGES_MAGIC) { | ||
| throw new IllegalStateException(format("Invalid stream header, expected 0x%08x, but was 0x%08x", SERIALIZED_PAGES_MAGIC, magic)); | ||
| } | ||
| long checksum = input.readLong(); | ||
| int pagesCount = input.readInt(); | ||
| List<SerializedPage> pages = ImmutableList.copyOf(readSerializedPages(input)); | ||
| verifyChecksum(checksum, pages); | ||
| checkState(pages.size() == pagesCount, "Wrong number of pages, expected %s, but read %s", pagesCount, pages.size()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this surface to the user? It'd be ideal if it could be mapped to a proper error code.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practise, this is reachable only when checksum verification is OFF. This is then treated as if IO error occurred and retried. |
||
| return createPagesResponse(taskInstanceId, token, nextToken, pages, complete); | ||
| } | ||
| catch (IOException e) { | ||
|
|
@@ -605,6 +660,21 @@ public PagesResponse handle(Request request, Response response) | |
| } | ||
| } | ||
|
|
||
| private void verifyChecksum(long readChecksum, List<SerializedPage> pages) | ||
| { | ||
| if (dataIntegrityVerificationEnabled) { | ||
| long calculatedChecksum = calculateChecksum(pages); | ||
| if (readChecksum != calculatedChecksum) { | ||
| throw new ChecksumVerificationException(format("Data corruption, read checksum: 0x%08x, calculated checksum: 0x%08x", readChecksum, calculatedChecksum)); | ||
| } | ||
| } | ||
| else { | ||
| if (readChecksum != NO_CHECKSUM) { | ||
| throw new ChecksumVerificationException(format("Expected checksum to be NO_CHECKSUM (0x%08x) but is 0x%08x", NO_CHECKSUM, readChecksum)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private static String getTaskInstanceId(Response response) | ||
| { | ||
| String taskInstanceId = response.getHeader(PRESTO_TASK_INSTANCE_ID); | ||
|
|
@@ -715,4 +785,13 @@ public String toString() | |
| .toString(); | ||
| } | ||
| } | ||
|
|
||
| private static class ChecksumVerificationException | ||
| extends RuntimeException | ||
| { | ||
| public ChecksumVerificationException(String message) | ||
| { | ||
| super(requireNonNull(message, "message is null")); | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.