Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bolt/shuffle/sparksql/CompressionStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class AdaptiveParallelZstdCodec {
// stream
bytedance::bolt::NanosecondTimer timer1(&writeTime);
RETURN_NOT_OK(outputStream->Write(output, outputPos));
outputPos = 0;
}
} while (inputLen > 0);
}
Expand Down
48 changes: 48 additions & 0 deletions bolt/shuffle/sparksql/tests/AdaptiveParallelZstdCodecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,31 @@ void runRoundTrip(
EXPECT_EQ(0, std::memcmp(output.data(), expected.data(), expected.size()));
}

std::vector<uint8_t> buildIncompressibleData(size_t size) {
std::vector<uint8_t> data(size);
uint32_t state = 0x12345678u;
for (size_t i = 0; i < size; ++i) {
state = state * 1664525u + 1013904223u;
data[i] = static_cast<uint8_t>(state >> 24);
}
return data;
}

std::vector<uint8_t> buildIncompressibleRow(size_t payloadSize, uint32_t seed) {
std::vector<uint8_t> row(sizeof(int32_t) + payloadSize);
auto payloadSize32 = static_cast<int32_t>(payloadSize);
std::memcpy(row.data(), &payloadSize32, sizeof(int32_t));

auto payload = buildIncompressibleData(payloadSize);
uint32_t state = seed;
for (size_t i = 0; i < payloadSize; ++i) {
state = state * 1103515245u + 12345u;
payload[i] ^= static_cast<uint8_t>(state >> 24);
}
std::memcpy(row.data() + sizeof(int32_t), payload.data(), payload.size());
return row;
}

} // namespace

TEST(AdaptiveParallelZstdCodecTest, RoundTripSmallPayloads) {
Expand Down Expand Up @@ -144,4 +169,27 @@ TEST(AdaptiveParallelZstdCodecTest, RoundTripLargePayload) {
runRoundTrip(rows, rawSize, RowVectorLayout::kComposite);
}

TEST(
AdaptiveParallelZstdCodecTest,
CompressAndFlushStressRoundTripWithoutCorruption) {
constexpr int32_t kRounds = 6;
constexpr int32_t kRowsPerRound = 256;
const auto payloadSize =
static_cast<size_t>(ZSTD_CStreamInSize() - sizeof(int32_t) - 1);

for (int32_t round = 0; round < kRounds; ++round) {
std::vector<std::vector<uint8_t>> rows;
rows.reserve(kRowsPerRound);

int64_t rawSize = 0;
for (int32_t i = 0; i < kRowsPerRound; ++i) {
rows.emplace_back(buildIncompressibleRow(
payloadSize, static_cast<uint32_t>(round * kRowsPerRound + i + 1)));
rawSize += static_cast<int64_t>(rows.back().size());
}

runRoundTrip(rows, rawSize, RowVectorLayout::kComposite);
}
}

} // namespace bytedance::bolt::shuffle::sparksql::test