diff --git a/presto-native-execution/presto_cpp/main/tests/PrestoQueryRunner.cpp b/presto-native-execution/presto_cpp/main/tests/PrestoQueryRunner.cpp index 1ab1fa14cb080..bd1c27a4405c0 100644 --- a/presto-native-execution/presto_cpp/main/tests/PrestoQueryRunner.cpp +++ b/presto-native-execution/presto_cpp/main/tests/PrestoQueryRunner.cpp @@ -15,7 +15,6 @@ #include #include #include -#include "presto_cpp/main/types/ParseTypeSignature.h" #include "velox/common/base/Fs.h" #include "velox/common/encode/Base64.h" #include "velox/common/file/FileSystems.h" @@ -23,6 +22,12 @@ #include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/serializers/PrestoSerializer.h" +// ANTLR defines an INVALID_INDEX macro, and DuckDB has a constant variable of +// the same name. So we have to include TypeParser.h after Velox. +// clang-format off +#include "presto_cpp/main/types/TypeParser.h" +// clang-format on + using namespace facebook::velox; namespace facebook::presto::test { @@ -122,9 +127,10 @@ class ServerResponse { std::vector names; std::vector types; + TypeParser parser; for (const auto& column : response_["columns"]) { names.push_back(column["name"].asString()); - types.push_back(parseTypeSignature(column["type"].asString())); + types.push_back(parser.parse(column["type"].asString())); } auto rowType = ROW(std::move(names), std::move(types)); diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index f2a70faee24e3..9f3f14e3bf592 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -10,9 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. add_library( - presto_type_converter OBJECT - TypeSignatureTypeConverter.cpp antlr/TypeSignatureLexer.cpp - antlr/TypeSignatureParser.cpp) + presto_type_converter OBJECT TypeParser.cpp antlr/TypeSignatureLexer.cpp + antlr/TypeSignatureParser.cpp) target_link_libraries(presto_type_converter velox_type) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index c9132601d5a07..551368be43b72 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -14,7 +14,6 @@ #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include -#include "presto_cpp/main/types/ParseTypeSignature.h" #include "presto_cpp/presto_protocol/Base64Util.h" #include "velox/common/base/Exceptions.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -170,7 +169,8 @@ namespace { std::optional tryConvertCast( const protocol::Signature& signature, const std::string& returnType, - const std::vector& args) { + const std::vector& args, + const TypeParser* typeParser) { static const char* kCast = "presto.default.$operator$cast"; static const char* kTryCast = "presto.default.try_cast"; static const char* kJsonToArrayCast = @@ -197,7 +197,7 @@ std::optional tryConvertCast( signature.name.compare(kJsonToArrayCast) == 0 || signature.name.compare(kJsonToMapCast) == 0 || signature.name.compare(kJsonToRowCast) == 0) { - auto type = parseTypeSignature(returnType); + auto type = typeParser->parse(returnType); return std::make_shared( type, std::vector{std::make_shared( @@ -219,14 +219,15 @@ std::optional tryConvertCast( return args[0]; } - auto type = parseTypeSignature(returnType); + auto type = typeParser->parse(returnType); return std::make_shared(type, args, nullOnFailure); } std::optional tryConvertTry( const protocol::Signature& signature, const std::string& returnType, - const std::vector& args) { + const std::vector& args, + const TypeParser* typeParser) { static const char* kTry = "presto.default.$internal$try"; if (signature.kind != protocol::FunctionKind::SCALAR) { @@ -243,7 +244,7 @@ std::optional tryConvertTry( VELOX_CHECK(lambda); VELOX_CHECK_EQ(lambda->signature()->size(), 0); - auto type = parseTypeSignature(returnType); + auto type = typeParser->parse(returnType); std::vector newArgs = {lambda->body()}; return std::make_shared(type, newArgs, "try"); } @@ -252,7 +253,8 @@ std::optional tryConvertLiteralArray( const protocol::Signature& signature, const std::string& returnType, const std::vector& args, - velox::memory::MemoryPool* pool) { + velox::memory::MemoryPool* pool, + const TypeParser* typeParser) { static const char* kLiteralArray = "presto.default.$literal$array"; static const char* kFromBase64 = "presto.default.from_base64"; @@ -272,7 +274,7 @@ std::optional tryConvertLiteralArray( return std::nullopt; } - auto type = parseTypeSignature(returnType); + auto type = typeParser->parse(returnType); auto encoded = std::dynamic_pointer_cast(call->inputs()[0]); @@ -310,7 +312,7 @@ std::optional VeloxExprConverter::tryConvertDate( // a VARCHAR or TIMESTAMP (with an optional timezone) type. args.emplace_back(toVeloxExpr(pexpr.arguments[0])); - auto returnType = parseTypeSignature(pexpr.returnType); + auto returnType = typeParser_->parse(pexpr.returnType); return std::make_shared(returnType, args, false); } @@ -360,7 +362,7 @@ std::optional VeloxExprConverter::tryConvertLike( } // Construct the returnType and CallTypedExpr for 'like' - auto returnType = parseTypeSignature(pexpr.returnType); + auto returnType = typeParser_->parse(pexpr.returnType); return std::make_shared( returnType, args, getFunctionName(signature)); } @@ -384,23 +386,24 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( auto args = toVeloxExpr(pexpr.arguments); auto signature = builtin->signature; - auto cast = tryConvertCast(signature, pexpr.returnType, args); + auto cast = tryConvertCast(signature, pexpr.returnType, args, typeParser_); if (cast.has_value()) { return cast.value(); } - auto tryExpr = tryConvertTry(signature, pexpr.returnType, args); + auto tryExpr = + tryConvertTry(signature, pexpr.returnType, args, typeParser_); if (tryExpr.has_value()) { return tryExpr.value(); } - auto literal = - tryConvertLiteralArray(signature, pexpr.returnType, args, pool_); + auto literal = tryConvertLiteralArray( + signature, pexpr.returnType, args, pool_, typeParser_); if (literal.has_value()) { return literal.value(); } - auto returnType = parseTypeSignature(pexpr.returnType); + auto returnType = typeParser_->parse(pexpr.returnType); return std::make_shared( returnType, args, getFunctionName(signature)); @@ -409,7 +412,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( std::dynamic_pointer_cast( pexpr.functionHandle)) { auto args = toVeloxExpr(pexpr.arguments); - auto returnType = parseTypeSignature(pexpr.returnType); + auto returnType = typeParser_->parse(pexpr.returnType); return std::make_shared( returnType, args, getFunctionName(sqlFunctionHandle->functionId)); } @@ -419,7 +422,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( std::shared_ptr VeloxExprConverter::toVeloxExpr( std::shared_ptr pexpr) const { - const auto type = parseTypeSignature(pexpr->type); + const auto type = typeParser_->parse(pexpr->type); switch (type->kind()) { case TypeKind::ROW: FOLLY_FALLTHROUGH; @@ -676,7 +679,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( return convertInExpr(args, pool_); } - auto returnType = parseTypeSignature(pexpr->returnType); + auto returnType = typeParser_->parse(pexpr->returnType); if (pexpr->form == protocol::Form::SWITCH) { return convertSwitchExpr(returnType, std::move(args)); @@ -703,7 +706,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( std::shared_ptr VeloxExprConverter::toVeloxExpr( std::shared_ptr pexpr) const { return std::make_shared( - parseTypeSignature(pexpr->type), pexpr->name); + typeParser_->parse(pexpr->type), pexpr->name); } std::shared_ptr VeloxExprConverter::toVeloxExpr( @@ -711,7 +714,7 @@ std::shared_ptr VeloxExprConverter::toVeloxExpr( std::vector argumentTypes; argumentTypes.reserve(lambda->argumentTypes.size()); for (auto& typeName : lambda->argumentTypes) { - argumentTypes.emplace_back(parseTypeSignature(typeName)); + argumentTypes.emplace_back(typeParser_->parse(typeName)); } // TODO(spershin): In some cases we can visit this method with the same lambda @@ -728,7 +731,7 @@ std::shared_ptr VeloxExprConverter::toVeloxExpr( std::shared_ptr VeloxExprConverter::toVeloxExpr( const protocol::VariableReferenceExpression& pexpr) const { return std::make_shared( - parseTypeSignature(pexpr.type), pexpr.name); + typeParser_->parse(pexpr.type), pexpr.name); } TypedExprPtr VeloxExprConverter::toVeloxExpr( diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index d45b753e22703..719514c0c6744 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -14,14 +14,20 @@ #pragma once #include +// antlr-common.h undefines the EOF macro that external/json/nlohmann/json.hpp +// relies on, so include presto_protcol.h before TypeParser.h +// clang-format off #include "presto_cpp/presto_protocol/presto_protocol.h" +#include "presto_cpp/main/types/TypeParser.h" +// clang-format on #include "velox/core/Expressions.h" namespace facebook::presto { class VeloxExprConverter { public: - explicit VeloxExprConverter(velox::memory::MemoryPool* pool) : pool_(pool) {} + VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser) + : pool_(pool), typeParser_(typeParser) {} std::shared_ptr toVeloxExpr( std::shared_ptr pexpr) const; @@ -60,7 +66,8 @@ class VeloxExprConverter { std::optional tryConvertDate( const protocol::CallExpression& pexpr) const; - velox::memory::MemoryPool* pool_; + velox::memory::MemoryPool* const pool_; + TypeParser* const typeParser_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index d94a69f9ed5b2..53486a4be3ae5 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -27,12 +27,10 @@ #include "velox/expression/Expr.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/FlatVector.h" -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" #include "presto_cpp/main/operators/BroadcastWrite.h" #include "presto_cpp/main/operators/PartitionAndSerialize.h" #include "presto_cpp/main/operators/ShuffleWrite.h" #include "presto_cpp/main/operators/ShuffleRead.h" -#include "presto_cpp/presto_protocol/presto_protocol.h" #include #include "velox/common/compression/Compression.h" // clang-format on @@ -47,16 +45,19 @@ namespace facebook::presto { namespace { -TypePtr stringToType(const std::string& typeString) { - return TypeSignatureTypeConverter::parse(typeString); +TypePtr stringToType( + const std::string& typeString, + const TypeParser& typeParser) { + return typeParser.parse(typeString); } std::vector stringToTypes( - const std::shared_ptr>& typeStrings) { + const std::shared_ptr>& typeStrings, + const TypeParser& typeParser) { std::vector types; types.reserve(typeStrings->size()); for (const auto& typeString : *typeStrings) { - types.push_back(stringToType(typeString)); + types.push_back(stringToType(typeString, typeParser)); } return types; } @@ -74,6 +75,7 @@ std::vector getNames(const protocol::Assignments& assignments) { RowTypePtr toRowType( const std::vector& variables, + const TypeParser& typeParser, const std::unordered_set& excludeNames = {}) { std::vector names; std::vector types; @@ -85,7 +87,7 @@ RowTypePtr toRowType( continue; } names.emplace_back(variable.name); - types.emplace_back(stringToType(variable.type)); + types.emplace_back(stringToType(variable.type, typeParser)); } return ROW(std::move(names), std::move(types)); @@ -122,7 +124,8 @@ std::vector toRequiredSubfields( } std::shared_ptr toColumnHandle( - const protocol::ColumnHandle* column) { + const protocol::ColumnHandle* column, + const TypeParser& typeParser) { velox::type::fbhive::HiveTypeParser hiveTypeParser; if (auto hiveColumn = dynamic_cast(column)) { @@ -131,7 +134,7 @@ std::shared_ptr toColumnHandle( return std::make_shared( hiveColumn->name, toHiveColumnType(hiveColumn->columnType), - stringToType(hiveColumn->typeSignature), + stringToType(hiveColumn->typeSignature, typeParser), hiveTypeParser.parse(hiveColumn->hiveType), toRequiredSubfields(hiveColumn->requiredSubfields)); } @@ -689,11 +692,12 @@ std::unique_ptr toFilter( std::unique_ptr toFilter( const protocol::Domain& domain, - const VeloxExprConverter& exprConverter) { + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) { auto nullAllowed = domain.nullAllowed; if (auto sortedRangeSet = std::dynamic_pointer_cast(domain.values)) { - auto type = stringToType(sortedRangeSet->type); + auto type = stringToType(sortedRangeSet->type, typeParser); auto ranges = sortedRangeSet->ranges; if (ranges.empty()) { @@ -802,13 +806,14 @@ std::unique_ptr toFilter( std::shared_ptr toConnectorTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, std::unordered_map>& partitionColumns) { if (auto hiveLayout = std::dynamic_pointer_cast( tableHandle.connectorTableLayout)) { for (const auto& entry : hiveLayout->partitionColumns) { - partitionColumns.emplace(entry.name, toColumnHandle(&entry)); + partitionColumns.emplace(entry.name, toColumnHandle(&entry, typeParser)); } connector::hive::SubfieldFilters subfieldFilters; @@ -816,7 +821,7 @@ std::shared_ptr toConnectorTableHandle( for (const auto& domain : *domains) { auto filter = domain.second; subfieldFilters[common::Subfield(domain.first)] = - toFilter(domain.second, exprConverter); + toFilter(domain.second, exprConverter, typeParser); } auto remainingFilter = @@ -1122,6 +1127,7 @@ core::LocalPartitionNode::Type toLocalExchangeType( std::vector> toHiveColumns( const protocol::List& inputColumns, + TypeParser& typeParser, bool& hasPartitionColumn) { hasPartitionColumn = false; std::vector> @@ -1132,7 +1138,7 @@ toHiveColumns( columnHandle.columnType == protocol::ColumnType::PARTITION_KEY; hiveColumns.emplace_back( std::dynamic_pointer_cast( - toColumnHandle(&columnHandle))); + toColumnHandle(&columnHandle, typeParser))); } return hiveColumns; } @@ -1180,7 +1186,8 @@ std::vector> toHiveSortingColumns( std::shared_ptr toHiveBucketProperty( const std::vector>& inputColumns, - const std::shared_ptr& bucketProperty) { + const std::shared_ptr& bucketProperty, + const TypeParser& typeParser) { if (bucketProperty == nullptr) { return nullptr; } @@ -1222,7 +1229,7 @@ std::shared_ptr toHiveBucketProperty( bucketProperty->bucketedBy.size(), "Bucketed types is not set properly for presto native bucket function: {}", toJsonString(*bucketProperty)); - bucketedTypes = stringToTypes(bucketProperty->types); + bucketedTypes = stringToTypes(bucketProperty->types, typeParser); } const auto sortedBy = toHiveSortingColumns(bucketProperty->sortedBy); @@ -1326,7 +1333,8 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const auto type = toLocalExchangeType(node->type); - const auto outputType = toRowType(node->partitioningScheme.outputLayout); + const auto outputType = + toRowType(node->partitioningScheme.outputLayout, typeParser_); // Different source nodes may have different output layouts. // Add ProjectNode on top of each source node to re-arrange the output columns @@ -1336,7 +1344,7 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( std::vector projections; projections.reserve(outputType->size()); - const auto desiredSourceOutput = toRowType(node->inputs[i]); + const auto desiredSourceOutput = toRowType(node->inputs[i], typeParser_); for (auto j = 0; j < outputType->size(); j++) { projections.emplace_back(std::make_shared( @@ -1766,7 +1774,8 @@ void VeloxQueryPlanConverterBase::toAggregations( const auto& signature = builtin->signature; aggregate.rawInputTypes.reserve(signature.argumentTypes.size()); for (const auto& argumentType : signature.argumentTypes) { - aggregate.rawInputTypes.push_back(stringToType(argumentType)); + aggregate.rawInputTypes.push_back( + stringToType(argumentType, typeParser_)); } } else if ( auto sqlFunction = @@ -1782,12 +1791,14 @@ void VeloxQueryPlanConverterBase::toAggregations( auto pos = functionId.find(";", start + 1); if (pos == std::string::npos) { auto argumentType = functionId.substr(start + 1); - aggregate.rawInputTypes.push_back(stringToType(argumentType)); + aggregate.rawInputTypes.push_back( + stringToType(argumentType, typeParser_)); break; } auto argumentType = functionId.substr(start + 1, pos - start - 1); - aggregate.rawInputTypes.push_back(stringToType(argumentType)); + aggregate.rawInputTypes.push_back( + stringToType(argumentType, typeParser_)); pos = start + 1; } } @@ -1821,7 +1832,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& /* tableWriteInfo */, const protocol::TaskId& taskId) { - auto rowType = toRowType(node->outputVariables); + auto rowType = toRowType(node->outputVariables, typeParser_); vector_size_t numRows = node->rows.size(); auto numColumns = rowType->size(); std::vector vectors; @@ -1866,14 +1877,15 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& /* tableWriteInfo */, const protocol::TaskId& taskId) { - auto rowType = toRowType(node->outputVariables); + auto rowType = toRowType(node->outputVariables, typeParser_); std::unordered_map> assignments; for (const auto& entry : node->assignments) { - assignments.emplace(entry.first.name, toColumnHandle(entry.second.get())); + assignments.emplace( + entry.first.name, toColumnHandle(entry.second.get(), typeParser_)); } - auto connectorTableHandle = - toConnectorTableHandle(node->table, exprConverter_, assignments); + auto connectorTableHandle = toConnectorTableHandle( + node->table, exprConverter_, typeParser_, assignments); return std::make_shared( node->id, rowType, connectorTableHandle, assignments); } @@ -2066,7 +2078,7 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( node->filter ? exprConverter_.toVeloxExpr(*node->filter) : nullptr, toVeloxQueryPlan(node->left, tableWriteInfo, taskId), toVeloxQueryPlan(node->right, tableWriteInfo, taskId), - toRowType(node->outputVariables)); + toRowType(node->outputVariables, typeParser_)); } VELOX_UNSUPPORTED( "JoinNode has empty criteria that cannot be " @@ -2092,7 +2104,7 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( node->filter ? exprConverter_.toVeloxExpr(*node->filter) : nullptr, toVeloxQueryPlan(node->left, tableWriteInfo, taskId), toVeloxQueryPlan(node->right, tableWriteInfo, taskId), - toRowType(node->outputVariables)); + toRowType(node->outputVariables, typeParser_)); } velox::core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( @@ -2161,7 +2173,7 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( node->filter ? exprConverter_.toVeloxExpr(*node->filter) : nullptr, toVeloxQueryPlan(node->left, tableWriteInfo, taskId), toVeloxQueryPlan(node->right, tableWriteInfo, taskId), - toRowType(node->outputVariables)); + toRowType(node->outputVariables, typeParser_)); } std::shared_ptr @@ -2236,8 +2248,8 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( VELOX_USER_CHECK_NOT_NULL(hiveOutputTableHandle); bool isPartitioned{false}; - const auto inputColumns = - toHiveColumns(hiveOutputTableHandle->inputColumns, isPartitioned); + const auto inputColumns = toHiveColumns( + hiveOutputTableHandle->inputColumns, typeParser_, isPartitioned); VELOX_USER_CHECK( hiveOutputTableHandle->bucketProperty == nullptr || isPartitioned, "Bucketed table must be partitioned: {}", @@ -2247,7 +2259,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( toLocationHandle(hiveOutputTableHandle->locationHandle), toFileFormat(hiveOutputTableHandle->tableStorageFormat, "TableWrite"), toHiveBucketProperty( - inputColumns, hiveOutputTableHandle->bucketProperty), + inputColumns, hiveOutputTableHandle->bucketProperty, typeParser_), std::optional( toFileCompressionKind(hiveOutputTableHandle->compressionCodec))); } else if ( @@ -2261,8 +2273,8 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( VELOX_USER_CHECK_NOT_NULL(hiveInsertTableHandle); bool isPartitioned{false}; - const auto inputColumns = - toHiveColumns(hiveInsertTableHandle->inputColumns, isPartitioned); + const auto inputColumns = toHiveColumns( + hiveInsertTableHandle->inputColumns, typeParser_, isPartitioned); VELOX_USER_CHECK( hiveInsertTableHandle->bucketProperty == nullptr || isPartitioned, "Bucketed table must be partitioned: {}", @@ -2275,7 +2287,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( toLocationHandle(hiveInsertTableHandle->locationHandle), toFileFormat(hiveInsertTableHandle->tableStorageFormat, "TableWrite"), toHiveBucketProperty( - inputColumns, hiveInsertTableHandle->bucketProperty), + inputColumns, hiveInsertTableHandle->bucketProperty, typeParser_), std::optional( toFileCompressionKind(hiveInsertTableHandle->compressionCodec)), std::unordered_map( @@ -2290,11 +2302,13 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( auto insertTableHandle = std::make_shared(connectorId, hiveTableHandle); - const auto outputType = toRowType(generateOutputVariables( - {node->rowCountVariable, - node->fragmentVariable, - node->tableCommitContextVariable}, - node->statisticsAggregation)); + const auto outputType = toRowType( + generateOutputVariables( + {node->rowCountVariable, + node->fragmentVariable, + node->tableCommitContextVariable}, + node->statisticsAggregation), + typeParser_); const auto sourceVeloxPlan = toVeloxQueryPlan(node->source, tableWriteInfo, taskId); std::shared_ptr aggregationNode = @@ -2307,7 +2321,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( taskId); return std::make_shared( node->id, - toRowType(node->columns), + toRowType(node->columns, typeParser_), node->columnNames, std::move(aggregationNode), std::move(insertTableHandle), @@ -2322,11 +2336,13 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& tableWriteInfo, const protocol::TaskId& taskId) { - const auto outputType = toRowType(generateOutputVariables( - {node->rowCountVariable, - node->fragmentVariable, - node->tableCommitContextVariable}, - node->statisticsAggregation)); + const auto outputType = toRowType( + generateOutputVariables( + {node->rowCountVariable, + node->fragmentVariable, + node->tableCommitContextVariable}, + node->statisticsAggregation), + typeParser_); const auto sourceVeloxPlan = toVeloxQueryPlan(node->source, tableWriteInfo, taskId); std::shared_ptr aggregationNode = @@ -2491,10 +2507,11 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( namespace { core::WindowNode::Function makeRowNumberFunction( - const protocol::VariableReferenceExpression& rowNumberVariable) { + const protocol::VariableReferenceExpression& rowNumberVariable, + const TypeParser& typeParser) { core::WindowNode::Function function; function.functionCall = std::make_shared( - stringToType(rowNumberVariable.type), + stringToType(rowNumberVariable.type, typeParser), std::vector{}, "presto.default.row_number"); @@ -2753,7 +2770,7 @@ core::PlanFragment VeloxQueryPlanConverterBase::toVeloxQueryPlan( setCellFromVariant(constValues.back(), 0, constExpr->value()); } } - auto outputType = toRowType(partitioningScheme.outputLayout); + auto outputType = toRowType(partitioningScheme.outputLayout, typeParser_); if (auto systemPartitioningHandle = std::dynamic_pointer_cast( @@ -2870,7 +2887,7 @@ core::PlanFragment VeloxQueryPlanConverterBase::toVeloxQueryPlan( bucketToPartition, keyChannels, constValues), - toRowType(partitioningScheme.outputLayout), + toRowType(partitioningScheme.outputLayout, typeParser_), sourceNode); return planFragment; } @@ -2885,7 +2902,7 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const protocol::TaskId& taskId) { return core::PartitionedOutputNode::single( node->id, - toRowType(node->outputVariables), + toRowType(node->outputVariables, typeParser_), VeloxQueryPlanConverterBase::toVeloxQueryPlan( node->source, tableWriteInfo, taskId)); } @@ -2894,7 +2911,7 @@ velox::core::PlanNodePtr VeloxInteractiveQueryPlanConverter::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& /* tableWriteInfo */, const protocol::TaskId& taskId) { - auto rowType = toRowType(node->outputVariables); + auto rowType = toRowType(node->outputVariables, typeParser_); if (node->orderingScheme) { std::vector sortingKeys; std::vector sortingOrders; @@ -2992,7 +3009,7 @@ velox::core::PlanNodePtr VeloxBatchQueryPlanConverter::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& /* tableWriteInfo */, const protocol::TaskId& taskId) { - auto rowType = toRowType(node->outputVariables); + auto rowType = toRowType(node->outputVariables, typeParser_); // Broadcast exchange source. if (node->exchangeType == protocol::ExchangeNodeType::REPLICATE) { return std::make_shared(node->id, rowType); diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h index 1418305231bcf..daa46b70ebdc2 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h @@ -22,20 +22,22 @@ #include "velox/type/Variant.h" #include "presto_cpp/main/types/PrestoTaskId.h" -// TypeSignatureTypeConverter.h must be included after presto_protocol.h -// because it changes the macro EOF in some way (maybe deleting it?) which -// is used in external/json/nlohmann/json.hpp -// #include "presto_cpp/main/types/PrestoToVeloxExpr.h" +// antlr-common.h undefines the EOF macro that external/json/nlohmann/json.hpp +// relies on, so include presto_protcol.h before TypeParser.h +// clang-format off +#include "presto_cpp/presto_protocol/presto_protocol.h" +#include "presto_cpp/main/types/TypeParser.h" +// clang-format on namespace facebook::presto { class VeloxQueryPlanConverterBase { public: - explicit VeloxQueryPlanConverterBase( + VeloxQueryPlanConverterBase( velox::core::QueryCtx* queryCtx, velox::memory::MemoryPool* pool) - : pool_(pool), queryCtx_{queryCtx}, exprConverter_(pool) {} + : pool_(pool), queryCtx_{queryCtx}, exprConverter_(pool, &typeParser_) {} virtual ~VeloxQueryPlanConverterBase() = default; @@ -215,9 +217,10 @@ class VeloxQueryPlanConverterBase { std::vector& aggregates, std::vector& aggregateNames); - velox::memory::MemoryPool* pool_; - velox::core::QueryCtx* queryCtx_; + velox::memory::MemoryPool* const pool_; + velox::core::QueryCtx* const queryCtx_; VeloxExprConverter exprConverter_; + TypeParser typeParser_; }; class VeloxInteractiveQueryPlanConverter : public VeloxQueryPlanConverterBase { diff --git a/presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.cpp b/presto-native-execution/presto_cpp/main/types/TypeParser.cpp similarity index 77% rename from presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.cpp rename to presto-native-execution/presto_cpp/main/types/TypeParser.cpp index 5c28d4b3d1cc5..8a9b6a5f04438 100644 --- a/presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.cpp +++ b/presto-native-execution/presto_cpp/main/types/TypeParser.cpp @@ -15,8 +15,7 @@ #include #include -#include "presto_cpp/main/types/ParseTypeSignature.h" -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" #include "presto_cpp/main/types/antlr/TypeSignatureLexer.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -25,30 +24,36 @@ using namespace facebook::velox; namespace facebook::presto { - -TypePtr parseTypeSignature(const std::string& signature) { - return TypeSignatureTypeConverter::parse(signature); -} - -// static -TypePtr TypeSignatureTypeConverter::parse(const std::string& text) { - antlr4::ANTLRInputStream input(text); - type::TypeSignatureLexer lexer(&input); - antlr4::CommonTokenStream tokens(&lexer); - type::TypeSignatureParser parser(&tokens); - - parser.setErrorHandler(std::make_shared()); - - try { - auto ctx = parser.start(); - TypeSignatureTypeConverter c; - return c.visit(ctx).as(); - } catch (const std::exception& e) { - VELOX_USER_FAIL("Failed to parse type [{}]: {}", text, e.what()); - } -} - namespace { +class TypeSignatureTypeConverter : public type::TypeSignatureBaseVisitor { + public: + virtual antlrcpp::Any visitStart( + type::TypeSignatureParser::StartContext* ctx) override; + virtual antlrcpp::Any visitNamed_type( + type::TypeSignatureParser::Named_typeContext* ctx) override; + virtual antlrcpp::Any visitType_spec( + type::TypeSignatureParser::Type_specContext* ctx) override; + virtual antlrcpp::Any visitType( + type::TypeSignatureParser::TypeContext* ctx) override; + virtual antlrcpp::Any visitSimple_type( + type::TypeSignatureParser::Simple_typeContext* ctx) override; + virtual antlrcpp::Any visitDecimal_type( + type::TypeSignatureParser::Decimal_typeContext* ctx) override; + virtual antlrcpp::Any visitVariable_type( + type::TypeSignatureParser::Variable_typeContext* ctx) override; + virtual antlrcpp::Any visitType_list( + type::TypeSignatureParser::Type_listContext* ctx) override; + virtual antlrcpp::Any visitRow_type( + type::TypeSignatureParser::Row_typeContext* ctx) override; + virtual antlrcpp::Any visitMap_type( + type::TypeSignatureParser::Map_typeContext* ctx) override; + virtual antlrcpp::Any visitArray_type( + type::TypeSignatureParser::Array_typeContext* ctx) override; + virtual antlrcpp::Any visitFunction_type( + type::TypeSignatureParser::Function_typeContext* ctx) override; + virtual antlrcpp::Any visitIdentifier( + type::TypeSignatureParser::IdentifierContext* ctx) override; +}; TypePtr typeFromString(const std::string& typeName) { auto upper = boost::to_upper_copy(typeName); @@ -95,8 +100,6 @@ struct NamedType { velox::TypePtr type; }; -} // namespace - antlrcpp::Any TypeSignatureTypeConverter::visitStart( type::TypeSignatureParser::StartContext* ctx) { NamedType named = visit(ctx->type_spec()).as(); @@ -209,5 +212,29 @@ antlrcpp::Any TypeSignatureTypeConverter::visitIdentifier( 1, ctx->QUOTED_ID()->getText().length() - 2); } } +} + +TypePtr TypeParser::parse(const std::string& text) const { + auto it = cache_.find(text); + if (it != cache_.end()) { + return it->second; + } + + antlr4::ANTLRInputStream input(text); + type::TypeSignatureLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + type::TypeSignatureParser parser(&tokens); + parser.setErrorHandler(std::make_shared()); + + try { + auto ctx = parser.start(); + TypeSignatureTypeConverter c; + auto result = c.visit(ctx).as(); + cache_.insert({text, result}); + return result; + } catch (const std::exception& e) { + VELOX_USER_FAIL("Failed to parse type [{}]: {}", text, e.what()); + } +} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/TypeParser.h b/presto-native-execution/presto_cpp/main/types/TypeParser.h new file mode 100644 index 0000000000000..c9c4d0ecacc55 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/TypeParser.h @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/type/Type.h" + +#include "presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h" + +namespace facebook::presto { + +class TypeParser { + public: + velox::TypePtr parse(const std::string& text) const; + + private: + mutable std::unordered_map cache_; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.h b/presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.h deleted file mode 100644 index d48b266ca7406..0000000000000 --- a/presto-native-execution/presto_cpp/main/types/TypeSignatureTypeConverter.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "velox/type/Type.h" - -#include "presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h" - -namespace facebook::presto { - -class TypeSignatureTypeConverter : type::TypeSignatureBaseVisitor { - public: - static velox::TypePtr parse(const std::string& text); - - private: - virtual antlrcpp::Any visitStart( - type::TypeSignatureParser::StartContext* ctx) override; - virtual antlrcpp::Any visitNamed_type( - type::TypeSignatureParser::Named_typeContext* ctx) override; - virtual antlrcpp::Any visitType_spec( - type::TypeSignatureParser::Type_specContext* ctx) override; - virtual antlrcpp::Any visitType( - type::TypeSignatureParser::TypeContext* ctx) override; - virtual antlrcpp::Any visitSimple_type( - type::TypeSignatureParser::Simple_typeContext* ctx) override; - virtual antlrcpp::Any visitDecimal_type( - type::TypeSignatureParser::Decimal_typeContext* ctx) override; - virtual antlrcpp::Any visitVariable_type( - type::TypeSignatureParser::Variable_typeContext* ctx) override; - virtual antlrcpp::Any visitType_list( - type::TypeSignatureParser::Type_listContext* ctx) override; - virtual antlrcpp::Any visitRow_type( - type::TypeSignatureParser::Row_typeContext* ctx) override; - virtual antlrcpp::Any visitMap_type( - type::TypeSignatureParser::Map_typeContext* ctx) override; - virtual antlrcpp::Any visitArray_type( - type::TypeSignatureParser::Array_typeContext* ctx) override; - virtual antlrcpp::Any visitFunction_type( - type::TypeSignatureParser::Function_typeContext* ctx) override; - virtual antlrcpp::Any visitIdentifier( - type::TypeSignatureParser::IdentifierContext* ctx) override; -}; - -} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.cpp b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.cpp index d952b393a4237..d60c4376261ae 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.cpp +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.cpp @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h index 8cff043a8d880..ad848c4fa9704 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.cpp b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.cpp index de519535d4c03..a4b2bf569a171 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.cpp +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.cpp @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.h b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.h index 99fa967a7c631..645eccdf24f22 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.h +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureParser.h @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.cpp b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.cpp index 9b9e4745aca20..f4fc07cc6c8a8 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.cpp +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.cpp @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.h b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.h index 7360b1e81cbdf..25ef8e877df60 100644 --- a/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.h +++ b/presto-native-execution/presto_cpp/main/types/antlr/TypeSignatureVisitor.h @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/TypeSignatureTypeConverter.h" +#include "presto_cpp/main/types/TypeParser.h" // Generated from TypeSignature.g4 by ANTLR 4.9.3 diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp index 2b837ab06e06d..605492a6a723e 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp @@ -27,7 +27,8 @@ class RowExpressionTest : public ::testing::Test { public: void SetUp() override { pool_ = memory::addDefaultLeafMemoryPool(); - converter_ = std::make_unique(pool_.get()); + converter_ = + std::make_unique(pool_.get(), &typeParser_); } void testConstantExpression( @@ -46,6 +47,7 @@ class RowExpressionTest : public ::testing::Test { std::shared_ptr pool_; std::unique_ptr converter_; + TypeParser typeParser_; }; TEST_F(RowExpressionTest, bigInt) { diff --git a/presto-native-execution/presto_cpp/main/types/tests/TypeSignatureTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/TypeSignatureTest.cpp index e20d5410dc961..72832723b546e 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/TypeSignatureTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/TypeSignatureTest.cpp @@ -13,7 +13,7 @@ */ #include -#include "presto_cpp/main/types/ParseTypeSignature.h" +#include "presto_cpp/main/types/TypeParser.h" #include "velox/common/base/tests/GTestUtils.h" using namespace facebook::velox; @@ -21,184 +21,184 @@ using namespace facebook::velox; namespace facebook::presto { namespace { -class TestTypeSignature : public ::testing::Test {}; +class TestTypeSignature : public ::testing::Test { + public: + TypeParser typeParser; +}; TEST_F(TestTypeSignature, booleanType) { - ASSERT_EQ(*parseTypeSignature("boolean"), *BOOLEAN()); + ASSERT_EQ(*typeParser.parse("boolean"), *BOOLEAN()); } TEST_F(TestTypeSignature, integerType) { - ASSERT_EQ(*parseTypeSignature("int"), *INTEGER()); - ASSERT_EQ(*parseTypeSignature("integer"), *INTEGER()); + ASSERT_EQ(*typeParser.parse("int"), *INTEGER()); + ASSERT_EQ(*typeParser.parse("integer"), *INTEGER()); } TEST_F(TestTypeSignature, varcharType) { - ASSERT_EQ(*parseTypeSignature("varchar"), *VARCHAR()); + ASSERT_EQ(*typeParser.parse("varchar"), *VARCHAR()); } TEST_F(TestTypeSignature, varbinary) { - ASSERT_EQ(*parseTypeSignature("varbinary"), *VARBINARY()); + ASSERT_EQ(*typeParser.parse("varbinary"), *VARBINARY()); } TEST_F(TestTypeSignature, arrayType) { - ASSERT_EQ(*parseTypeSignature("array(bigint)"), *ARRAY(BIGINT())); + ASSERT_EQ(*typeParser.parse("array(bigint)"), *ARRAY(BIGINT())); - ASSERT_EQ(*parseTypeSignature("array(int)"), *ARRAY(INTEGER())); - ASSERT_EQ(*parseTypeSignature("array(integer)"), *ARRAY(INTEGER())); + ASSERT_EQ(*typeParser.parse("array(int)"), *ARRAY(INTEGER())); + ASSERT_EQ(*typeParser.parse("array(integer)"), *ARRAY(INTEGER())); - ASSERT_EQ( - *parseTypeSignature("array(array(bigint))"), *ARRAY(ARRAY(BIGINT()))); + ASSERT_EQ(*typeParser.parse("array(array(bigint))"), *ARRAY(ARRAY(BIGINT()))); - ASSERT_EQ(*parseTypeSignature("array(array(int))"), *ARRAY(ARRAY(INTEGER()))); + ASSERT_EQ(*typeParser.parse("array(array(int))"), *ARRAY(ARRAY(INTEGER()))); } TEST_F(TestTypeSignature, mapType) { - ASSERT_EQ( - *parseTypeSignature("map(bigint,bigint)"), *MAP(BIGINT(), BIGINT())); + ASSERT_EQ(*typeParser.parse("map(bigint,bigint)"), *MAP(BIGINT(), BIGINT())); ASSERT_EQ( - *parseTypeSignature("map(bigint,array(bigint))"), + *typeParser.parse("map(bigint,array(bigint))"), *MAP(BIGINT(), ARRAY(BIGINT()))); ASSERT_EQ( - *parseTypeSignature("map(bigint,map(bigint,map(varchar,bigint)))"), + *typeParser.parse("map(bigint,map(bigint,map(varchar,bigint)))"), *MAP(BIGINT(), MAP(BIGINT(), MAP(VARCHAR(), BIGINT())))); } TEST_F(TestTypeSignature, invalidType) { VELOX_ASSERT_THROW( - parseTypeSignature("blah()"), "Failed to parse type [blah()]"); + typeParser.parse("blah()"), "Failed to parse type [blah()]"); VELOX_ASSERT_THROW( - parseTypeSignature("array()"), "Failed to parse type [array()]"); + typeParser.parse("array()"), "Failed to parse type [array()]"); - VELOX_ASSERT_THROW( - parseTypeSignature("map()"), "Failed to parse type [map()]"); + VELOX_ASSERT_THROW(typeParser.parse("map()"), "Failed to parse type [map()]"); - VELOX_ASSERT_THROW(parseTypeSignature("x"), "Failed to parse type [x]"); + VELOX_ASSERT_THROW(typeParser.parse("x"), "Failed to parse type [x]"); // Ensure this is not treated as a row type. VELOX_ASSERT_THROW( - parseTypeSignature("rowxxx(a)"), "Failed to parse type [rowxxx(a)]"); + typeParser.parse("rowxxx(a)"), "Failed to parse type [rowxxx(a)]"); } TEST_F(TestTypeSignature, rowType) { ASSERT_EQ( - *parseTypeSignature("row(a bigint,b varchar,c real)"), + *typeParser.parse("row(a bigint,b varchar,c real)"), *ROW({"a", "b", "c"}, {BIGINT(), VARCHAR(), REAL()})); ASSERT_EQ( - *parseTypeSignature("row(a bigint,b array(bigint),c row(a bigint))"), + *typeParser.parse("row(a bigint,b array(bigint),c row(a bigint))"), *ROW( {"a", "b", "c"}, {BIGINT(), ARRAY(BIGINT()), ROW({"a"}, {BIGINT()})})); ASSERT_EQ( - *parseTypeSignature("row(\"12\" bigint,b bigint,c bigint)"), + *typeParser.parse("row(\"12\" bigint,b bigint,c bigint)"), *ROW({"12", "b", "c"}, {BIGINT(), BIGINT(), BIGINT()})); ASSERT_EQ( - *parseTypeSignature("row(a varchar(10),b row(a bigint))"), + *typeParser.parse("row(a varchar(10),b row(a bigint))"), *ROW({"a", "b"}, {VARCHAR(), ROW({"a"}, {BIGINT()})})); ASSERT_EQ( - *parseTypeSignature("array(row(col0 bigint,col1 double))"), + *typeParser.parse("array(row(col0 bigint,col1 double))"), *ARRAY(ROW({"col0", "col1"}, {BIGINT(), DOUBLE()}))); ASSERT_EQ( - *parseTypeSignature("row(col0 array(row(col0 bigint,col1 double)))"), + *typeParser.parse("row(col0 array(row(col0 bigint,col1 double)))"), *ROW({"col0"}, {ARRAY(ROW({"col0", "col1"}, {BIGINT(), DOUBLE()}))})); ASSERT_EQ( - *parseTypeSignature("row(bigint,varchar)"), *ROW({BIGINT(), VARCHAR()})); + *typeParser.parse("row(bigint,varchar)"), *ROW({BIGINT(), VARCHAR()})); ASSERT_EQ( - *parseTypeSignature("row(bigint,array(bigint),row(a bigint))"), + *typeParser.parse("row(bigint,array(bigint),row(a bigint))"), *ROW({BIGINT(), ARRAY(BIGINT()), ROW({"a"}, {BIGINT()})})); ASSERT_EQ( - *parseTypeSignature("row(varchar(10),b row(bigint))"), + *typeParser.parse("row(varchar(10),b row(bigint))"), *ROW({"", "b"}, {VARCHAR(), ROW({BIGINT()})})); ASSERT_EQ( - *parseTypeSignature("array(row(col0 bigint,double))"), + *typeParser.parse("array(row(col0 bigint,double))"), *ARRAY(ROW({"col0", ""}, {BIGINT(), DOUBLE()}))); ASSERT_EQ( - *parseTypeSignature("row(col0 array(row(bigint,double)))"), + *typeParser.parse("row(col0 array(row(bigint,double)))"), *ROW({"col0"}, {ARRAY(ROW({BIGINT(), DOUBLE()}))})); ASSERT_EQ( - *parseTypeSignature("row(double double precision)"), + *typeParser.parse("row(double double precision)"), *ROW({"double"}, {DOUBLE()})); - ASSERT_EQ(*parseTypeSignature("row(double precision)"), *ROW({DOUBLE()})); + ASSERT_EQ(*typeParser.parse("row(double precision)"), *ROW({DOUBLE()})); ASSERT_EQ( - *parseTypeSignature("RoW(a bigint,b varchar)"), + *typeParser.parse("RoW(a bigint,b varchar)"), *ROW({"a", "b"}, {BIGINT(), VARCHAR()})); // Field type canonicalization. - ASSERT_EQ(*parseTypeSignature("row(col iNt)"), *ROW({"col"}, {INTEGER()})); + ASSERT_EQ(*typeParser.parse("row(col iNt)"), *ROW({"col"}, {INTEGER()})); } TEST_F(TestTypeSignature, typesWithSpaces) { VELOX_ASSERT_THROW( - parseTypeSignature("row(time time with time zone)"), + typeParser.parse("row(time time with time zone)"), "Specified element is not found : TIME WITH TIME ZONE"); ASSERT_EQ( - *parseTypeSignature("row(double double precision)"), + *typeParser.parse("row(double double precision)"), *ROW({"double"}, {DOUBLE()})); VELOX_ASSERT_THROW( - parseTypeSignature("row(time with time zone)"), + typeParser.parse("row(time with time zone)"), "Specified element is not found : TIME WITH TIME ZONE"); - ASSERT_EQ(*parseTypeSignature("row(double precision)"), *ROW({DOUBLE()})); + ASSERT_EQ(*typeParser.parse("row(double precision)"), *ROW({DOUBLE()})); VELOX_ASSERT_THROW( - parseTypeSignature("row(array(time with time zone))"), + typeParser.parse("row(array(time with time zone))"), "Specified element is not found : TIME WITH TIME ZONE"); // quoted field names VELOX_ASSERT_THROW( - parseTypeSignature( + typeParser.parse( "row(\"time with time zone\" time with time zone,\"double\" double)"), "Specified element is not found : TIME WITH TIME ZONE"); } TEST_F(TestTypeSignature, intervalYearToMonthType) { ASSERT_EQ( - *parseTypeSignature("row(interval interval year to month)"), + *typeParser.parse("row(interval interval year to month)"), *ROW({"interval"}, {INTERVAL_YEAR_MONTH()})); ASSERT_EQ( - *parseTypeSignature("row(interval year to month)"), + *typeParser.parse("row(interval year to month)"), *ROW({INTERVAL_YEAR_MONTH()})); } TEST_F(TestTypeSignature, functionType) { ASSERT_EQ( - *parseTypeSignature("function(bigint,bigint,bigint)"), + *typeParser.parse("function(bigint,bigint,bigint)"), *FUNCTION({BIGINT(), BIGINT()}, BIGINT())); ASSERT_EQ( - *parseTypeSignature("function(bigint,array(varchar),varchar)"), + *typeParser.parse("function(bigint,array(varchar),varchar)"), *FUNCTION({BIGINT(), ARRAY(VARCHAR())}, VARCHAR())); } TEST_F(TestTypeSignature, decimalType) { - ASSERT_EQ(*parseTypeSignature("decimal(10, 5)"), *DECIMAL(10, 5)); - ASSERT_EQ(*parseTypeSignature("decimal(20,10)"), *DECIMAL(20, 10)); + ASSERT_EQ(*typeParser.parse("decimal(10, 5)"), *DECIMAL(10, 5)); + ASSERT_EQ(*typeParser.parse("decimal(20,10)"), *DECIMAL(20, 10)); VELOX_ASSERT_THROW( - parseTypeSignature("decimal"), "Failed to parse type [decimal]"); + typeParser.parse("decimal"), "Failed to parse type [decimal]"); VELOX_ASSERT_THROW( - parseTypeSignature("decimal()"), "Failed to parse type [decimal()]"); + typeParser.parse("decimal()"), "Failed to parse type [decimal()]"); VELOX_ASSERT_THROW( - parseTypeSignature("decimal(20)"), "Failed to parse type [decimal(20)]"); + typeParser.parse("decimal(20)"), "Failed to parse type [decimal(20)]"); VELOX_ASSERT_THROW( - parseTypeSignature("decimal(, 20)"), + typeParser.parse("decimal(, 20)"), "Failed to parse type [decimal(, 20)]"); }