diff --git a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp index c3a17902d9780..0212918556e1a 100644 --- a/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp +++ b/presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp @@ -112,7 +112,7 @@ class LocalShuffleSerializedPage : public ShuffleSerializedPage { velox::BufferPtr buffer) : rows_{std::move(rows)}, buffer_{std::move(buffer)} {} - const std::vector& rows() override { + const std::vector& rows(int32_t /*driverId*/) override { return rows_; } diff --git a/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h b/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h index 456dcaf27099f..32e00f72ea013 100644 --- a/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h +++ b/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h @@ -63,7 +63,14 @@ class ShuffleSerializedPage : public velox::exec::SerializedPageBase { VELOX_UNSUPPORTED(); } - virtual const std::vector& rows() = 0; + /// Legacy single-consumer path that delegates to rows(0).. + /// retained for backward compatibility. + virtual const std::vector& rows() { + return rows(0); + } + + /// @param driverId Driver ID for per-consumer checksum tracking. + virtual const std::vector& rows(int32_t driverId) = 0; }; class ShuffleReader { diff --git a/presto-native-execution/presto_cpp/main/operators/ShuffleRead.cpp b/presto-native-execution/presto_cpp/main/operators/ShuffleRead.cpp index aa4f2e9df43d1..5e873db9d3d02 100644 --- a/presto-native-execution/presto_cpp/main/operators/ShuffleRead.cpp +++ b/presto-native-execution/presto_cpp/main/operators/ShuffleRead.cpp @@ -84,9 +84,10 @@ RowVectorPtr ShuffleRead::getOutput() { numRows += pageRows; } rows_.reserve(numRows); + const int32_t driverId = operatorCtx()->driverCtx()->driverId; for (const auto& page : currentPages_) { auto* batch = checkedPointerCast(page.get()); - const auto& rows = batch->rows(); + const auto& rows = batch->rows(driverId); for (const auto& row : rows) { rows_.emplace_back(row); }