From 374613aadfbe1a70716138f764df51e0b0b47d7c Mon Sep 17 00:00:00 2001 From: Zhang Xiaofeng Date: Fri, 27 Feb 2026 14:39:55 +0800 Subject: [PATCH] fix(shuffle): Reset outputPos after flushing output buffer in CompressInternal --- bolt/shuffle/sparksql/CompressionStream.h | 1 + .../tests/AdaptiveParallelZstdCodecTest.cpp | 48 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/bolt/shuffle/sparksql/CompressionStream.h b/bolt/shuffle/sparksql/CompressionStream.h index 08de4d8f8..a36f6f47b 100644 --- a/bolt/shuffle/sparksql/CompressionStream.h +++ b/bolt/shuffle/sparksql/CompressionStream.h @@ -317,6 +317,7 @@ class AdaptiveParallelZstdCodec { // stream bytedance::bolt::NanosecondTimer timer1(&writeTime); RETURN_NOT_OK(outputStream->Write(output, outputPos)); + outputPos = 0; } } while (inputLen > 0); } diff --git a/bolt/shuffle/sparksql/tests/AdaptiveParallelZstdCodecTest.cpp b/bolt/shuffle/sparksql/tests/AdaptiveParallelZstdCodecTest.cpp index c055ec14d..db93b77d4 100644 --- a/bolt/shuffle/sparksql/tests/AdaptiveParallelZstdCodecTest.cpp +++ b/bolt/shuffle/sparksql/tests/AdaptiveParallelZstdCodecTest.cpp @@ -115,6 +115,31 @@ void runRoundTrip( EXPECT_EQ(0, std::memcmp(output.data(), expected.data(), expected.size())); } +std::vector buildIncompressibleData(size_t size) { + std::vector data(size); + uint32_t state = 0x12345678u; + for (size_t i = 0; i < size; ++i) { + state = state * 1664525u + 1013904223u; + data[i] = static_cast(state >> 24); + } + return data; +} + +std::vector buildIncompressibleRow(size_t payloadSize, uint32_t seed) { + std::vector row(sizeof(int32_t) + payloadSize); + auto payloadSize32 = static_cast(payloadSize); + std::memcpy(row.data(), &payloadSize32, sizeof(int32_t)); + + auto payload = buildIncompressibleData(payloadSize); + uint32_t state = seed; + for (size_t i = 0; i < payloadSize; ++i) { + state = state * 1103515245u + 12345u; + payload[i] ^= static_cast(state >> 24); + } + std::memcpy(row.data() + sizeof(int32_t), payload.data(), payload.size()); + return row; +} + } // namespace TEST(AdaptiveParallelZstdCodecTest, RoundTripSmallPayloads) { @@ -144,4 +169,27 @@ TEST(AdaptiveParallelZstdCodecTest, RoundTripLargePayload) { runRoundTrip(rows, rawSize, RowVectorLayout::kComposite); } +TEST( + AdaptiveParallelZstdCodecTest, + CompressAndFlushStressRoundTripWithoutCorruption) { + constexpr int32_t kRounds = 6; + constexpr int32_t kRowsPerRound = 256; + const auto payloadSize = + static_cast(ZSTD_CStreamInSize() - sizeof(int32_t) - 1); + + for (int32_t round = 0; round < kRounds; ++round) { + std::vector> rows; + rows.reserve(kRowsPerRound); + + int64_t rawSize = 0; + for (int32_t i = 0; i < kRowsPerRound; ++i) { + rows.emplace_back(buildIncompressibleRow( + payloadSize, static_cast(round * kRowsPerRound + i + 1))); + rawSize += static_cast(rows.back().size()); + } + + runRoundTrip(rows, rawSize, RowVectorLayout::kComposite); + } +} + } // namespace bytedance::bolt::shuffle::sparksql::test