diff --git a/src/substrait/proto/CMakeLists.txt b/src/substrait/proto/CMakeLists.txt index 9670f4da..f3fdcac9 100644 --- a/src/substrait/proto/CMakeLists.txt +++ b/src/substrait/proto/CMakeLists.txt @@ -67,7 +67,10 @@ foreach(PROTO_FILE IN LISTS PROTOBUF_FILELIST) endforeach() # Add the generated protobuf C++ files to our exported library. -add_library(substrait_proto ${PROTO_SRCS} ${PROTO_HDRS}) +add_library( + substrait_proto + ${PROTO_SRCS} ${PROTO_HDRS} + ProtoUtils.cpp ProtoUtils.h) # Include the protobuf library as a dependency to use this class. target_link_libraries(substrait_proto ${PROTOBUF_LIBRARIES}) diff --git a/src/substrait/proto/ProtoUtils.cpp b/src/substrait/proto/ProtoUtils.cpp new file mode 100644 index 00000000..0af55aa4 --- /dev/null +++ b/src/substrait/proto/ProtoUtils.cpp @@ -0,0 +1,47 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/proto/ProtoUtils.h" + +namespace substrait::proto { + +std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num) { + static std::vector case_names = { + "unknown", + "rel", + "root", + }; + + if (num >= case_names.size()) { + return "unknown"; + } + + return case_names[num]; +} + +std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num) { + static std::vector case_names = { + "unknown", + "read", + "filter", + "fetch", + "aggregate", + "sort", + "join", + "project", + "set", + "extensionsingle", + "extensionmulti", + "extensionleaf", + "cross", + "hashjoin", + "mergejoin", + }; + + if (num >= case_names.size()) { + return "unknown"; + } + + return case_names[num]; +} + +} // namespace substrait::proto \ No newline at end of file diff --git a/src/substrait/proto/ProtoUtils.h b/src/substrait/proto/ProtoUtils.h new file mode 100644 index 00000000..bcae6ad4 --- /dev/null +++ b/src/substrait/proto/ProtoUtils.h @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/plan.pb.h" + +namespace substrait::proto { + +std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num); + +std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num); + +} // namespace substrait::proto \ No newline at end of file diff --git a/src/substrait/textplan/Any.h b/src/substrait/textplan/Any.h new file mode 100644 index 00000000..eb84a778 --- /dev/null +++ b/src/substrait/textplan/Any.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "fmt/format.h" + +namespace io::substrait::textplan { + +template +inline ValueType any_cast(const std::any& value, const char* file, int line) { + try { + return std::any_cast(value); + } catch (std::bad_any_cast& ex) { + throw std::invalid_argument( + fmt::format("{}:{} - {}", file, line, "bad any cast")); + } +} + +// A wrapper around std::any_cast that provides exceptions with line numbers. +#define ANY_CAST(ValueType, Value) \ + ::io::substrait::textplan::any_cast(Value, __FILE__, __LINE__) + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt index f1785920..2018a1f3 100644 --- a/src/substrait/textplan/CMakeLists.txt +++ b/src/substrait/textplan/CMakeLists.txt @@ -5,7 +5,8 @@ add_subdirectory(converter) add_library(symbol_table Location.cpp Location.h SymbolTable.cpp SymbolTable.h - SymbolTablePrinter.cpp SymbolTablePrinter.h) + SymbolTablePrinter.cpp SymbolTablePrinter.h + Any.h) add_library(error_listener SubstraitErrorListener.cpp SubstraitErrorListener.h @@ -16,7 +17,15 @@ add_library(parse_result ) add_dependencies(symbol_table - substrait_proto) + substrait_proto + substrait_common + fmt::fmt-header-only + ) + +target_link_libraries( + symbol_table + fmt::fmt-header-only + ) # Provide access to the generated protobuffer headers hierarchy. target_include_directories( @@ -25,4 +34,4 @@ target_include_directories( if (${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) -endif () \ No newline at end of file +endif () diff --git a/src/substrait/textplan/Location.cpp b/src/substrait/textplan/Location.cpp index bf4ef0a5..22d218f9 100644 --- a/src/substrait/textplan/Location.cpp +++ b/src/substrait/textplan/Location.cpp @@ -3,41 +3,28 @@ #include "substrait/textplan/Location.h" #include -#include -#include namespace io::substrait::textplan { -ProtoLocation ProtoLocation::visit(const std::string& name) const { - ProtoLocation new_location = *this; - new_location.location_.push_back(name); - return new_location; -} - -std::string ProtoLocation::toString() const { - std::stringstream text; - bool written = false; - for (const auto& loc : location_) { - if (!written) { - text << " -> "; - written = true; - } - text << loc; - } - return text.str(); -} +constexpr const Location Location::kUnknownLocation( + static_cast(nullptr)); bool operator==(const Location& c1, const Location& c2) { // Test only one side since we only store one kind of content per table. - if (std::holds_alternative( - c1.loc_)) { - auto s1 = std::get(c1.loc_).toString(); - auto s2 = std::get(c2.loc_).toString(); - return s1 == s2; - } else if (std::holds_alternative(c1.loc_)) { + if (std::holds_alternative(c1.loc_)) { + if (!std::holds_alternative(c2.loc_)) { + return false; + } auto a1 = std::get(c1.loc_); auto a2 = std::get(c2.loc_); return a1 == a2; + } else if (std::holds_alternative(c1.loc_)) { + if (!std::holds_alternative(c2.loc_)) { + return false; + } + auto a1 = std::get(c1.loc_); + auto a2 = std::get(c2.loc_); + return a1 == a2; } // Should not be reached. return false; @@ -47,28 +34,35 @@ bool operator==(const Location& c1, const Location& c2) { std::size_t std::hash<::io::substrait::textplan::Location>::operator()( const ::io::substrait::textplan::Location& loc) const noexcept { - if (std::holds_alternative<::io::substrait::textplan::ProtoLocation>( - loc.loc_)) { - return std::hash()( - std::get<::io::substrait::textplan::ProtoLocation>(loc.loc_) - .toString()); - } else if (std::holds_alternative(loc.loc_)) { + if (std::holds_alternative(loc.loc_)) { return std::hash()( std::get(loc.loc_)); + } else if (std::holds_alternative(loc.loc_)) { + return std::hash()( + std::get(loc.loc_)); } // Should not be reached. return 0; } -std::size_t std::less<::io::substrait::textplan::Location>::operator()( +bool std::less<::io::substrait::textplan::Location>::operator()( const ::io::substrait::textplan::Location& lhs, const ::io::substrait::textplan::Location& rhs) const noexcept { - if (std::holds_alternative<::io::substrait::textplan::ProtoLocation>( - lhs.loc_)) { - return std::get<::io::substrait::textplan::ProtoLocation>(lhs.loc_) - .toString() < - std::get<::io::substrait::textplan::ProtoLocation>(rhs.loc_).toString(); + if (std::holds_alternative(lhs.loc_)) { + if (!std::holds_alternative(rhs.loc_)) { + // This alternative is always less than the remaining choices. + return true; + } + return std::get(lhs.loc_) < + std::get(rhs.loc_); + } else if (std::holds_alternative(lhs.loc_)) { + if (!std::holds_alternative(rhs.loc_)) { + // This alternative is always less than the remaining (zero) choices. + return true; + } + return std::get(lhs.loc_) < + std::get(rhs.loc_); } - return std::get(lhs.loc_) < - std::get(rhs.loc_); + // Should not be reached. + return false; } diff --git a/src/substrait/textplan/Location.h b/src/substrait/textplan/Location.h index 9b59bcb3..d60ac455 100644 --- a/src/substrait/textplan/Location.h +++ b/src/substrait/textplan/Location.h @@ -10,30 +10,23 @@ namespace antlr4 { class ParserRuleContext; } -namespace io::substrait::textplan { - -class ProtoLocation { - public: - [[nodiscard]] ProtoLocation visit(const std::string& name) const; - - [[nodiscard]] std::string toString() const; - - private: - std::vector location_; +namespace google::protobuf { +class Message; }; +namespace io::substrait::textplan { + // Location is used for keeping track of where a symbol is within a parse tree. // Since SymbolTable supports both antlr4 and protobuf messages there are // essentially two flavors of location. It is expected that only one type of // location would be used in any SymbolTable instance. class Location { public: - explicit Location(antlr4::ParserRuleContext* node) { - loc_ = node; - }; - explicit Location(ProtoLocation location) { - loc_ = location; - }; + constexpr explicit Location(antlr4::ParserRuleContext* node) : loc_(node) {} + + constexpr explicit Location(google::protobuf::Message* msg) : loc_(msg) {} + + static const Location kUnknownLocation; protected: friend bool operator==(const Location& c1, const Location& c2); @@ -42,7 +35,7 @@ class Location { friend std::hash; friend std::less; - std::variant loc_; + std::variant loc_; }; } // namespace io::substrait::textplan @@ -55,8 +48,7 @@ struct std::hash<::io::substrait::textplan::Location> { template <> struct std::less<::io::substrait::textplan::Location> { - std::size_t operator()( + bool operator()( const ::io::substrait::textplan::Location& lhs, const ::io::substrait::textplan::Location& rhs) const noexcept; }; - diff --git a/src/substrait/textplan/SubstraitErrorListener.cpp b/src/substrait/textplan/SubstraitErrorListener.cpp index 82b49c54..e2a6203e 100644 --- a/src/substrait/textplan/SubstraitErrorListener.cpp +++ b/src/substrait/textplan/SubstraitErrorListener.cpp @@ -25,4 +25,4 @@ std::vector SubstraitErrorListener::getErrorMessages() { return messages; } -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTable.cpp b/src/substrait/textplan/SymbolTable.cpp index 96e6f453..ad3bf30d 100644 --- a/src/substrait/textplan/SymbolTable.cpp +++ b/src/substrait/textplan/SymbolTable.cpp @@ -21,6 +21,15 @@ SymbolTableIterator SymbolTableIterator::operator++() { return *this; } +bool operator==(const SymbolInfo& left, const SymbolInfo& right) { + return (left.name == right.name) && (left.location == right.location) && + (left.type == right.type); +} + +bool operator!=(const SymbolInfo& left, const SymbolInfo& right) { + return !(left == right); +} + std::string SymbolTable::getUniqueName(const std::string& base_name) { auto symbolSeenCount = names_.find(base_name); if (symbolSeenCount == names_.end()) { @@ -58,8 +67,7 @@ SymbolInfo* SymbolTable::defineUniqueSymbol( return defineSymbol(unique_name, location, type, subtype, blob); } -const SymbolInfo& SymbolTable::lookupSymbolByName( - const std::string& name) { +const SymbolInfo& SymbolTable::lookupSymbolByName(const std::string& name) { auto itr = symbols_by_name_.find(name); if (itr == symbols_by_name_.end()) { return kUnknownSymbol; @@ -76,9 +84,7 @@ const SymbolInfo& SymbolTable::lookupSymbolByLocation( return *symbols_[itr->second]; } -const SymbolInfo& SymbolTable::nthSymbolByType( - uint32_t n, - SymbolType type) { +const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type) { int count = 0; for (const auto& symbol : symbols_) { if (symbol->type == type) { @@ -97,4 +103,11 @@ SymbolTableIterator SymbolTable::end() const { return {this, symbols_by_name_.size()}; } +const SymbolInfo SymbolTable::kUnknownSymbol = { + "__UNKNOWN__", + Location::kUnknownLocation, + SymbolType::kUnknown, + std::nullopt, + std::nullopt}; + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTable.h b/src/substrait/textplan/SymbolTable.h index c68cd7ca..33ed5a90 100644 --- a/src/substrait/textplan/SymbolTable.h +++ b/src/substrait/textplan/SymbolTable.h @@ -57,6 +57,14 @@ enum class RelationDetailType { kExpression = 1, }; +enum class SourceType { + kUnknown = 0, + kLocalFiles = 1, + kNamedTable = 2, + kVirtualTable = 3, + kExtensionTable = 4, +}; + struct SymbolInfo { std::string name; Location location; @@ -71,10 +79,13 @@ struct SymbolInfo { std::any new_subtype, std::any new_blob) : name(std::move(new_name)), - location(std::move(new_location)), + location(new_location), type(new_type), subtype(std::move(new_subtype)), blob(std::move(new_blob)){}; + + friend bool operator==(const SymbolInfo& left, const SymbolInfo& right); + friend bool operator!=(const SymbolInfo& left, const SymbolInfo& right); }; class SymbolTable; @@ -135,16 +146,6 @@ class SymbolTable { return symbols_; }; - // Temporary functions to allow externally computed text to be saved. - void addCachedOutput(const std::string& text) { - cached_output_ = text; - } - // TODO: Remove after we have the information required to reconstruct the - // plan. - [[nodiscard]] std::string getCachedOutput() const { - return cached_output_; - } - // Add the capability for ::testing::PrintToString to print this. friend std::ostream& operator<<(std::ostream& os, const SymbolTable& result) { os << std::string("{"); @@ -160,12 +161,7 @@ class SymbolTable { return os; } - SymbolInfo kUnknownSymbol = { - "__UNKNOWN__", - Location(ProtoLocation()), - SymbolType::kUnknown, - std::nullopt, - std::nullopt}; + static const SymbolInfo kUnknownSymbol; private: friend SymbolTableIterator; @@ -175,8 +171,6 @@ class SymbolTable { std::vector> symbols_; std::unordered_map symbols_by_name_; std::unordered_map symbols_by_location_; - - std::string cached_output_; }; -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 95dbc95b..96d38e18 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -5,17 +5,249 @@ #include #include +#include "substrait/common/Exceptions.h" +#include "substrait/proto/algebra.pb.h" #include "substrait/proto/extensions/extensions.pb.h" +#include "substrait/textplan/Any.h" #include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/converter/PlanPrinterVisitor.h" namespace io::substrait::textplan { -std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { - std::string cachedText = symbolTable.getCachedOutput(); - if (!cachedText.empty()) { - return cachedText; +namespace { + +void localFileToText( + const ::substrait::proto::ReadRel::LocalFiles::FileOrFiles& item, + std::stringstream* text) { + switch (item.path_type_case()) { + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriPath: + *text << "uri_path: \"" << item.uri_path() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: + *text << "uri_path_glob: \"" + item.uri_path_glob() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriFile: + *text << "uri_file: \"" << item.uri_file() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriFolder: + *text << "uri_folder: \"" << item.uri_folder() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles:: + PATH_TYPE_NOT_SET: + default: + SUBSTRAIT_UNSUPPORTED( + "Unknown path type " + std::to_string(item.path_type_case()) + + "provided."); + } + if (item.partition_index() != 0) { + *text << " partition_index: " << std::to_string(item.partition_index()); + } + if (item.start() != 0 || item.length() != 0) { + *text << " start: " << std::to_string(item.start()); + *text << " length: " << std::to_string(item.length()); + } + switch (item.file_format_case()) { + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kParquet: + *text << " parquet: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kArrow: + *text << " arrow: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kOrc: + *text << " orc: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kExtension: + SUBSTRAIT_UNSUPPORTED("Extension file format type not yet supported."); + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kDwrf: + *text << " dwrf: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles:: + FILE_FORMAT_NOT_SET: + default: + SUBSTRAIT_UNSUPPORTED( + "Unknown file format type " + + std::to_string(item.file_format_case()) + "provided."); + } +} + +std::string typeToText(const ::substrait::proto::Type& type) { + switch (type.kind_case()) { + case ::substrait::proto::Type::kBool: + if (type.bool_().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "bool?"; + return "bool"; + case ::substrait::proto::Type::kI8: + if (type.i8().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i8?"; + return "i8"; + case ::substrait::proto::Type::kI16: + if (type.i16().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i16?"; + return "i16"; + case ::substrait::proto::Type::kI32: + if (type.i32().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i32?"; + return "i32"; + case ::substrait::proto::Type::kI64: + if (type.i64().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i64?"; + return "i64"; + case ::substrait::proto::Type::kFp32: + if (type.fp32().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fp32?"; + return "fp32"; + case ::substrait::proto::Type::kFp64: + if (type.fp64().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fp64?"; + return "fp64"; + case ::substrait::proto::Type::kString: + if (type.string().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "string?"; + return "string"; + case ::substrait::proto::Type::kDecimal: + if (type.string().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "decimal?"; + return "decimal"; + case ::substrait::proto::Type::kVarchar: + if (type.varchar().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "varchar?"; + return "varchar"; + case ::substrait::proto::Type::kFixedChar: + if (type.fixed_char().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fixedchar?"; + return "fixedchar"; + case ::substrait::proto::Type::kDate: + if (type.date().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "date?"; + return "date"; + case ::substrait::proto::Type::KIND_NOT_SET: + default: + return "UNSUPPORTED_TYPE"; + } +}; + +std::string relationToText( + const SymbolTable& symbolTable, + const SymbolInfo& info) { + auto relation = ANY_CAST(const ::substrait::proto::Rel*, info.blob); + + PlanPrinterVisitor printer(symbolTable); + return printer.printRelation(info.name, relation); +} + +std::string outputPipelinesSection(const SymbolTable& symbolTable) { + // TODO: Implement. + return ""; +} + +std::string outputRelationsSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kRelation) + continue; + // TODO: Put handling this into the PlanPrinterVisitor. + if (hasPreviousText) + text << "\n"; + text << relationToText(symbolTable, info); + hasPreviousText = true; + } + return text.str(); +} + +std::string outputSchemaSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kSchema) + continue; + + if (hasPreviousText) + text << "\n"; + + const auto* schema = + ANY_CAST(const ::substrait::proto::NamedStruct*, info.blob); + text << "schema " << info.name << " {\n"; + int idx = 0; + while (idx < schema->names_size() && + idx < schema->struct_().types_size()) { + // TODO -- Handle potential whitespace in the names here or elsewhere. + text << " " << schema->names(idx); + text << " " << typeToText(schema->struct_().types(idx)); + text << ";\n"; + ++idx; + } + text << "}\n"; + hasPreviousText = true; + } + return text.str(); +} + +std::string outputSourceSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kSource) + continue; + + if (hasPreviousText) + text << "\n"; + auto subtype = ANY_CAST(SourceType, info.subtype); + switch (subtype) { + case SourceType::kNamedTable: { + auto table = + ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + case SourceType::kLocalFiles: { + // TODO: Put handling this into the PlanPrinterVisitor. + auto files = + ANY_CAST(const ::substrait::proto::ReadRel_LocalFiles*, info.blob); + text << "source local_files " << info.name << " {\n"; + text << " items = [\n"; + for (const auto& item : files->items()) { + text << " {"; + localFileToText(item, &text); + text << "}\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + case SourceType::kVirtualTable: + SUBSTRAIT_FAIL("Printing of virtual tables not yet implemented."); + case SourceType::kExtensionTable: + SUBSTRAIT_FAIL("Printing of extension tables not yet implemented."); + case SourceType::kUnknown: + default: + SUBSTRAIT_FAIL("Printing of an unknown read source requested."); + } } + return text.str(); +} +std::string outputFunctionsSection(const SymbolTable& symbolTable) { std::stringstream text; std::map space_names; @@ -26,7 +258,7 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { if (info.type != SymbolType::kExtensionSpace) continue; - auto anchor = std::any_cast(info.blob); + auto anchor = ANY_CAST(uint32_t, info.blob); space_names.insert(std::make_pair(anchor, info.name)); } @@ -36,11 +268,11 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { if (info.type != SymbolType::kFunction) continue; - auto extension = - std::any_cast<::substrait::proto::extensions:: - SimpleExtensionDeclaration_ExtensionFunction>( - info.blob); - used_spaces.insert(extension.extension_uri_reference()); + auto extension = ANY_CAST( + const ::substrait::proto::extensions:: + SimpleExtensionDeclaration_ExtensionFunction*, + info.blob); + used_spaces.insert(extension->extension_uri_reference()); } // Finally output the extensions by space in the order they were encountered. @@ -56,14 +288,15 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { if (info.type != SymbolType::kFunction) continue; - auto extension = - std::any_cast<::substrait::proto::extensions:: - SimpleExtensionDeclaration_ExtensionFunction>( - info.blob); - if (extension.extension_uri_reference() != space) + auto extension = ANY_CAST( + const ::substrait::proto::extensions:: + SimpleExtensionDeclaration_ExtensionFunction*, + info.blob); + if (extension->extension_uri_reference() != space) continue; - text << " function " << extension.name() << " as " << info.name << ";\n"; + text << " function " << extension->name() << " as " << info.name + << ";\n"; } text << "}\n"; } @@ -71,4 +304,53 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { return text.str(); } +} // namespace + +std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + + std::string newText = outputPipelinesSection(symbolTable); + if (!newText.empty()) { + text << newText; + hasPreviousText = true; + } + + newText = outputRelationsSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputSchemaSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputSourceSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputFunctionsSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + } + return text.str(); +} + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index 86141109..a00056f0 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -1062,23 +1062,23 @@ std::any BasePlanProtoVisitor::visitExpectedTypeUrl(const std::string& url) { return std::nullopt; } -std::any BasePlanProtoVisitor::visitPlan() { - for (const auto& uri : plan_.extension_uris()) { +std::any BasePlanProtoVisitor::visitPlan(const ::substrait::proto::Plan& plan) { + for (const auto& uri : plan.extension_uris()) { visitExtensionUri(uri); } - for (const auto& extension : plan_.extensions()) { + for (const auto& extension : plan.extensions()) { visitExtension(extension); } - for (const auto& relation : plan_.relations()) { + for (const auto& relation : plan.relations()) { visitPlanRelation(relation); } - if (plan_.has_advanced_extensions()) { - visitAdvancedExtension(plan_.advanced_extensions()); + if (plan.has_advanced_extensions()) { + visitAdvancedExtension(plan.advanced_extensions()); } - for (const auto& url : plan_.expected_type_urls()) { + for (const auto& url : plan.expected_type_urls()) { visitExpectedTypeUrl(url); } return std::nullopt; } -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index 167fd318..451ca63c 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -13,14 +13,11 @@ namespace io::substrait::textplan { // own functionality. class BasePlanProtoVisitor { public: - BasePlanProtoVisitor() = delete; - - explicit BasePlanProtoVisitor(::substrait::proto::Plan plan) - : plan_(std::move(plan)) {} + BasePlanProtoVisitor() = default; // visit() begins the traversal of the entire plan. - virtual void visit() { - visitPlan(); + virtual void visit(const ::substrait::proto::Plan& plan) { + visitPlan(plan); } protected: @@ -187,10 +184,7 @@ class BasePlanProtoVisitor { const ::substrait::proto::extensions::AdvancedExtension& extension); virtual std::any visitExpectedTypeUrl(const std::string& url); - virtual std::any visitPlan(); - - private: - ::substrait::proto::Plan plan_; + virtual std::any visitPlan(const ::substrait::proto::Plan& plan); }; -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index c389776e..2be0c93b 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -7,6 +7,7 @@ #include #include "substrait/common/Exceptions.h" +#include "substrait/proto/ProtoUtils.h" #include "substrait/proto/algebra.pb.h" #include "substrait/proto/plan.pb.h" #include "substrait/textplan/Location.h" @@ -40,10 +41,10 @@ std::any InitialPlanProtoVisitor::visitExtension( shortName(extension.extension_function().name())); symbol_table_->defineSymbol( unique_name, - Location(ProtoLocation()), + Location((::google::protobuf::Message*)&extension.extension_function()), SymbolType::kFunction, std::nullopt, - extension.extension_function()); + &extension.extension_function()); return std::nullopt; } @@ -51,11 +52,108 @@ std::any InitialPlanProtoVisitor::visitExtensionUri( const ::substrait::proto::extensions::SimpleExtensionURI& uri) { symbol_table_->defineSymbol( uri.uri(), - Location(ProtoLocation()), + Location((::google::protobuf::Message*)&uri), SymbolType::kExtensionSpace, std::nullopt, uri.extension_uri_anchor()); return std::nullopt; } -} // namespace io::substrait::textplan \ No newline at end of file +std::any InitialPlanProtoVisitor::visitPlanRelation( + const ::substrait::proto::PlanRel& relation) { + BasePlanProtoVisitor::visitPlanRelation(relation); + std::string name = PlanRelTypeCaseName(relation.rel_type_case()); + auto unique_name = symbol_table_->getUniqueName(name); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&relation), + SymbolType::kPlanRelation, + std::nullopt, + &relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitRelation( + const ::substrait::proto::Rel& relation) { + std::string name = RelTypeCaseName(relation.rel_type_case()); + BasePlanProtoVisitor::visitRelation(relation); + auto unique_name = symbol_table_->getUniqueName(name); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&relation), + SymbolType::kRelation, + relation.rel_type_case(), + &relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitRelationRoot( + const ::substrait::proto::RelRoot& relation) { + BasePlanProtoVisitor::visitRelationRoot(relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitReadRelation( + const ::substrait::proto::ReadRel& relation) { + if (relation.has_base_schema()) { + const std::string& name = symbol_table_->getUniqueName("schema"); + symbol_table_->defineSymbol( + name, + Location((google::protobuf::Message*)&relation.base_schema()), + SymbolType::kSchema, + std::nullopt, + &relation.base_schema()); + } + + return BasePlanProtoVisitor::visitReadRelation(relation); +} + +std::any InitialPlanProtoVisitor::visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("virtual"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kVirtualTable, + &table); + return BasePlanProtoVisitor::visitVirtualTable(table); +} + +std::any InitialPlanProtoVisitor::visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local) { + const auto& unique_name = symbol_table_->getUniqueName("local"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&local), + SymbolType::kSource, + SourceType::kLocalFiles, + &local); + return BasePlanProtoVisitor::visitLocalFiles(local); +} + +std::any InitialPlanProtoVisitor::visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("named"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kNamedTable, + &table); + return BasePlanProtoVisitor::visitNamedTable(table); +} + +std::any InitialPlanProtoVisitor::visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("extensiontable"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kExtensionTable, + &table); + return BasePlanProtoVisitor::visitExtensionTable(table); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.h b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h index c85b030f..cf2de907 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.h +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h @@ -15,8 +15,7 @@ namespace io::substrait::textplan { // process which identifies the prominent symbols and gives them names. class InitialPlanProtoVisitor : public BasePlanProtoVisitor { public: - explicit InitialPlanProtoVisitor(const ::substrait::proto::Plan& plan) - : BasePlanProtoVisitor(plan) { + explicit InitialPlanProtoVisitor() : BasePlanProtoVisitor() { symbol_table_ = std::make_shared(); error_listener_ = std::make_shared(); }; @@ -37,8 +36,25 @@ class InitialPlanProtoVisitor : public BasePlanProtoVisitor { const ::substrait::proto::extensions::SimpleExtensionDeclaration& extension) override; + std::any visitPlanRelation( + const ::substrait::proto::PlanRel& relation) override; + std::any visitRelation(const ::substrait::proto::Rel& relation) override; + std::any visitRelationRoot( + const ::substrait::proto::RelRoot& relation) override; + std::any visitReadRelation( + const ::substrait::proto::ReadRel& relation) override; + + std::any visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table) override; + std::any visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local) override; + std::any visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table) override; + std::any visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table) override; + std::shared_ptr symbol_table_; std::shared_ptr error_listener_; }; -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/ParseBinary.cpp b/src/substrait/textplan/converter/ParseBinary.cpp index b912edb4..af18981e 100644 --- a/src/substrait/textplan/converter/ParseBinary.cpp +++ b/src/substrait/textplan/converter/ParseBinary.cpp @@ -9,14 +9,14 @@ namespace io::substrait::textplan { ParseResult parseBinaryPlan(const ::substrait::proto::Plan& plan) { - InitialPlanProtoVisitor visitor(plan); - visitor.visit(); + InitialPlanProtoVisitor visitor; + visitor.visit(plan); auto symbols = visitor.getSymbolTable(); auto syntaxErrors = visitor.getErrorListener()->getErrorMessages(); std::vector semanticErrors; - PlanPrinterVisitor printer(plan, *symbols); - printer.visit(); + PlanPrinterVisitor printer(*symbols); + printer.visit(plan); auto moreErrors = printer.getErrorListener()->getErrorMessages(); semanticErrors.insert( semanticErrors.end(), moreErrors.begin(), moreErrors.end()); diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp index 3559709f..59da0144 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -4,25 +4,189 @@ #include #include +#include #include +#include "substrait/proto/ProtoUtils.h" #include "substrait/proto/algebra.pb.h" -#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/Any.h" namespace io::substrait::textplan { -void PlanPrinterVisitor::visit() { - symbol_table_->addCachedOutput(std::any_cast(visitPlan())); +std::string PlanPrinterVisitor::printRelation( + const std::string& symbolName, + const ::substrait::proto::Rel* relation) { + std::stringstream text; + + text << RelTypeCaseName(relation->rel_type_case()) << " relation " + << symbolName << " {\n"; + auto symbol = symbol_table_->lookupSymbolByLocation( + Location((google::protobuf::Message*)&relation)); + if (symbol != SymbolTable::kUnknownSymbol) { + text << " source " << symbol.name << ";\n"; + } + auto result = this->visitRelation(*relation); + if (result.type() != typeid(std::string)) { + return "ERROR: Relation subtype " + + RelTypeCaseName((relation->rel_type_case())) + " not supported!\n"; + } + text << ANY_CAST(std::string, result); + text << "}\n"; + + return text.str(); +} + +std::any PlanPrinterVisitor::visitAggregateFunction( + const ::substrait::proto::AggregateFunction& function) { + return std::string("AF-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitExpression( + const ::substrait::proto::Expression& expression) { + return std::string("EXPR-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression) { + return std::string("MASKEXPR-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitReadRelation( + const ::substrait::proto::ReadRel& relation) { + std::stringstream text; + if (relation.has_base_schema()) { + const auto& symbol = symbol_table_->lookupSymbolByLocation( + Location((google::protobuf::Message*)&relation)); + if (symbol != SymbolTable::kUnknownSymbol) { + text << " base_schema " << symbol.name << ";\n"; + } + } + if (relation.has_filter()) { + text << " filter " + << ANY_CAST(std::string, visitExpression(relation.filter())) + ";\n"; + } + if (relation.has_best_effort_filter()) { + text << " filter " + << ANY_CAST( + std::string, visitExpression(relation.best_effort_filter())) + << ";\n"; + } + if (relation.has_projection()) { + text << " projection " + << ANY_CAST(std::string, visitMaskExpression(relation.projection())) + << ";\n"; + } + + return text.str(); +} + +std::any PlanPrinterVisitor::visitFilterRelation( + const ::substrait::proto::FilterRel& relation) { + std::stringstream text; + if (relation.has_condition()) { + text << " condition " + << ANY_CAST(std::string, visitExpression(relation.condition())) + << ";\n"; + } + return text.str(); +} + +std::any PlanPrinterVisitor::visitFetchRelation( + const ::substrait::proto::FetchRel& relation) { + std::stringstream text; + if (relation.offset() != 0) { + text << " offset " << std::to_string(relation.offset()) << ";\n"; + } + text << " count " << std::to_string(relation.count()) << ";\n"; + return text.str(); +} + +std::any PlanPrinterVisitor::visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation) { + std::stringstream text; + for (const auto& group : relation.groupings()) { + for (const auto& expr : group.grouping_expressions()) { + text << " grouping " << ANY_CAST(std::string, visitExpression(expr)) + << ";\n"; + } + } + for (const auto& measure : relation.measures()) { + if (!measure.has_measure()) + continue; + text << " measure {\n"; + text << " measure " + << ANY_CAST(std::string, visitAggregateFunction(measure.measure())) + << ";\n"; + if (measure.has_filter()) { + text << " filter " + + ANY_CAST(std::string, visitExpression(measure.filter())) + << ";\n"; + } + text << " }\n"; + } + return text.str(); } -std::any PlanPrinterVisitor::visitPlan() { - return std::string(); +std::any PlanPrinterVisitor::visitSortRelation( + const ::substrait::proto::SortRel& relation) { + std::stringstream text; + for (const auto& sort : relation.sorts()) { + text << " sort " << ANY_CAST(std::string, visitExpression(sort.expr())); + switch (sort.sort_kind_case()) { + case ::substrait::proto::SortField::kDirection: + switch (sort.direction()) { + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + text << " by ASC_NULLS_FIRST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + text << " by ASC_NULLS_LAST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + text << " by DESC_NULLS_FIRST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + text << " by DESC_NULLS_LAST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_CLUSTERED: + text << " by CLUSTERED"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED: + default: + break; + } + break; + case ::substrait::proto::SortField::kComparisonFunctionReference: { + auto field = symbol_table_->nthSymbolByType( + sort.comparison_function_reference(), SymbolType::kFunction); + if (field == SymbolTable::kUnknownSymbol) { + return field.name; + } else { + return "functionref#" + + std::to_string(sort.comparison_function_reference()); + } + } + case ::substrait::proto::SortField::SORT_KIND_NOT_SET: + break; + } + text << ";\n"; + } + return text.str(); } -std::any PlanPrinterVisitor::visitRelationRoot( - const ::substrait::proto::RelRoot& relation) { - BasePlanProtoVisitor::visitRelationRoot(relation); - return std::nullopt; +std::any PlanPrinterVisitor::visitProjectRelation( + const ::substrait::proto::ProjectRel& relation) { + std::stringstream text; + for (const auto& expr : relation.expressions()) { + text << " expression " << ANY_CAST(std::string, visitExpression(expr)) + << ";\n"; + } + return text.str(); } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.h b/src/substrait/textplan/converter/PlanPrinterVisitor.h index 03f67435..ec787c25 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.h +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.h @@ -14,10 +14,7 @@ namespace io::substrait::textplan { class PlanPrinterVisitor : public BasePlanProtoVisitor { public: // PlanPrinterVisitor takes ownership of the provided symbol table. - PlanPrinterVisitor( - const ::substrait::proto::Plan& plan, - const SymbolTable& symbol_table) - : BasePlanProtoVisitor(plan) { + explicit PlanPrinterVisitor(const SymbolTable& symbol_table) { symbol_table_ = std::make_shared(symbol_table); error_listener_ = std::make_shared(); }; @@ -31,13 +28,31 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { return error_listener_; }; - void visit() override; + std::string printRelation( + const std::string& symbolName, + const ::substrait::proto::Rel* relation); private: - std::any visitPlan() override; - - std::any visitRelationRoot( - const ::substrait::proto::RelRoot& relation) override; + std::any visitAggregateFunction( + const ::substrait::proto::AggregateFunction& function) override; + std::any visitExpression( + const ::substrait::proto::Expression& expression) override; + std::any visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression) + override; + + std::any visitReadRelation( + const ::substrait::proto::ReadRel& relation) override; + std::any visitFilterRelation( + const ::substrait::proto::FilterRel& relation) override; + std::any visitFetchRelation( + const ::substrait::proto::FetchRel& relation) override; + std::any visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation) override; + std::any visitSortRelation( + const ::substrait::proto::SortRel& relation) override; + std::any visitProjectRelation( + const ::substrait::proto::ProjectRel& relation) override; std::shared_ptr symbol_table_; std::shared_ptr error_listener_; diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index 06229e7b..172b8efc 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -3,12 +3,14 @@ #include #include -#include "substrait/textplan/converter/InitialPlanProtoVisitor.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" #include "substrait/textplan/tests/ParseResultMatchers.h" namespace io::substrait::textplan { + +using ::testing::Eq; + namespace { class TestCase { @@ -32,93 +34,163 @@ std::vector GetTestCases() { { "empty plan", "", - SerializesTo(""), + WhenSerialized(Eq("")), }, { "empty extension space", - "extension_uris: {\n" - " extension_uri_anchor: 42;\n" - " uri: \"http://life@everything\",\n" - "}", - SerializesTo(""), + R"(extension_uris: { + extension_uri_anchor: 42; + uri: "http://life@everything", + })", + WhenSerialized(Eq("")), }, { "used extension space", - "extension_uris: {\n" - " extension_uri_anchor: 42;\n" - " uri: \"http://life@everything\",\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 42\n" - " function_anchor: 5\n" - " name: \"sum:fp64_fp64\"\n" - " }\n" - "}\n", - SerializesTo("extension_space http://life@everything {\n" - " function sum:fp64_fp64 as sum;\n" - "}\n"), + R"(extension_uris: { + extension_uri_anchor: 42; + uri: "http://life@everything", + } + extensions: { + extension_function: { + extension_uri_reference: 42 + function_anchor: 5 + name: "sum:fp64_fp64" + } + })", + WhenSerialized(EqSquashingWhitespace( + R"(extension_space http://life@everything { + function sum:fp64_fp64 as sum; + })")), }, { "seven extensions, no uris", - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 4\n" - " name: \"lte:fp64_fp64\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 5\n" - " name: \"sum:fp64_fp64\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 3\n" - " name: \"lt:fp64_fp64\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 0\n" - " name: \"is_not_null:fp64\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 1\n" - " name: \"and:bool_bool\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 2\n" - " name: \"gte:fp64_fp64\"\n" - " }\n" - "}\n" - "extensions: {\n" - " extension_function: {\n" - " extension_uri_reference: 0\n" - " function_anchor: 6\n" - " name: \"multiply:opt_fp64_fp64\"\n" - " }\n" - "}\n", - SerializesTo("extension_space {\n" - " function lte:fp64_fp64 as lte;\n" - " function sum:fp64_fp64 as sum;\n" - " function lt:fp64_fp64 as lt;\n" - " function is_not_null:fp64 as is_not_null;\n" - " function and:bool_bool as and;\n" - " function gte:fp64_fp64 as gte;\n" - " function multiply:opt_fp64_fp64 as multiply;\n" - "}\n"), + R"(extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 4 + name: "lte:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 5 + name: "sum:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 3 + name: "lt:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 0 + name: "is_not_null:fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 1 + name: "and:bool_bool" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 2 + name: "gte:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 6 + name: "multiply:opt_fp64_fp64" + } + })", + WhenSerialized(EqSquashingWhitespace( + R"(extension_space { + function lte:fp64_fp64 as lte; + function sum:fp64_fp64 as sum; + function lt:fp64_fp64 as lt; + function is_not_null:fp64 as is_not_null; + function and:bool_bool as and; + function gte:fp64_fp64 as gte; + function multiply:opt_fp64_fp64 as multiply; + })")), + }, + { + "read local files", + R"(relations: { + root: { + input: { + read: { + local_files { + items { + partition_index: 0 + start: 0 + length: 3719 + uri_file: "/mock_lineitem.orc" + orc {} + } + } + } + } + } + })", + AllOf( + HasSymbols({"local", "read", "root"}), + WhenSerialized(EqSquashingWhitespace( + R"(read relation read { + } + + source local_files local { + items = [ + {uri_file: "/mock_lineitem.orc" start: 0 length: 3719 orc: {}} + ] + })"))), + }, + { + "read named table", + "relations: { root: { input: { read: { base_schema {} named_table { names: \"#2\" } } } } }", + AllOf( + HasSymbols({"schema", "named", "read", "root"}), + WhenSerialized(EqSquashingWhitespace( + R"(read relation read { + } + + schema schema { + } + + source named_table named { + names = [ + "#2", + ] + })"))), + }, + { + "single three node pipeline", + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", + HasSymbols({"local", "read", "project", "root"}), + }, + { + "two identical three node pipelines", + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }" + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", + HasSymbols( + {"local", + "read", + "project", + "root", + "local2", + "read2", + "project2", + "root2"}), }, }; return cases; @@ -175,8 +247,17 @@ TEST_F(BinaryToTextPlanConversionTest, loadFromJSON) { "and", "gte", "multiply", + + "schema", + "local", + + "read", + "filter", + "project", + "aggregate", + "root", })); } } // namespace -} // namespace io::substrait::textplan \ No newline at end of file +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/tests/CMakeLists.txt b/src/substrait/textplan/converter/tests/CMakeLists.txt index 37596780..8bad053d 100644 --- a/src/substrait/textplan/converter/tests/CMakeLists.txt +++ b/src/substrait/textplan/converter/tests/CMakeLists.txt @@ -24,4 +24,4 @@ add_custom_command( "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" ) -message(STATUS "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json") \ No newline at end of file +message(STATUS "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json") diff --git a/src/substrait/textplan/tests/CMakeLists.txt b/src/substrait/textplan/tests/CMakeLists.txt index 62a59b7c..fb0bbb8a 100644 --- a/src/substrait/textplan/tests/CMakeLists.txt +++ b/src/substrait/textplan/tests/CMakeLists.txt @@ -9,3 +9,15 @@ target_link_libraries( parse_result_matchers parse_result gmock) + +add_test_case( + symbol_table_test + SOURCES + SymbolTableTest.cpp + EXTRA_LINK_LIBS + symbol_table + substrait_proto + fmt::fmt-header-only + gmock + gtest + gtest_main) diff --git a/src/substrait/textplan/tests/ParseResultMatchers.cpp b/src/substrait/textplan/tests/ParseResultMatchers.cpp index 4b117b1b..f6d9d0ff 100644 --- a/src/substrait/textplan/tests/ParseResultMatchers.cpp +++ b/src/substrait/textplan/tests/ParseResultMatchers.cpp @@ -25,6 +25,45 @@ std::vector symbolNames( return names; } +bool StringEqSquashingWhitespace( + const std::string& have, + const std::string& expected) { + auto atHave = have.begin(); + auto atExpected = expected.begin(); + while (atHave != have.end() && atExpected != expected.end()) { + if (isspace(*atExpected)) { + if (!isspace(*atHave)) { + return false; + } + // Have a match, consume all remaining space. + do { + atExpected++; + } while (atExpected != expected.end() && isspace(*atExpected)); + do { + atHave++; + } while (atHave != have.end() && isspace(*atHave)); + continue; + } + if (*atHave != *atExpected) { + return false; + } + atHave++; + atExpected++; + } + // For convenience consume any trailing whitespace on both sides. + if (atExpected != expected.end()) { + do { + atExpected++; + } while (atExpected != expected.end() && isspace(*atExpected)); + } + if (atHave != have.end()) { + do { + atHave++; + } while (atHave != have.end() && isspace(*atHave)); + } + return atHave == have.end() && atExpected == expected.end(); +} + } // namespace class ParsesOkMatcher { @@ -117,39 +156,39 @@ ::testing::Matcher HasSymbols( return HasSymbolsMatcher(std::move(expected_symbols)); } -class SerializesToMatcher { +class WhenSerializedMatcher { public: using is_gtest_matcher = void; - explicit SerializesToMatcher(std::string expected_result) - : expected_result_(std::move(expected_result)) {} + explicit WhenSerializedMatcher( + ::testing::Matcher string_matcher) + : string_matcher_(std::move(string_matcher)) {} - bool MatchAndExplain(const ParseResult& result, std::ostream* listener) - const { + bool MatchAndExplain( + const ParseResult& result, + ::testing::MatchResultListener* listener) const { std::string outputText = SymbolTablePrinter::outputToText(result.getSymbolTable()); - if (listener) { - *listener << "has output text \"" << outputText << "\""; - } - return outputText == expected_result_; + return MatchPrintAndExplain(outputText, string_matcher_, listener); } - void DescribeTo(std::ostream* os) const { - *os << "reparses to: " << ::testing::PrintToString(expected_result_); + void DescribeTo(::std::ostream* os) const { + *os << "matches after serializing "; + string_matcher_.DescribeTo(os); } - void DescribeNegationTo(std::ostream* os) const { - *os << "does not reparse to: " - << ::testing::PrintToString(expected_result_); + void DescribeNegationTo(::std::ostream* os) const { + *os << "does not match after serializing "; + string_matcher_.DescribeTo(os); } private: - const std::string expected_result_; + ::testing::Matcher string_matcher_; }; -::testing::Matcher SerializesTo( - std::string expected_symbols) { - return SerializesToMatcher(std::move(expected_symbols)); +::testing::Matcher WhenSerialized( + ::testing::Matcher string_matcher) { + return WhenSerializedMatcher(std::move(string_matcher)); } class HasErrorsMatcher { @@ -183,4 +222,35 @@ ::testing::Matcher HasErrors( return HasErrorsMatcher(std::move(expected_errors)); } -} // namespace io::substrait::textplan \ No newline at end of file +class EqSquashingWhitespaceMatcher { + public: + using is_gtest_matcher = void; + + explicit EqSquashingWhitespaceMatcher(std::string expected_string) + : expected_string_(std::move(expected_string)) {} + + bool MatchAndExplain(const std::string& str, std::ostream* /* listener */) + const { + return StringEqSquashingWhitespace(str, expected_string_); + } + + void DescribeTo(std::ostream* os) const { + *os << "equals squashing whitespace " + << ::testing::PrintToString(expected_string_); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not equal squashing whitespace " + << ::testing::PrintToString(expected_string_); + } + + private: + std::string expected_string_; +}; + +::testing::Matcher EqSquashingWhitespace( + std::string expected_string) { + return EqSquashingWhitespaceMatcher(std::move(expected_string)); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/tests/ParseResultMatchers.h b/src/substrait/textplan/tests/ParseResultMatchers.h index 82f841d7..96f4707d 100644 --- a/src/substrait/textplan/tests/ParseResultMatchers.h +++ b/src/substrait/textplan/tests/ParseResultMatchers.h @@ -14,10 +14,15 @@ namespace io::substrait::textplan { [[maybe_unused]] ::testing::Matcher HasSymbols( std::vector expected_symbols); -[[maybe_unused]] ::testing::Matcher SerializesTo( - std::string expected_result); +[[maybe_unused]] ::testing::Matcher WhenSerialized( + ::testing::Matcher string_matcher); [[maybe_unused]] ::testing::Matcher HasErrors( std::vector expected_errors); +// Matches strings ignoring differences in kinds of whitespace (as long as they +// are present) and ignoring trailing whitespace as well. +[[maybe_unused]] ::testing::Matcher EqSquashingWhitespace( + std::string expected_string); + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/tests/SymbolTableTest.cpp b/src/substrait/textplan/tests/SymbolTableTest.cpp new file mode 100644 index 00000000..68e3aa4d --- /dev/null +++ b/src/substrait/textplan/tests/SymbolTableTest.cpp @@ -0,0 +1,143 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/SymbolTable.h" +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/Any.h" +#include "substrait/textplan/Location.h" + +#include +#include + +namespace io::substrait::textplan { +namespace { + +class SymbolTableTest : public ::testing::Test { + public: + SymbolTableTest() : UnspecifiedLocation(Location::kUnknownLocation){}; + + protected: + static std::vector symbolNames( + const std::vector>& symbols) { + std::vector names; + for (const auto& symbol : symbols) { + names.push_back(symbol->name); + } + return names; + } + + static SymbolTable createSimpleTable(::substrait::proto::Plan* plan) { + auto* ptr1 = plan->add_relations(); + auto* ptr2 = plan->add_extension_uris(); + auto* ptr3 = plan->add_extensions(); + + SymbolTable table; + table.defineSymbol( + "symbol1", + Location(ptr1), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr1); + table.defineSymbol( + "symbol2", + Location(ptr2), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr2); + table.defineSymbol( + "symbol3", + Location(ptr3), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr3); + return table; + } + + const Location UnspecifiedLocation; +}; + +TEST_F(SymbolTableTest, DuplicateSymbolsNotDetected) { + SymbolTable table; + table.defineSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + table.defineSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + + ASSERT_THAT( + symbolNames(table.getSymbols()), ::testing::ElementsAre("a", "a")); +} + +TEST_F(SymbolTableTest, DuplicateSymbolsHandledByUnique) { + SymbolTable table; + table.defineUniqueSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + table.defineUniqueSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + + ASSERT_THAT( + symbolNames(table.getSymbols()), ::testing::ElementsAre("a", "a2")); +} + +TEST_F(SymbolTableTest, LocationsUnchangedAfterCopy) { + ::substrait::proto::Plan plan; + SymbolTable table = createSimpleTable(&plan); + auto* ptr1 = &plan.relations(0); + auto* ptr2 = plan.mutable_extension_uris(0); + auto* ptr3 = &plan.extensions(0); + + SymbolTable table2 = table; + auto symbols = table2.getSymbols(); + ASSERT_THAT( + symbolNames(symbols), + ::testing::ElementsAre("symbol1", "symbol2", "symbol3")); + + ASSERT_THAT( + ANY_CAST(::substrait::proto::PlanRel*, symbols[0]->blob), + ::testing::Eq(ptr1)); + ASSERT_THAT( + ANY_CAST( + ::substrait::proto::extensions::SimpleExtensionURI*, + symbols[1]->blob), + ::testing::Eq(ptr2)); + ASSERT_THAT( + ANY_CAST( + ::substrait::proto::extensions::SimpleExtensionDeclaration*, + symbols[2]->blob), + ::testing::Eq(ptr3)); + + ASSERT_THAT(symbols[0]->location, ::testing::Eq(symbols[0]->location)); + ASSERT_THAT( + symbols[0]->location, + ::testing::Not(::testing::Eq(symbols[1]->location))); + ASSERT_THAT( + symbols[0]->location, + ::testing::Not(::testing::Eq(symbols[2]->location))); + ASSERT_THAT( + symbols[1]->location, + ::testing::Not(::testing::Eq(symbols[2]->location))); + + ASSERT_THAT( + table.getSymbols()[0]->location, ::testing::Eq(symbols[0]->location)); + ASSERT_THAT( + table.getSymbols()[1]->location, ::testing::Eq(symbols[1]->location)); + ASSERT_THAT( + table.getSymbols()[2]->location, ::testing::Eq(symbols[2]->location)); +} + +} // namespace +} // namespace io::substrait::textplan