Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions presto-native-execution/presto_cpp/main/TaskResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl(
protocol::TaskId taskId = pathMatch[1];
bool summarize = message->hasQueryParam("summarize");

auto& headers = message->getHeaders();
const auto& headers = message->getHeaders();
const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT);
const auto sendThrift =
acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos;
Expand Down Expand Up @@ -392,7 +392,7 @@ proxygen::RequestHandler* TaskResource::deleteTask(
message->getQueryParam(protocol::PRESTO_ABORT_TASK_URL_PARAM) == "true";
}
bool summarize = message->hasQueryParam("summarize");
auto& headers = message->getHeaders();
const auto& headers = message->getHeaders();
const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT);
const auto sendThrift =
acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos;
Expand Down Expand Up @@ -544,7 +544,7 @@ proxygen::RequestHandler* TaskResource::getTaskStatus(
auto currentState = getCurrentState(message);
auto maxWait = getMaxWait(message);

auto& headers = message->getHeaders();
const auto& headers = message->getHeaders();
const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT);
const auto sendThrift =
acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos;
Expand Down Expand Up @@ -615,7 +615,7 @@ proxygen::RequestHandler* TaskResource::getTaskInfo(
auto maxWait = getMaxWait(message);
bool summarize = message->hasQueryParam("summarize");

auto& headers = message->getHeaders();
const auto& headers = message->getHeaders();
const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT);
const auto sendThrift =
acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <gtest/gtest.h>
#include "presto_cpp/main/thrift/ProtocolToThrift.h"
#include "presto_cpp/main/thrift/ThriftIO.h"
#include "presto_cpp/main/common/tests/test_json.h"
#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h"

Expand Down Expand Up @@ -100,7 +101,7 @@ TEST_F(TaskUpdateRequestTest, mapOutputBuffers) {
ASSERT_EQ(outputBuffers.buffers["2"], 20);
}

