diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index bfe482f598ede..acc52a56d6582 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -9,6 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(operators) add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) diff --git a/presto-native-execution/presto_cpp/main/operators/CMakeLists.txt b/presto-native-execution/presto_cpp/main/operators/CMakeLists.txt new file mode 100644 index 0000000000000..a0f1e61f5d973 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_operators + PartitionAndSerialize.cpp + ShuffleWrite.cpp + UnsafeRowExchangeSource.cpp) + +target_link_libraries( + presto_operators + velox_core + velox_exec + velox_expression + velox_hive_partition_function + velox_vector) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.cpp b/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.cpp new file mode 100644 index 0000000000000..008924055c6f5 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/operators/PartitionAndSerialize.h" +#include "velox/connectors/hive/HivePartitionFunction.h" +#include "velox/row/UnsafeRowDynamicSerializer.h" + +using namespace facebook::velox::exec; +using namespace facebook::velox; + +namespace facebook::presto::operators { + +void PartitionAndSerializeNode::addDetails(std::stringstream& stream) const { + stream << "("; + for (auto i = 0; i < keys_.size(); ++i) { + const auto& expr = keys_[i]; + if (i > 0) { + stream << ", "; + } + if (auto field = + std::dynamic_pointer_cast(expr)) { + stream << field->name(); + } else if ( + auto constant = + std::dynamic_pointer_cast(expr)) { + stream << constant->toString(); + } else { + stream << expr->toString(); + } + } + stream << ") " << numPartitions_; +} + +namespace { + +class PartitionAndSerializeOperator : public Operator { + public: + PartitionAndSerializeOperator( + int32_t operatorId, + DriverCtx* FOLLY_NONNULL ctx, + const std::shared_ptr& planNode) + : Operator( + ctx, + planNode->outputType(), + operatorId, + planNode->id(), + "PartitionAndSerialize") { + auto inputType = planNode->sources()[0]->outputType(); + auto keyChannels = toChannels(inputType, planNode->keys()); + + // Initialize the hive partition function. + auto numPartitions = planNode->numPartitions(); + std::vector bucketToPartition(numPartitions); + std::iota(bucketToPartition.begin(), bucketToPartition.end(), 0); + partitionFunction_ = + std::make_unique( + planNode->numPartitions(), + std::move(bucketToPartition), + keyChannels); + } + + bool needsInput() const override { + return !input_; + } + + void addInput(RowVectorPtr input) override { + input_ = std::move(input); + } + + RowVectorPtr getOutput() override { + if (!input_) { + return nullptr; + } + + auto numInput = input_->size(); + + // TODO Reuse output vector. + auto output = std::dynamic_pointer_cast( + BaseVector::create(outputType_, numInput, pool())); + + computePartitions(*output->childAt(0)->asFlatVector()); + + serializeRows(*output->childAt(1)->asFlatVector()); + + input_.reset(); + + return output; + } + + BlockingReason isBlocked(ContinueFuture* future) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override { + return noMoreInput_; + } + + private: + void computePartitions(FlatVector& partitionsVector) { + auto numInput = input_->size(); + + partitions_.resize(numInput); + partitionFunction_->partition(*input_, partitions_); + + // TODO Avoid copy. + partitionsVector.resize(numInput); + auto rawPartitions = partitionsVector.mutableRawValues(); + std::memcpy(rawPartitions, partitions_.data(), sizeof(int32_t) * numInput); + } + + void serializeRows(FlatVector& dataVector) { + auto numInput = input_->size(); + + dataVector.resize(numInput); + + // Compute row sizes. + rowSizes_.resize(numInput); + + size_t totalSize = 0; + for (auto i = 0; i < numInput; ++i) { + size_t rowSize = velox::row::UnsafeRowDynamicSerializer::getSizeRow( + input_->type(), input_.get(), i); + rowSizes_[i] = rowSize; + totalSize += rowSize; + } + + // Allocate memory. + auto buffer = dataVector.getBufferWithSpace(totalSize); + + // Serialize rows. + auto rawBuffer = buffer->asMutable(); + size_t offset = 0; + for (auto i = 0; i < numInput; ++i) { + dataVector.setNoCopy(i, StringView(rawBuffer + offset, rowSizes_[i])); + + // Write row data. + auto size = velox::row::UnsafeRowDynamicSerializer::serialize( + input_->type(), input_, rawBuffer + offset, i) + .value_or(0); + VELOX_DCHECK_EQ(size, rowSizes_[i]); + offset += size; + } + } + + std::unique_ptr partitionFunction_; + std::vector partitions_; + std::vector rowSizes_; +}; +} // namespace + +std::unique_ptr PartitionAndSerializeTranslator::toOperator( + DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) { + if (auto partitionNode = + std::dynamic_pointer_cast(node)) { + return std::make_unique( + id, ctx, partitionNode); + } + return nullptr; +} +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.h b/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.h new file mode 100644 index 0000000000000..33d3492c36900 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/PartitionAndSerialize.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/Operator.h" + +namespace facebook::presto::operators { + +class PartitionAndSerializeNode : public velox::core::PlanNode { + public: + PartitionAndSerializeNode( + const velox::core::PlanNodeId& id, + std::vector keys, + int numPartitions, + velox::RowTypePtr outputType, + velox::core::PlanNodePtr source) + : velox::core::PlanNode(id), + keys_{std::move(keys)}, + numPartitions_{numPartitions}, + outputType_{std::move(outputType)}, + sources_{std::move(source)} { + VELOX_USER_CHECK( + velox::ROW( + {"partition", "data"}, {velox::INTEGER(), velox::VARBINARY()}) + ->equivalent(*outputType_)); + VELOX_USER_CHECK(!keys_.empty(), "Empty keys for hive hash"); + + } + + const velox::RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override { + return sources_; + } + + const std::vector& keys() const { + return keys_; + } + + int numPartitions() const { + return numPartitions_; + } + + std::string_view name() const override { + return "PartitionAndSerialize"; + } + + private: + void addDetails(std::stringstream& stream) const override; + + const std::vector keys_; + const int numPartitions_; + const velox::RowTypePtr outputType_; + const std::vector sources_; +}; + +class PartitionAndSerializeTranslator + : public velox::exec::Operator::PlanNodeTranslator { + public: + std::unique_ptr toOperator( + velox::exec::DriverCtx* ctx, + int32_t id, + const velox::core::PlanNodePtr& node) override; +}; +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h b/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h new file mode 100644 index 0000000000000..a85b5c88cb388 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/ShuffleInterface.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/Operator.h" + +namespace facebook::presto::operators { + +class ShuffleInterface { + public: + /// Write to the shuffle one row at a time. + virtual void collect(int32_t partition, std::string_view data) = 0; + + /// Tell the shuffle system the writer is done. + /// @param success set to false indicate aborted client. + virtual void noMoreData(bool success) = 0; + + /// Check by the reader to see if more blocks are available for this + /// partition. + virtual bool hasNext(int32_t partition) const = 0; + + /// Read the next block of data for this partition. + /// @param success set to false indicate aborted client. + virtual velox::BufferPtr next(int32_t partition, bool success) = 0; +}; + +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.cpp b/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.cpp new file mode 100644 index 0000000000000..8f286d03a3914 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/operators/ShuffleWrite.h" + +using namespace facebook::velox::exec; +using namespace facebook::velox; + +namespace facebook::presto::operators { +namespace { + +class ShuffleWriteOperator : public Operator { + public: + ShuffleWriteOperator( + int32_t operatorId, + DriverCtx* FOLLY_NONNULL ctx, + const std::shared_ptr& planNode) + : Operator( + ctx, + planNode->outputType(), + operatorId, + planNode->id(), + "ShuffleWrite"), + shuffle_{planNode->shuffle()} {} + + bool needsInput() const override { + return !noMoreInput_; + } + + void addInput(RowVectorPtr input) override { + auto partitions = input->childAt(0)->as>(); + auto serializedRows = input->childAt(1)->as>(); + for (auto i = 0; i < input->size(); ++i) { + auto partition = partitions->valueAt(i); + auto data = serializedRows->valueAt(i); + + shuffle_->collect(partition, std::string_view(data.data(), data.size())); + } + } + + void noMoreInput() override { + Operator::noMoreInput(); + shuffle_->noMoreData(true); + } + + RowVectorPtr getOutput() override { + return nullptr; + } + + BlockingReason isBlocked(ContinueFuture* future) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override { + return noMoreInput_; + } + + private: + ShuffleInterface* shuffle_; +}; +} // namespace + +std::unique_ptr ShuffleWriteTranslator::toOperator( + DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) { + if (auto shuffleWriteNode = + std::dynamic_pointer_cast(node)) { + return std::make_unique(id, ctx, shuffleWriteNode); + } + return nullptr; +} +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.h b/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.h new file mode 100644 index 0000000000000..d237f4ce55c78 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/ShuffleWrite.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/operators/ShuffleInterface.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/Operator.h" + +namespace facebook::presto::operators { + +class ShuffleWriteNode : public velox::core::PlanNode { + public: + ShuffleWriteNode( + const velox::core::PlanNodeId& id, + ShuffleInterface* shuffle, + velox::core::PlanNodePtr source) + : velox::core::PlanNode(id), + shuffle_{shuffle}, + sources_{std::move(source)} {} + + const velox::RowTypePtr& outputType() const override { + return sources_[0]->outputType(); + } + + const std::vector& sources() const override { + return sources_; + } + + ShuffleInterface* shuffle() const { + return shuffle_; + } + + std::string_view name() const override { + return "ShuffleWrite"; + } + + private: + void addDetails(std::stringstream& stream) const override {} + + ShuffleInterface* shuffle_; + + const std::vector sources_; +}; + +class ShuffleWriteTranslator + : public velox::exec::Operator::PlanNodeTranslator { + public: + std::unique_ptr toOperator( + velox::exec::DriverCtx* ctx, + int32_t id, + const velox::core::PlanNodePtr& node) override; +}; +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp new file mode 100644 index 0000000000000..68ed705a5851f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h" + +namespace facebook::presto::operators { + +void UnsafeRowExchangeSource::request() { + std::lock_guard l(queue_->mutex()); + + if (!shuffle_->hasNext(destination_)) { + atEnd_ = true; + queue_->enqueue(nullptr); + return; + } + + auto buffer = shuffle_->next(destination_, true); + + auto ioBuf = folly::IOBuf::wrapBuffer(buffer->as(), buffer->size()); + queue_->enqueue(std::make_unique( + std::move(ioBuf), pool_, [buffer](auto&) { buffer->release(); })); +} +}; // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.h b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.h new file mode 100644 index 0000000000000..6e04a0aaee7f1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/operators/ShuffleWrite.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/Exchange.h" +#include "velox/exec/Operator.h" + +namespace facebook::presto::operators { + +class UnsafeRowExchangeSource : public velox::exec::ExchangeSource { + public: + UnsafeRowExchangeSource( + const std::string& taskId, + int destination, + std::shared_ptr queue, + ShuffleInterface* shuffle, + velox::memory::MemoryPool* pool) + : ExchangeSource(taskId, destination, queue, pool), shuffle_(shuffle) {} + + bool shouldRequestLocked() override { + return !atEnd_; + } + + void request() override; + + void close() override {} + + private: + ShuffleInterface* shuffle_; +}; +} // namespace facebook::presto::operators \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt new file mode 100644 index 0000000000000..6dfaf6a724bf0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/tests/CMakeLists.txt @@ -0,0 +1,29 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_executable( + presto_operators_test + UnsaferowShuffleTest.cpp) + +add_test(presto_operators_test presto_operators_test) + +target_link_libraries( + presto_operators_test + presto_operators + velox_exec_test_lib + velox_vector_test_lib + velox_type + velox_vector + velox_exec + velox_memory + velox_exec + gtest + gtest_main) diff --git a/presto-native-execution/presto_cpp/main/operators/tests/UnsaferowShuffleTest.cpp b/presto-native-execution/presto_cpp/main/operators/tests/UnsaferowShuffleTest.cpp new file mode 100644 index 0000000000000..7f3582a98b431 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/operators/tests/UnsaferowShuffleTest.cpp @@ -0,0 +1,405 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "folly/init/Init.h" +#include "presto_cpp/main/operators/PartitionAndSerialize.h" +#include "presto_cpp/main/operators/ShuffleWrite.h" +#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h" +#include "velox/exec/Exchange.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/expression/VectorFunction.h" +#include "velox/serializers/UnsafeRowSerializer.h" + +using namespace facebook::velox; +using namespace facebook::presto; +using namespace facebook::presto::operators; + +namespace facebook::presto::operators::test { + +namespace { + +class TestShuffle : public ShuffleInterface { + public: + TestShuffle( + memory::MemoryPool* pool, + uint32_t numPartitions, + uint32_t maxBytesPerPartition) + : pool_{pool}, + numPartitions_{numPartitions}, + maxBytesPerPartition_{maxBytesPerPartition}, + inProgressSizes_(numPartitions, 0) { + inProgressPartitions_.resize(numPartitions_); + readyPartitions_.resize(numPartitions_); + } + + void collect(int32_t partition, std::string_view data) { + auto& buffer = inProgressPartitions_[partition]; + + // Check if there is enough space in the buffer. + if (buffer && + inProgressSizes_[partition] + data.size() + sizeof(size_t) >= + maxBytesPerPartition_) { + buffer->setSize(inProgressSizes_[partition]); + readyPartitions_[partition].emplace_back(std::move(buffer)); + inProgressPartitions_[partition].reset(); + } + + // Allocate buffer if needed. + if (!buffer) { + buffer = AlignedBuffer::allocate(maxBytesPerPartition_, pool_); + inProgressSizes_[partition] = 0; + } + + // Copy data. + auto rawBuffer = buffer->asMutable(); + auto offset = inProgressSizes_[partition]; + + *(size_t*)(rawBuffer + offset) = data.size(); + + offset += sizeof(size_t); + memcpy(rawBuffer + offset, data.data(), data.size()); + + inProgressSizes_[partition] += sizeof(size_t) + data.size(); + } + + void noMoreData(bool success) { + VELOX_CHECK(success, "Unexpected error") + for (auto i = 0; i < numPartitions_; ++i) { + if (inProgressSizes_[i] > 0) { + auto& buffer = inProgressPartitions_[i]; + buffer->setSize(inProgressSizes_[i]); + readyPartitions_[i].emplace_back(std::move(buffer)); + inProgressPartitions_[i].reset(); + } + } + } + + bool hasNext(int32_t partition) const { + return !readyPartitions_[partition].empty(); + } + + BufferPtr next(int32_t partition, bool success) { + VELOX_CHECK(success, "Unexpected error") + VELOX_CHECK(!readyPartitions_[partition].empty()); + + auto buffer = readyPartitions_[partition].back(); + readyPartitions_[partition].pop_back(); + return buffer; + } + + private: + memory::MemoryPool* pool_; + const uint32_t numPartitions_; + const uint32_t maxBytesPerPartition_; + std::vector inProgressPartitions_; + std::vector inProgressSizes_; + std::vector> readyPartitions_; +}; + +void registerExchangeSource(ShuffleInterface* shuffle) { + exec::ExchangeSource::registerFactory( + [shuffle]( + const std::string& taskId, + int destination, + std::shared_ptr queue, + memory::MemoryPool* FOLLY_NONNULL pool) + -> std::unique_ptr { + if (strncmp(taskId.c_str(), "spark://", 8) == 0) { + return std::make_unique( + taskId, destination, std::move(queue), shuffle, pool); + } + return nullptr; + }); +} + +auto addPartitionAndSerializeNode(int numPartitions) { + return [numPartitions]( + core::PlanNodeId nodeId, + core::PlanNodePtr source) -> core::PlanNodePtr { + auto outputType = ROW({"p", "d"}, {INTEGER(), VARBINARY()}); + + std::vector keys; + keys.push_back( + std::make_shared(INTEGER(), "c0")); + + return std::make_shared( + nodeId, keys, numPartitions, outputType, std::move(source)); + }; +} + +auto addShuffleWriteNode(ShuffleInterface* shuffle) { + return [shuffle]( + core::PlanNodeId nodeId, + core::PlanNodePtr source) -> core::PlanNodePtr { + return std::make_shared( + nodeId, shuffle, std::move(source)); + }; +} +} // namespace + +class UnsafeRowShuffleTest : public exec::test::OperatorTestBase { + protected: + + void registerVectorSerde() override { + serializer::spark::UnsafeRowVectorSerde::registerVectorSerde(); + } + + static std::string makeTaskId(const std::string& prefix, int num) { + return fmt::format("spark://{}-{}", prefix, num); + } + + std::shared_ptr makeTask( + const std::string& taskId, + core::PlanNodePtr planNode, + int destination) { + auto queryCtx = + core::QueryCtx::createForTest(std::make_shared()); + core::PlanFragment planFragment{planNode}; + return std::make_shared( + taskId, std::move(planFragment), destination, std::move(queryCtx)); + } + + void addRemoteSplits( + exec::Task* task, + const std::vector& remoteTaskIds) { + for (auto& taskId : remoteTaskIds) { + auto split = + exec::Split(std::make_shared(taskId), -1); + task->addSplit("0", std::move(split)); + } + task->noMoreSplits("0"); + } + + RowVectorPtr deserialize( + const RowVectorPtr& serializedResult, + const RowTypePtr& rowType) { + auto serializedData = + serializedResult->childAt(1)->as>(); + + // Serialize data into a single block. + + // Calculate total size. + size_t totalSize = 0; + for (auto i = 0; i < serializedData->size(); ++i) { + totalSize += serializedData->valueAt(i).size(); + } + + // Allocate the block. Add an extra sizeof(size_t) bytes for each row to + // hold row size. + BufferPtr buffer = AlignedBuffer::allocate( + totalSize + sizeof(size_t) * serializedData->size(), pool()); + auto rawBuffer = buffer->asMutable(); + + // Copy data. + size_t offset = 0; + for (auto i = 0; i < serializedData->size(); ++i) { + auto value = serializedData->valueAt(i); + + *(size_t*)(rawBuffer + offset) = value.size(); + offset += sizeof(size_t); + + memcpy(rawBuffer + offset, value.data(), value.size()); + offset += value.size(); + } + + // Deserialize the block. + return deserialize(buffer, rowType); + } + + RowVectorPtr deserialize(BufferPtr& serialized, const RowTypePtr& rowType) { + auto serializer = + std::make_unique(); + + ByteRange byteRange = { + serialized->asMutable(), (int32_t)serialized->size(), 0}; + + auto input = std::make_unique(); + input->resetInput({byteRange}); + + RowVectorPtr result; + serializer->deserialize(input.get(), pool(), rowType, &result, nullptr); + return result; + } +}; + +TEST_F(UnsafeRowShuffleTest, operators) { + exec::Operator::registerOperator( + std::make_unique()); + exec::Operator::registerOperator(std::make_unique()); + + TestShuffle shuffle(pool(), 4, 1 << 20 /* 1MB */); + + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 4}), + makeFlatVector({10, 20, 30, 40}), + }); + + auto plan = exec::test::PlanBuilder() + .values({data}, true) + .addNode(addPartitionAndSerializeNode(4)) + .localPartition({}) + .addNode(addShuffleWriteNode(&shuffle)) + .planNode(); + + exec::test::CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + + auto [taskCursor, serializedResults] = + readCursor(params, [](auto /*task*/) {}); + ASSERT_EQ(serializedResults.size(), 0); +} + +TEST_F(UnsafeRowShuffleTest, endToEnd) { + exec::Operator::registerOperator( + std::make_unique()); + exec::Operator::registerOperator(std::make_unique()); + + size_t numPartitions = 5; + TestShuffle shuffle(pool(), numPartitions, 1 << 20 /* 1MB */); + + registerExchangeSource(&shuffle); + + // Create and run single leaf task to partition data and write it to shuffle. + auto data = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5, 6}), + makeFlatVector({10, 20, 30, 40, 50, 60}), + }); + + auto dataType = asRowType(data->type()); + + auto leafPlan = exec::test::PlanBuilder() + .values({data}, true) + .addNode(addPartitionAndSerializeNode(numPartitions)) + .localPartition({}) + .addNode(addShuffleWriteNode(&shuffle)) + .planNode(); + + auto leafTaskId = makeTaskId("leaf", 0); + auto leafTask = makeTask(leafTaskId, leafPlan, 0); + exec::Task::start(leafTask, 2); + ASSERT_TRUE(exec::test::waitForTaskCompletion(leafTask.get())); + + // Create and run multiple downstream tasks, one per partition, to read data + // from shuffle. + for (auto i = 0; i < numPartitions; ++i) { + auto plan = exec::test::PlanBuilder() + .exchange(dataType) + .project(dataType->names()) + .planNode(); + + exec::test::CursorParameters params; + params.planNode = plan; + params.destination = i; + + bool noMoreSplits = false; + auto [taskCursor, serializedResults] = readCursor(params, [&](auto* task) { + if (noMoreSplits) { + return; + } + addRemoteSplits(task, {leafTaskId}); + noMoreSplits = true; + }); + + ASSERT_FALSE(shuffle.hasNext(i)) << i; + } +} + +TEST_F(UnsafeRowShuffleTest, partitionAndSerializeOperator) { + auto data = makeRowVector({ + makeFlatVector(1'000, [](auto row) { return row; }), + makeFlatVector(1'000, [](auto row) { return row * 10; }), + }); + + auto plan = + exec::test::PlanBuilder() + .values({data}, true) + .addNode(addPartitionAndSerializeNode(4)) + .planNode(); + + exec::test::CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + + auto [taskCursor, serializedResults] = + readCursor(params, [](auto /*task*/) {}); + ASSERT_EQ(serializedResults.size(), 2); + + for (auto& serializedResult : serializedResults) { + // Verify that serialized data can be deserialized successfully into the + // original data. + auto deserialized = deserialize(serializedResult, asRowType(data->type())); + velox::test::assertEqualVectors(data, deserialized); + } +} + +TEST_F(UnsafeRowShuffleTest, ShuffleWriterToString) { + auto data = makeRowVector({ + makeFlatVector(1'000, [](auto row) { return row; }), + makeFlatVector(1'000, [](auto row) { return row * 10; }), + }); + + auto plan = exec::test::PlanBuilder() + .values({data}, true) + .addNode(addPartitionAndSerializeNode(4)) + .localPartition({}) + .addNode(addShuffleWriteNode(nullptr)) + .planNode(); + + ASSERT_EQ( + plan->toString(true, false), + "-- ShuffleWrite[] -> p:INTEGER, d:VARBINARY\n"); + ASSERT_EQ( + plan->toString(true, true), + "-- ShuffleWrite[] -> p:INTEGER, d:VARBINARY\n""" + " -- LocalPartition[GATHER] -> p:INTEGER, d:VARBINARY\n" + " -- PartitionAndSerialize[(c0) 4] -> p:INTEGER, d:VARBINARY\n" + " -- Values[1000 rows in 1 vectors] -> c0:INTEGER, c1:BIGINT\n"); + ASSERT_EQ( + plan->toString(false, false), + "-- ShuffleWrite\n"); +} + +TEST_F(UnsafeRowShuffleTest, partitionAndSerializeToString) { + auto data = makeRowVector({ + makeFlatVector(1'000, [](auto row) { return row; }), + makeFlatVector(1'000, [](auto row) { return row * 10; }), + }); + + auto plan = + exec::test::PlanBuilder() + .values({data}, true) + .addNode(addPartitionAndSerializeNode(4)) + .planNode(); + + ASSERT_EQ( + plan->toString(true, false), + "-- PartitionAndSerialize[(c0) 4] -> p:INTEGER, d:VARBINARY\n"); + ASSERT_EQ( + plan->toString(true, true), + "-- PartitionAndSerialize[(c0) 4] -> p:INTEGER, d:VARBINARY\n" + " -- Values[1000 rows in 1 vectors] -> c0:INTEGER, c1:BIGINT\n"); + ASSERT_EQ( + plan->toString(false, false), + "-- PartitionAndSerialize\n"); +} +} // namespace facebook::presto::operators::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::init(&argc, &argv, false); + return RUN_ALL_TESTS(); +}