diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/PagesSerdeUtil.java b/presto-main/src/main/java/io/prestosql/execution/buffer/PagesSerdeUtil.java index 841d4ea4037d..d6c8735d5232 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/PagesSerdeUtil.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/PagesSerdeUtil.java @@ -17,11 +17,14 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import io.airlift.slice.XxHash64; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockEncodingSerde; import java.util.Iterator; +import java.util.List; import static io.prestosql.block.BlockSerdeUtil.readBlock; import static io.prestosql.block.BlockSerdeUtil.writeBlock; @@ -32,6 +35,13 @@ public final class PagesSerdeUtil { private PagesSerdeUtil() {} + /** + * Special checksum value used to verify configuration consistency across nodes (all nodes need to have data integrity configured the same way). + * + * @implNote It's not just 0, so that hypothetical zero-ed out data is not treated as valid payload with no checksum. + */ + public static final long NO_CHECKSUM = 0x0123456789abcdefL; + static void writeRawPage(Page page, SliceOutput output, BlockEncodingSerde serde) { output.writeInt(page.getChannelCount()); @@ -53,6 +63,7 @@ static Page readRawPage(int positionCount, SliceInput input, BlockEncodingSerde public static void writeSerializedPage(SliceOutput output, SerializedPage page) { + // Every new field being written here must be added in updateChecksum() too. output.writeInt(page.getPositionCount()); output.writeByte(page.getPageCodecMarkers()); output.writeInt(page.getUncompressedSizeInBytes()); @@ -60,6 +71,16 @@ public static void writeSerializedPage(SliceOutput output, SerializedPage page) output.writeBytes(page.getSlice()); } + private static void updateChecksum(XxHash64 hash, SerializedPage page) + { + hash.update(Slices.wrappedIntArray( + page.getPositionCount(), + page.getPageCodecMarkers(), + page.getUncompressedSizeInBytes(), + page.getSizeInBytes())); + hash.update(page.getSlice()); + } + private static SerializedPage readSerializedPage(SliceInput sliceInput) { int positionCount = sliceInput.readInt(); @@ -82,6 +103,20 @@ public static long writeSerializedPages(SliceOutput sliceOutput, Iterable pages) + { + XxHash64 hash = new XxHash64(); + for (SerializedPage page : pages) { + updateChecksum(hash, page); + } + long checksum = hash.hash(); + // Since NO_CHECKSUM is assigned a special meaning, it is not a valid checksum. + if (checksum == NO_CHECKSUM) { + return checksum + 1; + } + return checksum; + } + public static long writePages(PagesSerde serde, SliceOutput sliceOutput, Page... pages) { return writePages(serde, sliceOutput, asList(pages).iterator()); diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java index 14ea3d8d6e11..9ca9887781d2 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java @@ -25,6 +25,7 @@ import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.operator.HttpPageBufferClient.ClientCallback; import io.prestosql.operator.WorkProcessor.ProcessState; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -57,6 +58,8 @@ public class ExchangeClient { private static final SerializedPage NO_MORE_PAGES = new SerializedPage(EMPTY_SLICE, PageCodecMarker.MarkerSet.empty(), 0, 0); + private final String selfAddress; + private final DataIntegrityVerification dataIntegrityVerification; private final long bufferCapacity; private final DataSize maxResponseSize; private final int concurrentRequestMultiplier; @@ -97,6 +100,8 @@ public class ExchangeClient // ExchangeClientStatus.mergeWith assumes all clients have the same bufferCapacity. // Please change that method accordingly when this assumption becomes not true. public ExchangeClient( + String selfAddress, + DataIntegrityVerification dataIntegrityVerification, DataSize bufferCapacity, DataSize maxResponseSize, int concurrentRequestMultiplier, @@ -107,6 +112,8 @@ public ExchangeClient( LocalMemoryContext systemMemoryContext, Executor pageBufferClientCallbackExecutor) { + this.selfAddress = requireNonNull(selfAddress, "selfAddress is null"); + this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); this.bufferCapacity = bufferCapacity.toBytes(); this.maxResponseSize = maxResponseSize; this.concurrentRequestMultiplier = concurrentRequestMultiplier; @@ -156,7 +163,9 @@ public synchronized void addLocation(URI location) checkState(!noMoreLocations, "No more locations already set"); HttpPageBufferClient client = new HttpPageBufferClient( + selfAddress, httpClient, + dataIntegrityVerification, maxResponseSize, maxErrorDuration, acknowledgePages, diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java index 1774d9234844..6a43d3667a11 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java @@ -15,9 +15,12 @@ import io.airlift.concurrent.ThreadPoolExecutorMBean; import io.airlift.http.client.HttpClient; +import io.airlift.node.NodeInfo; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.sql.analyzer.FeaturesConfig; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -36,6 +39,8 @@ public class ExchangeClientFactory implements ExchangeClientSupplier { + private final NodeInfo nodeInfo; + private final DataIntegrityVerification dataIntegrityVerification; private final DataSize maxBufferedBytes; private final int concurrentRequestMultiplier; private final Duration maxErrorDuration; @@ -48,11 +53,15 @@ public class ExchangeClientFactory @Inject public ExchangeClientFactory( + NodeInfo nodeInfo, + FeaturesConfig featuresConfig, ExchangeClientConfig config, @ForExchange HttpClient httpClient, @ForExchange ScheduledExecutorService scheduler) { this( + nodeInfo, + featuresConfig.getExchangeDataIntegrityVerification(), config.getMaxBufferSize(), config.getMaxResponseSize(), config.getConcurrentRequestMultiplier(), @@ -64,6 +73,8 @@ public ExchangeClientFactory( } public ExchangeClientFactory( + NodeInfo nodeInfo, + DataIntegrityVerification dataIntegrityVerification, DataSize maxBufferedBytes, DataSize maxResponseSize, int concurrentRequestMultiplier, @@ -73,6 +84,8 @@ public ExchangeClientFactory( HttpClient httpClient, ScheduledExecutorService scheduler) { + this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); + this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); this.maxBufferedBytes = requireNonNull(maxBufferedBytes, "maxBufferedBytes is null"); this.concurrentRequestMultiplier = concurrentRequestMultiplier; this.maxErrorDuration = requireNonNull(maxErrorDuration, "maxErrorDuration is null"); @@ -112,6 +125,8 @@ public ThreadPoolExecutorMBean getExecutor() public ExchangeClient get(LocalMemoryContext systemMemoryContext) { return new ExchangeClient( + nodeInfo.getExternalAddress(), + dataIntegrityVerification, maxBufferedBytes, maxResponseSize, concurrentRequestMultiplier, diff --git a/presto-main/src/main/java/io/prestosql/operator/HttpPageBufferClient.java b/presto-main/src/main/java/io/prestosql/operator/HttpPageBufferClient.java index 513548f9599a..15a88f0a539e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/HttpPageBufferClient.java +++ b/presto-main/src/main/java/io/prestosql/operator/HttpPageBufferClient.java @@ -34,6 +34,7 @@ import io.prestosql.execution.buffer.SerializedPage; import io.prestosql.server.remotetask.Backoff; import io.prestosql.spi.PrestoException; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import org.joda.time.DateTime; import javax.annotation.Nullable; @@ -70,10 +71,14 @@ import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; +import static io.prestosql.execution.buffer.PagesSerdeUtil.NO_CHECKSUM; +import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum; import static io.prestosql.execution.buffer.PagesSerdeUtil.readSerializedPages; import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createEmptyPagesResponse; import static io.prestosql.operator.HttpPageBufferClient.PagesResponse.createPagesResponse; +import static io.prestosql.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static io.prestosql.spi.HostAddress.fromUri; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.prestosql.spi.StandardErrorCode.REMOTE_BUFFER_CLOSE_FAILED; import static io.prestosql.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; import static io.prestosql.util.Failures.REMOTE_TASK_MISMATCH_ERROR; @@ -109,7 +114,9 @@ public interface ClientCallback void clientFailed(HttpPageBufferClient client, Throwable cause); } + private final String selfAddress; private final HttpClient httpClient; + private final DataIntegrityVerification dataIntegrityVerification; private final DataSize maxResponseSize; private final boolean acknowledgePages; private final URI location; @@ -145,7 +152,9 @@ public interface ClientCallback private final Executor pageBufferClientCallbackExecutor; public HttpPageBufferClient( + String selfAddress, HttpClient httpClient, + DataIntegrityVerification dataIntegrityVerification, DataSize maxResponseSize, Duration maxErrorDuration, boolean acknowledgePages, @@ -154,11 +163,24 @@ public HttpPageBufferClient( ScheduledExecutorService scheduler, Executor pageBufferClientCallbackExecutor) { - this(httpClient, maxResponseSize, maxErrorDuration, acknowledgePages, location, clientCallback, scheduler, Ticker.systemTicker(), pageBufferClientCallbackExecutor); + this( + selfAddress, + httpClient, + dataIntegrityVerification, + maxResponseSize, + maxErrorDuration, + acknowledgePages, + location, + clientCallback, + scheduler, + Ticker.systemTicker(), + pageBufferClientCallbackExecutor); } public HttpPageBufferClient( + String selfAddress, HttpClient httpClient, + DataIntegrityVerification dataIntegrityVerification, DataSize maxResponseSize, Duration maxErrorDuration, boolean acknowledgePages, @@ -168,7 +190,9 @@ public HttpPageBufferClient( Ticker ticker, Executor pageBufferClientCallbackExecutor) { + this.selfAddress = requireNonNull(selfAddress, "selfAddress is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.dataIntegrityVerification = requireNonNull(dataIntegrityVerification, "dataIntegrityVerification is null"); this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null"); this.acknowledgePages = acknowledgePages; this.location = requireNonNull(location, "location is null"); @@ -301,7 +325,7 @@ private synchronized void sendGetResults() prepareGet() .setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()) .setUri(uri).build(), - new PageResponseHandler()); + new PageResponseHandler(dataIntegrityVerification != DataIntegrityVerification.NONE)); future = resultFuture; Futures.addCallback(resultFuture, new FutureCallback() @@ -403,6 +427,22 @@ public void onFailure(Throwable t) log.debug("Request to %s failed %s", uri, t); checkNotHoldsLock(this); + if (t instanceof ChecksumVerificationException) { + switch (dataIntegrityVerification) { + case NONE: + // In case of NONE, failure is possible in case of inconsistent cluster configuration, so we should not retry. + case ABORT: + // PrestoException will not be retried + t = new PrestoException(GENERIC_INTERNAL_ERROR, format("Checksum verification failure on %s when reading from %s: %s", selfAddress, uri, t.getMessage()), t); + break; + case RETRY: + log.warn("Checksum verification failure on %s when reading from %s, may be retried: %s", selfAddress, uri, t.getMessage()); + break; + default: + throw new AssertionError("Unsupported option: " + dataIntegrityVerification); + } + } + t = rewriteException(t); if (!(t instanceof PrestoException) && backoff.failure()) { String message = format("%s (%s - %s failures, failure duration %s, total failed request time %s)", @@ -542,6 +582,13 @@ private static Throwable rewriteException(Throwable t) public static class PageResponseHandler implements ResponseHandler { + private final boolean dataIntegrityVerificationEnabled; + + private PageResponseHandler(boolean dataIntegrityVerificationEnabled) + { + this.dataIntegrityVerificationEnabled = dataIntegrityVerificationEnabled; + } + @Override public PagesResponse handleException(Request request, Exception exception) { @@ -593,7 +640,15 @@ public PagesResponse handle(Request request, Response response) boolean complete = getComplete(response); try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) { + int magic = input.readInt(); + if (magic != SERIALIZED_PAGES_MAGIC) { + throw new IllegalStateException(format("Invalid stream header, expected 0x%08x, but was 0x%08x", SERIALIZED_PAGES_MAGIC, magic)); + } + long checksum = input.readLong(); + int pagesCount = input.readInt(); List pages = ImmutableList.copyOf(readSerializedPages(input)); + verifyChecksum(checksum, pages); + checkState(pages.size() == pagesCount, "Wrong number of pages, expected %s, but read %s", pagesCount, pages.size()); return createPagesResponse(taskInstanceId, token, nextToken, pages, complete); } catch (IOException e) { @@ -605,6 +660,21 @@ public PagesResponse handle(Request request, Response response) } } + private void verifyChecksum(long readChecksum, List pages) + { + if (dataIntegrityVerificationEnabled) { + long calculatedChecksum = calculateChecksum(pages); + if (readChecksum != calculatedChecksum) { + throw new ChecksumVerificationException(format("Data corruption, read checksum: 0x%08x, calculated checksum: 0x%08x", readChecksum, calculatedChecksum)); + } + } + else { + if (readChecksum != NO_CHECKSUM) { + throw new ChecksumVerificationException(format("Expected checksum to be NO_CHECKSUM (0x%08x) but is 0x%08x", NO_CHECKSUM, readChecksum)); + } + } + } + private static String getTaskInstanceId(Response response) { String taskInstanceId = response.getHeader(PRESTO_TASK_INSTANCE_ID); @@ -715,4 +785,13 @@ public String toString() .toString(); } } + + private static class ChecksumVerificationException + extends RuntimeException + { + public ChecksumVerificationException(String message) + { + super(requireNonNull(message, "message is null")); + } + } } diff --git a/presto-main/src/main/java/io/prestosql/server/PagesResponseWriter.java b/presto-main/src/main/java/io/prestosql/server/PagesResponseWriter.java index 0d334a8d0fa0..bd99974e2642 100644 --- a/presto-main/src/main/java/io/prestosql/server/PagesResponseWriter.java +++ b/presto-main/src/main/java/io/prestosql/server/PagesResponseWriter.java @@ -17,7 +17,10 @@ import io.airlift.slice.OutputStreamSliceOutput; import io.airlift.slice.SliceOutput; import io.prestosql.execution.buffer.SerializedPage; +import io.prestosql.sql.analyzer.FeaturesConfig; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; +import javax.inject.Inject; import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; @@ -34,13 +37,18 @@ import java.util.List; import static io.prestosql.PrestoMediaTypes.PRESTO_PAGES; +import static io.prestosql.execution.buffer.PagesSerdeUtil.NO_CHECKSUM; +import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum; import static io.prestosql.execution.buffer.PagesSerdeUtil.writeSerializedPages; +import static java.util.Objects.requireNonNull; @Provider @Produces(PRESTO_PAGES) public class PagesResponseWriter implements MessageBodyWriter> { + public static final int SERIALIZED_PAGES_MAGIC = 0xfea4f001; + private static final MediaType PRESTO_PAGES_TYPE = MediaType.valueOf(PRESTO_PAGES); private static final Type LIST_GENERIC_TOKEN; @@ -53,6 +61,15 @@ public class PagesResponseWriter } } + private final boolean dataIntegrityVerificationEnabled; + + @Inject + public PagesResponseWriter(FeaturesConfig featuresConfig) + { + requireNonNull(featuresConfig, "featuresConfig is null"); + this.dataIntegrityVerificationEnabled = featuresConfig.getExchangeDataIntegrityVerification() != DataIntegrityVerification.NONE; + } + @Override public boolean isWriteable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { @@ -79,6 +96,9 @@ public void writeTo(List serializedPages, { try { SliceOutput sliceOutput = new OutputStreamSliceOutput(output); + sliceOutput.writeInt(SERIALIZED_PAGES_MAGIC); + sliceOutput.writeLong(dataIntegrityVerificationEnabled ? calculateChecksum(serializedPages) : NO_CHECKSUM); + sliceOutput.writeInt(serializedPages.size()); writeSerializedPages(sliceOutput, serializedPages); // We use flush instead of close, because the underlying stream would be closed and that is not allowed. sliceOutput.flush(); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java index 72e04e97b4c2..6cb42c0cc090 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java @@ -84,6 +84,7 @@ public class FeaturesConfig private boolean optimizeHashGeneration = true; private boolean enableIntermediateAggregations; private boolean pushTableWriteThroughUnion = true; + private DataIntegrityVerification exchangeDataIntegrityVerification = DataIntegrityVerification.ABORT; private boolean exchangeCompressionEnabled; private boolean legacyTimestamp = true; private boolean optimizeMixedDistinctAggregations; @@ -159,6 +160,14 @@ public boolean canReplicate() } } + public enum DataIntegrityVerification + { + NONE, + ABORT, + RETRY, + /**/; + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -814,6 +823,18 @@ public FeaturesConfig setExchangeCompressionEnabled(boolean exchangeCompressionE return this; } + public DataIntegrityVerification getExchangeDataIntegrityVerification() + { + return exchangeDataIntegrityVerification; + } + + @Config("exchange.data-integrity-verification") + public FeaturesConfig setExchangeDataIntegrityVerification(DataIntegrityVerification exchangeDataIntegrityVerification) + { + this.exchangeDataIntegrityVerification = exchangeDataIntegrityVerification; + return this; + } + public boolean isEnableIntermediateAggregations() { return enableIntermediateAggregations; diff --git a/presto-main/src/test/java/io/prestosql/operator/MockExchangeRequestProcessor.java b/presto-main/src/test/java/io/prestosql/operator/MockExchangeRequestProcessor.java index 55614d0d69b9..081fdd32a7a8 100644 --- a/presto-main/src/test/java/io/prestosql/operator/MockExchangeRequestProcessor.java +++ b/presto-main/src/test/java/io/prestosql/operator/MockExchangeRequestProcessor.java @@ -27,7 +27,6 @@ import io.prestosql.client.PrestoHeaders; import io.prestosql.execution.buffer.BufferResult; import io.prestosql.execution.buffer.PagesSerde; -import io.prestosql.execution.buffer.PagesSerdeUtil; import io.prestosql.execution.buffer.SerializedPage; import io.prestosql.spi.Page; @@ -47,7 +46,10 @@ import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; +import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum; +import static io.prestosql.execution.buffer.PagesSerdeUtil.writeSerializedPages; import static io.prestosql.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; +import static io.prestosql.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -97,7 +99,10 @@ public Response handle(Request request) HttpStatus status; if (!result.getSerializedPages().isEmpty()) { DynamicSliceOutput sliceOutput = new DynamicSliceOutput(64); - PagesSerdeUtil.writeSerializedPages(sliceOutput, result.getSerializedPages()); + sliceOutput.writeInt(SERIALIZED_PAGES_MAGIC); + sliceOutput.writeLong(calculateChecksum(result.getSerializedPages())); + sliceOutput.writeInt(result.getSerializedPages().size()); + writeSerializedPages(sliceOutput, result.getSerializedPages()); bytes = sliceOutput.slice().getBytes(); status = HttpStatus.OK; } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestExchangeClient.java b/presto-main/src/test/java/io/prestosql/operator/TestExchangeClient.java index a723af398653..d0ac81c48472 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestExchangeClient.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestExchangeClient.java @@ -14,9 +14,14 @@ package io.prestosql.operator; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.http.client.HttpStatus; +import io.airlift.http.client.Request; +import io.airlift.http.client.Response; import io.airlift.http.client.testing.TestingHttpClient; +import io.airlift.http.client.testing.TestingResponse; import io.airlift.units.DataSize; import io.airlift.units.DataSize.Unit; import io.airlift.units.Duration; @@ -25,17 +30,24 @@ import io.prestosql.execution.buffer.SerializedPage; import io.prestosql.memory.context.SimpleLocalMemoryContext; import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.net.URI; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import static com.google.common.collect.Maps.uniqueIndex; +import static com.google.common.io.ByteStreams.toByteArray; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; @@ -46,6 +58,7 @@ import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; @@ -94,6 +107,8 @@ public void testHappyPath() @SuppressWarnings("resource") ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, DataSize.of(32, Unit.MEGABYTE), maxResponseSize, 1, @@ -133,6 +148,8 @@ public void testAddLocation() @SuppressWarnings("resource") ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, DataSize.of(32, Unit.MEGABYTE), maxResponseSize, 1, @@ -205,6 +222,8 @@ public void testBufferLimit() @SuppressWarnings("resource") ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, DataSize.ofBytes(1), maxResponseSize, 1, @@ -273,6 +292,109 @@ public void testBufferLimit() assertStatus(exchangeClient.getStatus().getPageBufferClientStatuses().get(0), location, "closed", 3, 5, 5, "not scheduled"); } + @Test + public void testAbortOnDataCorruption() + { + URI location = URI.create("http://localhost:8080"); + ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.ABORT, location); + + assertFalse(exchangeClient.isClosed()); + assertThatThrownBy(() -> getNextPage(exchangeClient)) + .isInstanceOf(PrestoException.class) + .hasMessageMatching("Checksum verification failure on localhost when reading from http://localhost:8080/0: Data corruption, read checksum: 0xf91cfe5d2bc6e1c2, calculated checksum: 0x3c51297c7b78052f"); + + assertThatThrownBy(exchangeClient::isFinished) + .isInstanceOf(PrestoException.class) + .hasMessageMatching("Checksum verification failure on localhost when reading from http://localhost:8080/0: Data corruption, read checksum: 0xf91cfe5d2bc6e1c2, calculated checksum: 0x3c51297c7b78052f"); + + exchangeClient.close(); + } + + @Test + public void testRetryDataCorruption() + { + URI location = URI.create("http://localhost:8080"); + ExchangeClient exchangeClient = setUpDataCorruption(DataIntegrityVerification.RETRY, location); + + assertFalse(exchangeClient.isClosed()); + assertPageEquals(getNextPage(exchangeClient), createPage(1)); + assertFalse(exchangeClient.isClosed()); + assertPageEquals(getNextPage(exchangeClient), createPage(2)); + assertNull(getNextPage(exchangeClient)); + assertTrue(exchangeClient.isClosed()); + + ExchangeClientStatus status = exchangeClient.getStatus(); + assertEquals(status.getBufferedPages(), 0); + assertEquals(status.getBufferedBytes(), 0); + + assertStatus(status.getPageBufferClientStatuses().get(0), location, "closed", 2, 4, 4, "not scheduled"); + } + + private ExchangeClient setUpDataCorruption(DataIntegrityVerification dataIntegrityVerification, URI location) + { + DataSize maxResponseSize = DataSize.of(10, Unit.MEGABYTE); + + MockExchangeRequestProcessor delegate = new MockExchangeRequestProcessor(maxResponseSize); + delegate.addPage(location, createPage(1)); + delegate.addPage(location, createPage(2)); + delegate.setComplete(location); + + TestingHttpClient.Processor processor = new TestingHttpClient.Processor() + { + private int completedRequests; + private TestingResponse savedResponse; + + @Override + public synchronized Response handle(Request request) + throws Exception + { + if (completedRequests == 0) { + verify(savedResponse == null); + TestingResponse response = (TestingResponse) delegate.handle(request); + checkState(response.getStatusCode() == HttpStatus.OK.code(), "Unexpected status code: %s", response.getStatusCode()); + ListMultimap headers = response.getHeaders().entries().stream() + .collect(toImmutableListMultimap(entry -> entry.getKey().toString(), Map.Entry::getValue)); + byte[] bytes = toByteArray(response.getInputStream()); + checkState(bytes.length > 42, "too short"); + savedResponse = new TestingResponse(HttpStatus.OK, headers, bytes.clone()); + // corrupt + bytes[42]++; + completedRequests++; + return new TestingResponse(HttpStatus.OK, headers, bytes); + } + + if (completedRequests == 1) { + verify(savedResponse != null); + Response response = savedResponse; + savedResponse = null; + completedRequests++; + return response; + } + + completedRequests++; + return delegate.handle(request); + } + }; + + ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + dataIntegrityVerification, + DataSize.of(32, Unit.MEGABYTE), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor); + + exchangeClient.addLocation(location); + exchangeClient.noMoreLocations(); + + return exchangeClient; + } + @Test public void testClose() throws Exception @@ -287,6 +409,8 @@ public void testClose() @SuppressWarnings("resource") ExchangeClient exchangeClient = new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, DataSize.ofBytes(1), maxResponseSize, 1, diff --git a/presto-main/src/test/java/io/prestosql/operator/TestExchangeOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestExchangeOperator.java index c5eac5cf6716..dfbe44d54d26 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestExchangeOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestExchangeOperator.java @@ -29,6 +29,7 @@ import io.prestosql.spi.Page; import io.prestosql.spi.type.Type; import io.prestosql.split.RemoteSplit; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import io.prestosql.sql.planner.plan.PlanNodeId; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -84,6 +85,8 @@ public void setUp() httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), scheduler); exchangeClientSupplier = (systemMemoryUsageListener) -> new ExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, DataSize.of(32, MEGABYTE), DataSize.of(10, MEGABYTE), 3, diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHttpPageBufferClient.java b/presto-main/src/test/java/io/prestosql/operator/TestHttpPageBufferClient.java index f1dea48e0296..85e3ad230f91 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHttpPageBufferClient.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHttpPageBufferClient.java @@ -28,6 +28,7 @@ import io.prestosql.operator.HttpPageBufferClient.ClientCallback; import io.prestosql.spi.HostAddress; import io.prestosql.spi.Page; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -101,7 +102,10 @@ public void testHappyPath() TestingClientCallback callback = new TestingClientCallback(requestComplete); URI location = URI.create("http://localhost:8080"); - HttpPageBufferClient client = new HttpPageBufferClient(new TestingHttpClient(processor, scheduler), + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, expectedMaxSize, new Duration(1, TimeUnit.MINUTES), true, @@ -186,7 +190,10 @@ public void testLifecycle() TestingClientCallback callback = new TestingClientCallback(requestComplete); URI location = URI.create("http://localhost:8080"); - HttpPageBufferClient client = new HttpPageBufferClient(new TestingHttpClient(processor, scheduler), + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, DataSize.of(10, Unit.MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, @@ -226,7 +233,10 @@ public void testInvalidResponses() TestingClientCallback callback = new TestingClientCallback(requestComplete); URI location = URI.create("http://localhost:8080"); - HttpPageBufferClient client = new HttpPageBufferClient(new TestingHttpClient(processor, scheduler), + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, DataSize.of(10, Unit.MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, @@ -294,7 +304,10 @@ public void testCloseDuringPendingRequest() TestingClientCallback callback = new TestingClientCallback(requestComplete); URI location = URI.create("http://localhost:8080"); - HttpPageBufferClient client = new HttpPageBufferClient(new TestingHttpClient(processor, scheduler), + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, DataSize.of(10, Unit.MEGABYTE), new Duration(1, TimeUnit.MINUTES), true, @@ -348,7 +361,10 @@ public void testExceptionFromResponseHandler() TestingClientCallback callback = new TestingClientCallback(requestComplete); URI location = URI.create("http://localhost:8080"); - HttpPageBufferClient client = new HttpPageBufferClient(new TestingHttpClient(processor, scheduler), + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, DataSize.of(10, Unit.MEGABYTE), new Duration(30, TimeUnit.SECONDS), true, diff --git a/presto-main/src/test/java/io/prestosql/operator/TestMergeOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestMergeOperator.java index 1c23c7fba8d1..1e6480369424 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestMergeOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestMergeOperator.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.http.client.HttpClient; import io.airlift.http.client.testing.TestingHttpClient; +import io.airlift.node.NodeInfo; import io.prestosql.execution.Lifespan; import io.prestosql.execution.buffer.PagesSerdeFactory; import io.prestosql.execution.buffer.TestingPagesSerdeFactory; @@ -27,6 +28,7 @@ import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.type.Type; import io.prestosql.split.RemoteSplit; +import io.prestosql.sql.analyzer.FeaturesConfig; import io.prestosql.sql.gen.OrderingCompiler; import io.prestosql.sql.planner.plan.PlanNodeId; import org.testng.annotations.AfterMethod; @@ -82,7 +84,7 @@ public void setUp() taskBuffers = CacheBuilder.newBuilder().build(CacheLoader.from(TestingTaskBuffer::new)); httpClient = new TestingHttpClient(new TestingExchangeHttpClientHandler(taskBuffers), executor); - exchangeClientFactory = new ExchangeClientFactory(new ExchangeClientConfig(), httpClient, executor); + exchangeClientFactory = new ExchangeClientFactory(new NodeInfo("test"), new FeaturesConfig(), new ExchangeClientConfig(), httpClient, executor); orderingCompiler = new OrderingCompiler(); } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestingExchangeHttpClientHandler.java b/presto-main/src/test/java/io/prestosql/operator/TestingExchangeHttpClientHandler.java index a2faec7a9254..20e78272d669 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestingExchangeHttpClientHandler.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestingExchangeHttpClientHandler.java @@ -24,7 +24,7 @@ import io.airlift.http.client.testing.TestingResponse; import io.airlift.slice.DynamicSliceOutput; import io.prestosql.execution.buffer.PagesSerde; -import io.prestosql.execution.buffer.PagesSerdeUtil; +import io.prestosql.execution.buffer.SerializedPage; import io.prestosql.spi.Page; import static io.prestosql.PrestoMediaTypes.PRESTO_PAGES; @@ -32,7 +32,10 @@ import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_PAGE_TOKEN; import static io.prestosql.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; +import static io.prestosql.execution.buffer.PagesSerdeUtil.calculateChecksum; +import static io.prestosql.execution.buffer.PagesSerdeUtil.writeSerializedPage; import static io.prestosql.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; +import static io.prestosql.server.PagesResponseWriter.SERIALIZED_PAGES_MAGIC; import static java.util.Objects.requireNonNull; import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; import static org.testng.Assert.assertEquals; @@ -72,14 +75,22 @@ public Response handle(Request request) if (page != null) { headers.put(PRESTO_PAGE_NEXT_TOKEN, String.valueOf(pageToken + 1)); headers.put(PRESTO_BUFFER_COMPLETE, String.valueOf(false)); + SerializedPage serializedPage = PAGES_SERDE.serialize(page); DynamicSliceOutput output = new DynamicSliceOutput(256); - PagesSerdeUtil.writePages(PAGES_SERDE, output, page); + output.writeInt(SERIALIZED_PAGES_MAGIC); + output.writeLong(calculateChecksum(ImmutableList.of(serializedPage))); + output.writeInt(1); + writeSerializedPage(output, serializedPage); return new TestingResponse(HttpStatus.OK, headers.build(), output.slice().getInput()); } else if (taskBuffer.isFinished()) { headers.put(PRESTO_PAGE_NEXT_TOKEN, String.valueOf(pageToken)); headers.put(PRESTO_BUFFER_COMPLETE, String.valueOf(true)); - return new TestingResponse(HttpStatus.OK, headers.build(), new byte[0]); + DynamicSliceOutput output = new DynamicSliceOutput(8); + output.writeInt(SERIALIZED_PAGES_MAGIC); + output.writeLong(calculateChecksum(ImmutableList.of())); + output.writeInt(0); + return new TestingResponse(HttpStatus.OK, headers.build(), output.slice().getInput()); } else { headers.put(PRESTO_PAGE_NEXT_TOKEN, String.valueOf(pageToken)); diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java index b76cc844719b..f67000099c90 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java @@ -19,6 +19,7 @@ import io.prestosql.operator.aggregation.arrayagg.ArrayAggGroupImplementation; import io.prestosql.operator.aggregation.histogram.HistogramGroupImplementation; import io.prestosql.operator.aggregation.multimapagg.MultimapAggGroupImplementation; +import io.prestosql.sql.analyzer.FeaturesConfig.DataIntegrityVerification; import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType; import io.prestosql.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import org.testng.annotations.Test; @@ -88,6 +89,7 @@ public void testDefaults() .setDefaultFilterFactorEnabled(false) .setEnableForcedExchangeBelowGroupId(true) .setExchangeCompressionEnabled(false) + .setExchangeDataIntegrityVerification(DataIntegrityVerification.ABORT) .setLegacyTimestamp(true) .setEnableIntermediateAggregations(false) .setPushAggregationThroughOuterJoin(true) @@ -164,6 +166,7 @@ public void testExplicitPropertyMappings() .put("memory-revoking-threshold", "0.2") .put("memory-revoking-target", "0.8") .put("exchange.compression-enabled", "true") + .put("exchange.data-integrity-verification", "RETRY") .put("deprecated.legacy-timestamp", "false") .put("optimizer.enable-intermediate-aggregations", "true") .put("parse-decimal-literals-as-double", "true") @@ -234,6 +237,7 @@ public void testExplicitPropertyMappings() .setMemoryRevokingThreshold(0.2) .setMemoryRevokingTarget(0.8) .setExchangeCompressionEnabled(true) + .setExchangeDataIntegrityVerification(DataIntegrityVerification.RETRY) .setLegacyTimestamp(false) .setEnableIntermediateAggregations(true) .setParseDecimalLiteralsAsDouble(true)