TEST_F(TaskUpdateRequestTest, binarySplitFromThrift) {
TEST_F(TaskUpdateRequestTest, binaryHiveSplitFromThrift) {
thrift::Split thriftSplit;
thriftSplit.connectorId()->catalogName_ref() = "hive";
thriftSplit.transactionHandle()->jsonValue_ref() = R"({
Expand All @@ -127,14 +128,89 @@ TEST_F(TaskUpdateRequestTest, binarySplitFromThrift) {
protocol::NodeSelectionStrategy::NO_PREFERENCE);
}

TEST_F(TaskUpdateRequestTest, binaryTableWriteInfo) {
std::string str = slurp(getDataPath(BASE_DATA_PATH, "TableWriteInfo.json"));
protocol::TableWriteInfo tableWriteInfo;
TEST_F(TaskUpdateRequestTest, binaryRemoteSplitFromThrift) {
thrift::Split thriftSplit;
thrift::RemoteTransactionHandle thriftTransactionHandle;
thrift::RemoteSplit thriftRemoteSplit;

thriftSplit.connectorId()->catalogName_ref() = "$remote";
thriftSplit.transactionHandle()->customSerializedValue_ref() =
thriftWrite(thriftTransactionHandle);

thriftRemoteSplit.location()->location_ref() = "/test_location";
thriftRemoteSplit.remoteSourceTaskId()->id_ref() = 100;
thriftRemoteSplit.remoteSourceTaskId()->attemptNumber_ref() = 200;
thriftRemoteSplit.remoteSourceTaskId()->stageExecutionId()->id_ref() = 300;
thriftRemoteSplit.remoteSourceTaskId()->stageExecutionId()->stageId()->id_ref() = 400;
thriftRemoteSplit.remoteSourceTaskId()->stageExecutionId()->stageId()->queryId_ref() = "test_query_id";

thriftSplit.connectorSplit()->connectorId_ref() = "$remote";
thriftSplit.connectorSplit()->customSerializedValue_ref() =
thriftWrite(thriftRemoteSplit);

protocol::Split split;
thrift::fromThrift(thriftSplit, split);

// Verify that connector specific fields are set correctly with thrift codec
auto remoteSplit = std::dynamic_pointer_cast<protocol::RemoteSplit>(
split.connectorSplit);
ASSERT_EQ((remoteSplit->location).location, "/test_location");
ASSERT_EQ(remoteSplit->remoteSourceTaskId, "test_query_id.400.300.100.200");
}

TEST_F(TaskUpdateRequestTest, unionExecutionWriterTargetFromThrift) {
// Construct ExecutionWriterTarget with CreateHandle
thrift::CreateHandle thriftCreateHandle;
thrift::ExecutionWriterTargetUnion thriftWriterTarget;
thriftCreateHandle.schemaTableName()->schema_ref() = "test_schema";
thriftCreateHandle.schemaTableName()->table_ref() = "test_table";
thriftCreateHandle.handle()->connectorId()->catalogName_ref() = "hive";
thriftCreateHandle.handle()->transactionHandle()->jsonValue_ref() = R"({
"@type": "hive",
"uuid": "8a4d6c83-60ee-46de-9715-bc91755619fa"
})";
thriftCreateHandle.handle()->connectorHandle()->jsonValue_ref() = slurp(getDataPath(BASE_DATA_PATH, "HiveOutputTableHandle.json"));;
thriftWriterTarget.set_createHandle(std::move(thriftCreateHandle));

// Convert from thrift to protocol and verify fields
auto writerTarget = std::make_shared<protocol::ExecutionWriterTarget>();
thrift::fromThrift(thriftWriterTarget, writerTarget);

ASSERT_EQ(writerTarget->_type, "CreateHandle");
auto createHandle = std::dynamic_pointer_cast<protocol::CreateHandle>(writerTarget);
ASSERT_NE(createHandle, nullptr);
ASSERT_EQ(createHandle->schemaTableName.schema, "test_schema");
ASSERT_EQ(createHandle->schemaTableName.table, "test_table");

auto* hiveTxnHandle = dynamic_cast<protocol::hive::HiveTransactionHandle*>(createHandle->handle.transactionHandle.get());
ASSERT_NE(hiveTxnHandle, nullptr);
ASSERT_EQ(hiveTxnHandle->uuid, "8a4d6c83-60ee-46de-9715-bc91755619fa");

auto* hiveOutputTableHandle = dynamic_cast<protocol::hive::HiveOutputTableHandle*>(createHandle->handle.connectorHandle.get());
ASSERT_NE(hiveOutputTableHandle, nullptr);
ASSERT_EQ(hiveOutputTableHandle->schemaName, "test_schema");
ASSERT_EQ(hiveOutputTableHandle->tableName, "test_table");
ASSERT_EQ(hiveOutputTableHandle->tableStorageFormat, protocol::hive::HiveStorageFormat::ORC);
ASSERT_EQ(hiveOutputTableHandle->locationHandle.targetPath, "/path/to/target");
}

TEST_F(TaskUpdateRequestTest, unionExecutionWriterTargetToThrift) {
// Construct thrift ExecutionWriterTarget with CreateHandle
auto createHandle = std::make_shared<protocol::CreateHandle>();
createHandle->schemaTableName.schema = "test_schema";
createHandle->schemaTableName.table = "test_table";

auto writerTarget = std::make_shared<protocol::ExecutionWriterTarget>();
writerTarget->_type = "CreateHandle";
writerTarget = createHandle;

thrift::fromThrift(str, tableWriteInfo);
auto hiveTableHandle = std::dynamic_pointer_cast<protocol::hive::HiveTableHandle>((*tableWriteInfo.analyzeTableHandle).connectorHandle);
ASSERT_EQ(hiveTableHandle->tableName, "test_table");
ASSERT_EQ(hiveTableHandle->analyzePartitionValues->size(), 2);
// Convert to thrift and verify fields. Note that toThrift functions for connector fields are not implemented.
thrift::ExecutionWriterTargetUnion thriftWriterTarget;
thrift::toThrift(writerTarget, thriftWriterTarget);
ASSERT_TRUE(thriftWriterTarget.createHandle_ref().has_value());
const auto& thriftCreateHandle = thriftWriterTarget.createHandle_ref().value();
ASSERT_EQ(thriftCreateHandle.schemaTableName()->schema_ref().value(), "test_schema");
ASSERT_EQ(thriftCreateHandle.schemaTableName()->table_ref().value(), "test_table");
}

TEST_F(TaskUpdateRequestTest, fragment) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"@type": "hive",
"schemaName": "test_schema",
"tableName": "test_table",
"inputColumns": [],
"pageSinkMetadata": {
"schemaTableName": {
"schema": "test_schema",
"table": "test_table"
},
"modifiedPartitions": {}
},
"locationHandle": {
"targetPath": "/path/to/target",
"writePath": "/path/to/write",
"tableType": "NEW",
"writeMode": "STAGE_AND_MOVE_TO_TARGET_DIRECTORY"
},
"tableStorageFormat": "ORC",
"partitionStorageFormat": "ORC",
"actualStorageFormat": "ORC",
"compressionCodec": "NONE",
"partitionedBy": [],
"preferredOrderingColumns": [],
"tableOwner": "owner_name",
"additionalTableParameters": {}
}

This file was deleted.

2 changes: 1 addition & 1 deletion presto-native-execution/presto_cpp/main/thrift/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

.PHONY: presto_protocol-to-thrift-json.json presto_thrift.json
all: ProtocolToThrift.h ProtocolToThrift.cpp

ProtocolToThrift.h: ProtocolToThrift-hpp.mustache presto_protocol-to-thrift-json.json
Expand All @@ -27,4 +28,3 @@ presto_protocol-to-thrift-json.json: presto_protocol-to-thrift-json.py presto_pr

presto_thrift.json: presto_thrift.thrift ./thrift2json.py
./thrift2json.py presto_thrift.thrift | jq . > presto_thrift.json

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
{{/.}}

#include "presto_cpp/main/thrift/ProtocolToThrift.h"
#include "presto_cpp/main/thrift/ThriftIO.h"
#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h"

namespace facebook::presto::thrift {
Expand Down Expand Up @@ -190,13 +191,69 @@ void fromThrift(const std::map<K1, V1>& thriftMap, std::map<K2, V2>& protoMap) {
{{&cinc}}
{{/cinc}}
{{^cinc}}
{{#connector}}
void toThrift(const std::shared_ptr<facebook::presto::protocol::{{class_name}}>& proto, {{class_name}}& thrift) {
}
void fromThrift(const {{class_name}}& thrift, std::shared_ptr<facebook::presto::protocol::{{class_name}}>& proto) {
if (thrift.connectorId().has_value() && thrift.customSerializedValue().has_value()) {
facebook::presto::protocol::getConnectorProtocol(thrift.connectorId().value())
.deserialize(thrift.customSerializedValue().value(), proto);
} else if (thrift.jsonValue().has_value()) {
json j = json::parse(thrift.jsonValue().value());
from_json(j, proto);
}
}

{{/connector}}
{{#union}}
void toThrift(
const std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto,
apache::thrift::optional_field_ref<{{class_name}}&> thrift) {
if (proto) {
thrift.ensure();
toThrift(proto, *thrift);
}
}
void toThrift(
const std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto,
{{class_name}}& thrift) {
{{#fields}}
if (auto {{field_name}} =
std::dynamic_pointer_cast<facebook::presto::protocol::{{field_type}}>(proto)) {
{{field_type}} thrift{{field_type}};
toThrift(*{{field_name}}, thrift{{field_type}});
thrift.set_{{field_name}}(std::move(thrift{{field_type}}));
}
{{/fields}}
}
void fromThrift(
apache::thrift::optional_field_ref<const {{class_name}}&> thrift,
std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto) {
if (thrift.has_value()) {
proto = std::make_shared<facebook::presto::protocol::{{proto_name}}>();
fromThrift(thrift.value(), proto);
}
}
void fromThrift(
const {{class_name}}& thrift,
std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto) {
{{#fields}}
if (thrift.getType() == {{class_name}}::Type::{{field_name}}) {
std::shared_ptr<facebook::presto::protocol::{{#proto_field_type}}{{proto_field_type}}{{/proto_field_type}}{{^proto_field_type}}{{field_type}}{{/proto_field_type}}> {{field_name}};
fromThrift(thrift.get_{{field_name}}(), {{field_name}});
proto = {{field_name}};
}
{{/fields}}
}

{{/union}}
{{#struct}}
void toThrift(const facebook::presto::protocol::{{class_name}}& proto, {{&class_name}}& thrift) {
{{#fields}}
toThrift(proto.{{proto_name}}, {{^optional}}*{{/optional}}thrift.{{field_name}}_ref());
{{/fields}}
}
void fromThrift(const {{&class_name}}& thrift, facebook::presto::protocol::{{class_name}}& proto) {
void fromThrift(const {{&class_name}}& thrift, facebook::presto::protocol::{{#proto_name}}{{proto_name}}{{/proto_name}}{{^proto_name}}{{class_name}}{{/proto_name}}& proto) {
{{#fields}}
fromThrift({{^optional}}*{{/optional}}thrift.{{field_name}}_ref(), proto.{{proto_name}});
{{/fields}}
Expand All @@ -216,11 +273,11 @@ void fromThrift(const {{class_name}}& thrift, facebook::presto::protocol::{{clas
}
{{/wrapper}}
{{#enum}}
void toThrift(const facebook::presto::protocol::{{class_name}}& proto, {{class_name}}& thrift) {
void toThrift(const facebook::presto::protocol::{{proto_name}}& proto, {{class_name}}& thrift) {
thrift = ({{class_name}})(static_cast<int>(proto));
}
void fromThrift(const {{class_name}}& thrift, facebook::presto::protocol::{{class_name}}& proto) {
proto = (facebook::presto::protocol::{{class_name}})(static_cast<int>(thrift));
void fromThrift(const {{class_name}}& thrift, facebook::presto::protocol::{{proto_name}}& proto) {
proto = (facebook::presto::protocol::{{proto_name}})(static_cast<int>(thrift));
}

{{/enum}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,28 @@ void fromThrift(const double& thrift, facebook::presto::protocol::DataSize& data
{{&hinc}}
{{/hinc}}
{{^hinc}}
{{#connector}}
void toThrift(const std::shared_ptr<facebook::presto::protocol::{{class_name}}>& proto, {{class_name}}& thrift);
void fromThrift(const {{class_name}}& thrift, std::shared_ptr<facebook::presto::protocol::{{class_name}}>& proto);

{{/connector}}
{{#union}}
void toThrift(
const std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto,
apache::thrift::optional_field_ref<{{class_name}}&> thrift);
void toThrift(
const std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto,
{{class_name}}& thrift);
void fromThrift(
apache::thrift::optional_field_ref<const {{class_name}}&> thrift,
std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto);
void fromThrift(
const {{class_name}}& thrift,
std::shared_ptr<facebook::presto::protocol::{{proto_name}}>& proto);
{{/union}}
{{#struct}}
void toThrift(const facebook::presto::protocol::{{class_name}}& proto, {{class_name}}& thrift);
void fromThrift(const {{&class_name}}& thrift, facebook::presto::protocol::{{class_name}}& proto);
void toThrift(const facebook::presto::protocol::{{class_name}}& proto, {{&class_name}}& thrift);
void fromThrift(const {{&class_name}}& thrift, facebook::presto::protocol::{{#proto_name}}{{proto_name}}{{/proto_name}}{{^proto_name}}{{class_name}}{{/proto_name}}& proto);

{{/struct}}
{{#wrapper}}
Expand Down
Loading
Loading