diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..930d78ff --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +--- +Checks: "*, \ + -abseil-*, \ + -cert-env33-c, \ + -cert-err58-cpp, \ + -clang-diagnostic-padded, \ + -clang-analyzer-deadcode.DeadStores, \ + -cppcoreguidelines-avoid-magic-numbers, \ + -cppcoreguidelines-pro-bounds-constant-array-index, \ + -cppcoreguidelines-pro-bounds-pointer-arithmetic, \ + -cppcoreguidelines-pro-type-reinterpret-cast, \ + -cppcoreguidelines-no-malloc, \ + -cppcoreguidelines-owning-memory, \ + -cppcoreguidelines-macro-usage, \ + -cppcoreguidelines-pro-type-vararg, \ + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, \ + -fuchsia-overloaded-operator, \ + -fuchsia-default-arguments, \ + -fuchsia-multiple-inheritance, \ + -fuchsia-default-arguments-calls, \ + -fuchsia-trailing-return, \ + -fuchsia-default-arguments-declarations, \ + -fuchsia-statically-constructed-objects, \ + -google-runtime-references, \ + -google-runtime-int, \ + -google-explicit-constructor, \ + -hicpp-no-malloc, \ + -hicpp-vararg, \ + -hicpp-invalid-access-moved, \ + -hicpp-no-array-decay, \ + -hicpp-signed-bitwise, \ + -llvm-header-guard, \ + -modernize-use-trailing-return-type, \ + -misc-definitions-in-headers, \ + -misc-unused-alias-decls, \ + -modernize-concat-nested-namespaces, \ + -modernize-raw-string-literal, \ + -readability-magic-numbers" +HeaderFilterRegex: "" +... \ No newline at end of file diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index c1b14dac..b9e5a44a 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 name: Ubuntu Build & Test on: @@ -7,6 +8,22 @@ on: branches: [main] jobs: + check: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - name: Setup Ubuntu + run: ./scripts/setup-ubuntu.sh + - name: Check License Header + uses: apache/skywalking-eyes/header@v0.4.0 + - name: Check CMake files + run: find . \( -name '*.cmake' -o -name 'CMakeLists.txt' \) -exec cmake-format $* {} + + - name: Clang-tidy + run: python3 scripts/run-clang-tidy.py "." "build" "third_party,scripts,docker,cmake_modules" "h,hpp,cc,cpp" + - run: mkdir build build: runs-on: ubuntu-latest diff --git a/.licenserc.yaml b/.licenserc.yaml new file mode 100644 index 00000000..a8f6def1 --- /dev/null +++ b/.licenserc.yaml @@ -0,0 +1,22 @@ +header: + license: + spdx-id: Apache-2.0 + content: | + SPDX-License-Identifier: Apache-2.0 + + paths-ignore: + - '.github' + - '.gitignore' + - '.gitmodules' + - '.clang-format' + - '.clang-tidy' + - '.licenserc.yaml' + - 'third_party/fmt' + - 'third_party/googletest' + - 'third_party/substrait' + - 'third_party/yaml-cpp' + - '**/*.md' + - '**/*.json' + - '**/*.log' + + comment: never \ No newline at end of file diff --git a/cmake_modules/BuildUtils.cmake b/cmake_modules/BuildUtils.cmake index 90e0b470..2a059fa5 100644 --- a/cmake_modules/BuildUtils.cmake +++ b/cmake_modules/BuildUtils.cmake @@ -41,5 +41,6 @@ function(ADD_TEST_CASE TEST_NAME) add_dependencies(${TEST_NAME} ${ARG_EXTRA_DEPENDENCIES}) endif() - add_test(NAME ${TEST_NAME} COMMAND $) + add_test(NAME ${TEST_NAME} COMMAND $ + WORKING_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests") endfunction() diff --git a/include/substrait/common/Exceptions.h b/include/substrait/common/Exceptions.h index 23b97f1f..d05b92f4 100644 --- a/include/substrait/common/Exceptions.h +++ b/include/substrait/common/Exceptions.h @@ -2,9 +2,9 @@ #pragma once -#include #include #include +#include namespace io::substrait::common { namespace error_code { @@ -39,12 +39,14 @@ class SubstraitException : public std::exception { // objects. kUser = 0, - // Errors where the root cause of the problem is some unreliable aspect of the - // system are classified with SYSTEM. + // Errors where the root cause of the problem is some unreliable aspect of + // the system are classified with SYSTEM. kSystem = 1 }; - SubstraitException( + const char* file, + size_t line, + const char* function, const std::string& exceptionCode, const std::string& exceptionMessage, Type exceptionType = Type::kSystem, @@ -62,10 +64,16 @@ class SubstraitException : public std::exception { class SubstraitUserError : public SubstraitException { public: SubstraitUserError( + const char* file, + size_t line, + const char* function, const std::string& exceptionCode, const std::string& exceptionMessage, const std::string& exceptionName = "SubstraitUserError") : SubstraitException( + file, + line, + function, exceptionCode, exceptionMessage, Type::kUser, @@ -75,10 +83,16 @@ class SubstraitUserError : public SubstraitException { class SubstraitRuntimeError final : public SubstraitException { public: SubstraitRuntimeError( + const char* file, + size_t line, + const char* function, const std::string& exceptionCode, const std::string& exceptionMessage, const std::string& exceptionName = "SubstraitRuntimeError") : SubstraitException( + file, + line, + function, exceptionCode, exceptionMessage, Type::kSystem, @@ -90,10 +104,10 @@ std::string errorMessage(fmt::string_view fmt, const Args&... args) { return fmt::vformat(fmt, fmt::make_format_args(args...)); } -#define SUBSTRAIT_THROW(exception, errorCode, ...) \ - { \ - auto message = ::io::substrait::common::errorMessage(__VA_ARGS__); \ - throw exception(errorCode, message); \ +#define SUBSTRAIT_THROW(exception, errorCode, ...) \ + { \ + auto message = ::io::substrait::common::errorMessage(__VA_ARGS__); \ + throw exception(__FILE__, __LINE__, __FUNCTION__, errorCode, message); \ } #define SUBSTRAIT_UNSUPPORTED(...) \ diff --git a/scripts/run-clang-tidy.py b/scripts/run-clang-tidy.py new file mode 100644 index 00000000..58f0659b --- /dev/null +++ b/scripts/run-clang-tidy.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 + +# Run clang-tidy recursively and parallel on directory +# Usage: run-clang-tidy sourcedir builddir excludedirs extensions +# extensions and excludedirs are specified as comma-separated +# string without dot, e.g. 'c,cpp' +# e.g. run-clang-tidy . build test,other c,cpp file + +import os, sys, subprocess, multiprocessing +manager = multiprocessing.Manager() +failedfiles = manager.list() + +# Get absolute current path and remove trailing seperators +currentdir = os.path.realpath(os.getcwd()).rstrip(os.sep) +print("Arguments: " + str(sys.argv)) +# Get absolute source dir after removing leading and trailing seperators from input. +sourcedir = currentdir + sys.argv[1].lstrip(os.sep).rstrip(os.sep) +print("Source directory: " + sourcedir) +builddir = sourcedir + os.sep + sys.argv[2].rstrip(os.sep) +print("Build directory: " + builddir) +# Split exclude dirs into a tuple +excludedirs = tuple([(sourcedir + os.sep + s).rstrip(os.sep) for s in sys.argv[3].split(',')]) +# If the build directory is not the same as the source directory, exclude it +if not sourcedir == builddir: + excludedirs = excludedirs + (builddir,) +print("Exclude directories: " + str(excludedirs)) +# Split extensions into a tuple +extensions = tuple([("." + s) for s in sys.argv[4].split(',')]) +print("Extensions: " + str(extensions)) + +def runclangtidy(filepath): + print("Checking: " + filepath) + proc = subprocess.Popen("clang-tidy --quiet -p=" + builddir + " " + filepath, shell=True) + if proc.wait() != 0: + failedfiles.append(filepath) + +def collectfiles(dir, exclude, exts): + collectedfiles = [] + for root, dirs, files in os.walk(dir): + for file in files: + filepath = root + os.sep + file + if (len(exclude) == 0 or not filepath.startswith(exclude)) and filepath.endswith(exts): + collectedfiles.append(filepath) + return collectedfiles + +# Define the pool AFTER the global variables and subprocess function because multiprocessing +# has stricter requirements on member ordering +# See: https://stackoverflow.com/questions/41385708/multiprocessing-example-giving-attributeerror +pool = multiprocessing.Pool() +pool.map(runclangtidy, collectfiles(sourcedir, excludedirs, extensions)) +pool.close() +pool.join() +if len(failedfiles) > 0: + print("Errors in " + len(failedfiles) + " files") + sys.exit(1) +print("No errors found") +sys.exit(0) \ No newline at end of file diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh index 16e74f7c..c9b8a2c9 100755 --- a/scripts/setup-ubuntu.sh +++ b/scripts/setup-ubuntu.sh @@ -19,8 +19,11 @@ sudo --preserve-env apt install -y \ ccache \ ninja-build \ checkinstall \ + clang-tidy \ git \ wget \ libprotobuf-dev \ libprotobuf23 \ protobuf-compiler + +pip install cmake-format \ No newline at end of file diff --git a/src/substrait/CMakeLists.txt b/src/substrait/CMakeLists.txt index 2a2e145b..67f694f6 100644 --- a/src/substrait/CMakeLists.txt +++ b/src/substrait/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(common) add_subdirectory(type) add_subdirectory(function) add_subdirectory(proto) +add_subdirectory(textplan) diff --git a/src/substrait/common/Exceptions.cpp b/src/substrait/common/Exceptions.cpp index 15537cf4..6c86f788 100644 --- a/src/substrait/common/Exceptions.cpp +++ b/src/substrait/common/Exceptions.cpp @@ -6,19 +6,22 @@ namespace io::substrait::common { SubstraitException::SubstraitException( + const char* file, + size_t line, + const char* function, const std::string& exceptionCode, const std::string& exceptionMessage, Type exceptionType, const std::string& exceptionName) : msg_(fmt::format( - "Exception: {}\nError Code: {}\nType: {}\nReason: {}\n" - "Function: {}\nFile: {}\n:Line: {}\n", + "Exception: {}\nError Code: {}\nError Type: {}\nReason: {}\n" + "Function: {}\nLocation: {}(Line:{})\n", exceptionName, exceptionCode, exceptionType == Type::kSystem ? "system" : "user", exceptionMessage, - __FUNCTION__, - __FILE__, - std::to_string(__LINE__))) {} + function, + file, + std::to_string(line))) {} } // namespace io::substrait::common diff --git a/src/substrait/proto/CMakeLists.txt b/src/substrait/proto/CMakeLists.txt index 458033d3..f3fdcac9 100644 --- a/src/substrait/proto/CMakeLists.txt +++ b/src/substrait/proto/CMakeLists.txt @@ -41,7 +41,8 @@ set(PROTO_SRCS) set(PROTO_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}") cmake_path(GET PROTO_OUTPUT_DIR PARENT_PATH PROTO_OUTPUT_PARENT_DIR) -cmake_path(GET PROTO_OUTPUT_PARENT_DIR PARENT_PATH PROTO_OUTPUT_TOPLEVEL_DIR) +cmake_path(GET PROTO_OUTPUT_PARENT_DIR PARENT_PATH PROTO_OUTPUT_MIDLEVEL_DIR) +cmake_path(GET PROTO_OUTPUT_MIDLEVEL_DIR PARENT_PATH PROTO_OUTPUT_TOPLEVEL_DIR) foreach(PROTO_FILE IN LISTS PROTOBUF_FILELIST) file(RELATIVE_PATH RELATIVE_PROTO_PATH "${CMAKE_SOURCE_DIR}/third_party/substrait/proto/substrait" "${PROTO_FILE}") @@ -56,7 +57,7 @@ foreach(PROTO_FILE IN LISTS PROTOBUF_FILELIST) OUTPUT ${PROTO_SRC} ${PROTO_HDR} COMMAND protobuf::protoc "--proto_path=${GENERATED_PROTO_TOPLEVEL_DIR}" - "--cpp_out=${PROTO_OUTPUT_TOPLEVEL_DIR}" + "--cpp_out=${PROTO_OUTPUT_MIDLEVEL_DIR}" ${GENERATED_PROTO_FILE} DEPENDS ${GENERATED_PROTOBUF_LIST} protobuf::protoc COMMENT "Generated C++ protobuf module for ${PROTO_FILE}" @@ -66,7 +67,15 @@ foreach(PROTO_FILE IN LISTS PROTOBUF_FILELIST) endforeach() # Add the generated protobuf C++ files to our exported library. -add_library(substrait_proto ${PROTO_SRCS}) +add_library( + substrait_proto + ${PROTO_SRCS} ${PROTO_HDRS} + ProtoUtils.cpp ProtoUtils.h) + +# Include the protobuf library as a dependency to use this class. +target_link_libraries(substrait_proto ${PROTOBUF_LIBRARIES}) # Make sure we can see our own generated include files. -target_include_directories(substrait_proto PUBLIC "${PROTO_OUTPUT_TOPLEVEL_DIR}") +target_include_directories( + substrait_proto + PUBLIC "${PROTO_OUTPUT_TOPLEVEL_DIR}/src") diff --git a/src/substrait/proto/ProtoUtils.cpp b/src/substrait/proto/ProtoUtils.cpp new file mode 100644 index 00000000..0af55aa4 --- /dev/null +++ b/src/substrait/proto/ProtoUtils.cpp @@ -0,0 +1,47 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/proto/ProtoUtils.h" + +namespace substrait::proto { + +std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num) { + static std::vector case_names = { + "unknown", + "rel", + "root", + }; + + if (num >= case_names.size()) { + return "unknown"; + } + + return case_names[num]; +} + +std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num) { + static std::vector case_names = { + "unknown", + "read", + "filter", + "fetch", + "aggregate", + "sort", + "join", + "project", + "set", + "extensionsingle", + "extensionmulti", + "extensionleaf", + "cross", + "hashjoin", + "mergejoin", + }; + + if (num >= case_names.size()) { + return "unknown"; + } + + return case_names[num]; +} + +} // namespace substrait::proto \ No newline at end of file diff --git a/src/substrait/proto/ProtoUtils.h b/src/substrait/proto/ProtoUtils.h new file mode 100644 index 00000000..bcae6ad4 --- /dev/null +++ b/src/substrait/proto/ProtoUtils.h @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/plan.pb.h" + +namespace substrait::proto { + +std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num); + +std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num); + +} // namespace substrait::proto \ No newline at end of file diff --git a/src/substrait/proto/update_proto_package.pl b/src/substrait/proto/update_proto_package.pl index 259f035c..d9eb0bba 100755 --- a/src/substrait/proto/update_proto_package.pl +++ b/src/substrait/proto/update_proto_package.pl @@ -1,4 +1,5 @@ #!/bin/perl -w +# SPDX-License-Identifier: Apache-2.0 # Renames package declarations for protobuffers from substrait to substrait.proto. # This allows us to modify where the generated C++ have their definitions without diff --git a/src/substrait/textplan/Any.h b/src/substrait/textplan/Any.h new file mode 100644 index 00000000..eb84a778 --- /dev/null +++ b/src/substrait/textplan/Any.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "fmt/format.h" + +namespace io::substrait::textplan { + +template +inline ValueType any_cast(const std::any& value, const char* file, int line) { + try { + return std::any_cast(value); + } catch (std::bad_any_cast& ex) { + throw std::invalid_argument( + fmt::format("{}:{} - {}", file, line, "bad any cast")); + } +} + +// A wrapper around std::any_cast that provides exceptions with line numbers. +#define ANY_CAST(ValueType, Value) \ + ::io::substrait::textplan::any_cast(Value, __FILE__, __LINE__) + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt new file mode 100644 index 00000000..2018a1f3 --- /dev/null +++ b/src/substrait/textplan/CMakeLists.txt @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_subdirectory(converter) + +add_library(symbol_table + Location.cpp Location.h + SymbolTable.cpp SymbolTable.h + SymbolTablePrinter.cpp SymbolTablePrinter.h + Any.h) + +add_library(error_listener + SubstraitErrorListener.cpp SubstraitErrorListener.h + ) + +add_library(parse_result + ParseResult.cpp ParseResult.h + ) + +add_dependencies(symbol_table + substrait_proto + substrait_common + fmt::fmt-header-only + ) + +target_link_libraries( + symbol_table + fmt::fmt-header-only + ) + +# Provide access to the generated protobuffer headers hierarchy. +target_include_directories( + symbol_table + PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/../..") + +if (${SUBSTRAIT_CPP_BUILD_TESTING}) + add_subdirectory(tests) +endif () diff --git a/src/substrait/textplan/Location.cpp b/src/substrait/textplan/Location.cpp new file mode 100644 index 00000000..22d218f9 --- /dev/null +++ b/src/substrait/textplan/Location.cpp @@ -0,0 +1,68 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/Location.h" + +#include + +namespace io::substrait::textplan { + +constexpr const Location Location::kUnknownLocation( + static_cast(nullptr)); + +bool operator==(const Location& c1, const Location& c2) { + // Test only one side since we only store one kind of content per table. + if (std::holds_alternative(c1.loc_)) { + if (!std::holds_alternative(c2.loc_)) { + return false; + } + auto a1 = std::get(c1.loc_); + auto a2 = std::get(c2.loc_); + return a1 == a2; + } else if (std::holds_alternative(c1.loc_)) { + if (!std::holds_alternative(c2.loc_)) { + return false; + } + auto a1 = std::get(c1.loc_); + auto a2 = std::get(c2.loc_); + return a1 == a2; + } + // Should not be reached. + return false; +} + +} // namespace io::substrait::textplan + +std::size_t std::hash<::io::substrait::textplan::Location>::operator()( + const ::io::substrait::textplan::Location& loc) const noexcept { + if (std::holds_alternative(loc.loc_)) { + return std::hash()( + std::get(loc.loc_)); + } else if (std::holds_alternative(loc.loc_)) { + return std::hash()( + std::get(loc.loc_)); + } + // Should not be reached. + return 0; +} + +bool std::less<::io::substrait::textplan::Location>::operator()( + const ::io::substrait::textplan::Location& lhs, + const ::io::substrait::textplan::Location& rhs) const noexcept { + if (std::holds_alternative(lhs.loc_)) { + if (!std::holds_alternative(rhs.loc_)) { + // This alternative is always less than the remaining choices. + return true; + } + return std::get(lhs.loc_) < + std::get(rhs.loc_); + } else if (std::holds_alternative(lhs.loc_)) { + if (!std::holds_alternative(rhs.loc_)) { + // This alternative is always less than the remaining (zero) choices. + return true; + } + return std::get(lhs.loc_) < + std::get(rhs.loc_); + } + // Should not be reached. + return false; +} diff --git a/src/substrait/textplan/Location.h b/src/substrait/textplan/Location.h new file mode 100644 index 00000000..d60ac455 --- /dev/null +++ b/src/substrait/textplan/Location.h @@ -0,0 +1,54 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include + +namespace antlr4 { +class ParserRuleContext; +} + +namespace google::protobuf { +class Message; +}; + +namespace io::substrait::textplan { + +// Location is used for keeping track of where a symbol is within a parse tree. +// Since SymbolTable supports both antlr4 and protobuf messages there are +// essentially two flavors of location. It is expected that only one type of +// location would be used in any SymbolTable instance. +class Location { + public: + constexpr explicit Location(antlr4::ParserRuleContext* node) : loc_(node) {} + + constexpr explicit Location(google::protobuf::Message* msg) : loc_(msg) {} + + static const Location kUnknownLocation; + + protected: + friend bool operator==(const Location& c1, const Location& c2); + + private: + friend std::hash; + friend std::less; + + std::variant loc_; +}; + +} // namespace io::substrait::textplan + +template <> +struct std::hash<::io::substrait::textplan::Location> { + std::size_t operator()( + const ::io::substrait::textplan::Location& loc) const noexcept; +}; + +template <> +struct std::less<::io::substrait::textplan::Location> { + bool operator()( + const ::io::substrait::textplan::Location& lhs, + const ::io::substrait::textplan::Location& rhs) const noexcept; +}; diff --git a/src/substrait/textplan/ParseResult.cpp b/src/substrait/textplan/ParseResult.cpp new file mode 100644 index 00000000..d539e103 --- /dev/null +++ b/src/substrait/textplan/ParseResult.cpp @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/ParseResult.h" + +#include +#include + +namespace io::substrait::textplan { + +std::ostream& operator<<(std::ostream& os, const ParseResult& result) { + if (result.successful()) { + os << *result.symbol_table_; + } + auto msgs = result.getSyntaxErrors(); + if (!msgs.empty()) { + os << "{" << std::endl; + for (const std::string& msg : msgs) { + os << " \"" << msg << "\"," << std::endl; + } + os << "}"; + } + msgs = result.getSemanticErrors(); + if (!msgs.empty()) { + os << "{" << std::endl; + for (const std::string& msg : msgs) { + os << " \"" << msg << "\"," << std::endl; + } + os << "}"; + } + return os; +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/ParseResult.h b/src/substrait/textplan/ParseResult.h new file mode 100644 index 00000000..fb1d29be --- /dev/null +++ b/src/substrait/textplan/ParseResult.h @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include + +#include "substrait/textplan/SymbolTable.h" + +namespace io::substrait::textplan { + +// ParseResult contains the result of a parse (from text to binary) or the +// conversion (from binary to text). The symbol table contains nearly +// all the information necessary to reconstruct either form. +class ParseResult { + public: + ParseResult( + SymbolTable symbolTable, + std::vector syntaxErrors, + std::vector semanticErrors) { + symbol_table_ = std::make_shared(std::move(symbolTable)); + syntax_errors_ = std::move(syntaxErrors); + semantic_errors_ = std::move(semanticErrors); + } + + [[nodiscard]] bool successful() const { + return syntax_errors_.empty() && semantic_errors_.empty(); + } + + [[nodiscard]] const SymbolTable& getSymbolTable() const { + return *symbol_table_; + } + + [[nodiscard]] const std::vector& getSyntaxErrors() const { + return syntax_errors_; + } + + [[nodiscard]] const std::vector& getSemanticErrors() const { + return semantic_errors_; + } + + [[nodiscard]] std::vector getAllErrors() const { + std::vector errors; + errors.insert(errors.end(), syntax_errors_.begin(), syntax_errors_.end()); + errors.insert( + errors.end(), semantic_errors_.begin(), semantic_errors_.end()); + return errors; + } + + // Add the capability for ::testing::PrintToString to print ParseResult. + friend std::ostream& operator<<(std::ostream& os, const ParseResult& result); + + private: + std::shared_ptr symbol_table_; + std::vector syntax_errors_; + std::vector semantic_errors_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SubstraitErrorListener.cpp b/src/substrait/textplan/SubstraitErrorListener.cpp new file mode 100644 index 00000000..e2a6203e --- /dev/null +++ b/src/substrait/textplan/SubstraitErrorListener.cpp @@ -0,0 +1,28 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "SubstraitErrorListener.h" + +#include +#include + +namespace io::substrait::textplan { + +void SubstraitErrorListener::addError( + size_t linenum, + size_t charnum, + const std::string& msg) { + errors_.push_back({{linenum, charnum}, msg}); +} + +std::vector SubstraitErrorListener::getErrorMessages() { + std::vector messages; + for (const auto& instance : getErrors()) { + messages.push_back( + std::to_string(instance.location.line) + ":" + + std::to_string(instance.location.char_position_in_line) + " → " + + instance.message); + } + return messages; +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SubstraitErrorListener.h b/src/substrait/textplan/SubstraitErrorListener.h new file mode 100644 index 00000000..e523bc6f --- /dev/null +++ b/src/substrait/textplan/SubstraitErrorListener.h @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include + +namespace io::substrait::textplan { + +struct ErrorLocation { + size_t line; + size_t char_position_in_line; +}; + +struct ErrorInstance { + ErrorLocation location; + std::string message; +}; + +// SubstraitErrorListener is similar in behavior to an antlr4::ErrorListener to +// provide a similar error collection methodology regardless of how the input +// data is obtained. +class SubstraitErrorListener { + public: + SubstraitErrorListener() = default; + + void addError(size_t linenum, size_t charnum, const std::string& msg); + + const std::vector& getErrors() { + return errors_; + }; + + bool hasErrors() { + return !errors_.empty(); + } + + std::vector getErrorMessages(); + + private: + std::vector errors_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTable.cpp b/src/substrait/textplan/SymbolTable.cpp new file mode 100644 index 00000000..446f9bea --- /dev/null +++ b/src/substrait/textplan/SymbolTable.cpp @@ -0,0 +1,114 @@ +/* SPDX-License-Identifier: Apache-2.0 */ +#include +#include +#include + +#include "SymbolTable.h" + +#include "substrait/textplan/Location.h" + +namespace io::substrait::textplan { + +bool SymbolTableIterator::operator!=(SymbolTableIterator const& other) const { + return index_ != other.index_; +} + +SymbolInfo const& SymbolTableIterator::operator*() const { + return *table_->symbols_[index_]; +} + +SymbolTableIterator SymbolTableIterator::operator++() { + ++index_; + return *this; +} + +bool operator==(const SymbolInfo& left, const SymbolInfo& right) { + return (left.name == right.name) && (left.location == right.location) && + (left.type == right.type); +} + +bool operator!=(const SymbolInfo& left, const SymbolInfo& right) { + return !(left == right); +} + +std::string SymbolTable::getUniqueName(const std::string& base_name) { + auto symbolSeenCount = names_.find(base_name); + if (symbolSeenCount == names_.end()) { + names_.insert(std::make_pair(base_name, 1)); + return base_name; + } + int32_t count = symbolSeenCount->second + 1; + symbolSeenCount->second = count; + return base_name + std::to_string(count); +} + +SymbolInfo* SymbolTable::defineSymbol( + const std::string& name, + const Location& location, + SymbolType type, + const std::any& subtype, + const std::any& blob) { + // TODO -- Detect attempts to reuse the same symbol. + std::shared_ptr info = + std::make_shared(name, location, type, subtype, blob); + symbols_.push_back(std::move(info)); + symbols_by_name_.insert(std::make_pair(name, symbols_.size() - 1)); + symbols_by_location_.insert(std::make_pair(location, symbols_.size() - 1)); + + return info.get(); +} + +SymbolInfo* SymbolTable::defineUniqueSymbol( + const std::string& name, + const Location& location, + SymbolType type, + const std::any& subtype, + const std::any& blob) { + std::string unique_name = getUniqueName(name); + return defineSymbol(unique_name, location, type, subtype, blob); +} + +const SymbolInfo& SymbolTable::lookupSymbolByName(const std::string& name) { + auto itr = symbols_by_name_.find(name); + if (itr == symbols_by_name_.end()) { + return kUnknownSymbol; + } + return *symbols_[itr->second]; +} + +const SymbolInfo& SymbolTable::lookupSymbolByLocation( + const Location& location) { + auto itr = symbols_by_location_.find(location); + if (itr == symbols_by_location_.end()) { + return kUnknownSymbol; + } + return *symbols_[itr->second]; +} + +const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type) { + int count = 0; + for (const auto& symbol : symbols_) { + if (symbol->type == type) { + if (n == count++) + return *symbol; + } + } + return kUnknownSymbol; +} + +SymbolTableIterator SymbolTable::begin() const { + return {this, 0}; +} + +SymbolTableIterator SymbolTable::end() const { + return {this, symbols_by_name_.size()}; +} + +const SymbolInfo SymbolTable::kUnknownSymbol = { + "__UNKNOWN__", + Location::kUnknownLocation, + SymbolType::kUnknown, + std::nullopt, + std::nullopt}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTable.h b/src/substrait/textplan/SymbolTable.h new file mode 100644 index 00000000..33ed5a90 --- /dev/null +++ b/src/substrait/textplan/SymbolTable.h @@ -0,0 +1,176 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/textplan/Location.h" + +namespace io::substrait::textplan { + +enum class SymbolType { + kExtensionSpace = 0, + kFunction = 1, + kPlanRelation = 2, + kRelation = 3, + kRelationDetail = 4, + kSchema = 5, + kSchemaColumn = 6, + kSource = 7, + kSourceDetail = 8, + kField = 9, + + kUnknown = -1, +}; + +enum class RelationType { + kUnknown = 0, + kRead = 1, + kProject = 2, + kJoin = 3, + kCross = 4, + kFetch = 5, + kAggregate = 6, + kSort = 7, + kFilter = 8, + kSet = 9, + kExchange = 10, + kDdl = 11, + kWrite = 12, + kHashJoin = 13, + kMergeJoin = 14, + kReference = 15, + + kExtensionLeaf = 100, + kExtensionSingle = 101, + kExtensionMulti = 102, +}; + +enum class RelationDetailType { + kUnknown = 0, + kExpression = 1, +}; + +enum class SourceType { + kUnknown = 0, + kLocalFiles = 1, + kNamedTable = 2, + kVirtualTable = 3, + kExtensionTable = 4, +}; + +struct SymbolInfo { + std::string name; + Location location; + SymbolType type; + std::any subtype; + std::any blob; + + SymbolInfo( + std::string new_name, + Location new_location, + SymbolType new_type, + std::any new_subtype, + std::any new_blob) + : name(std::move(new_name)), + location(new_location), + type(new_type), + subtype(std::move(new_subtype)), + blob(std::move(new_blob)){}; + + friend bool operator==(const SymbolInfo& left, const SymbolInfo& right); + friend bool operator!=(const SymbolInfo& left, const SymbolInfo& right); +}; + +class SymbolTable; + +class SymbolTableIterator { + public: + bool operator!=(SymbolTableIterator const& other) const; + + const SymbolInfo& operator*() const; + + SymbolTableIterator operator++(); + + private: + friend SymbolTable; + + SymbolTableIterator(const SymbolTable* table, size_t start) + : table_(table), index_(start){}; + + size_t index_; + const SymbolTable* table_; +}; + +class SymbolTable { + public: + // If the given symbol is not yet defined, returns that symbol. Otherwise + // it returns a modified version of the symbol (by adding a number) so that + // it is unique. + std::string getUniqueName(const std::string& base_name); + + // Adds a new symbol to the symbol table. + SymbolInfo* defineSymbol( + const std::string& name, + const Location& location, + SymbolType type, + const std::any& subtype, + const std::any& blob); + + // Convenience function that defines a symbol by first calling getUniqueName. + SymbolInfo* defineUniqueSymbol( + const std::string& name, + const Location& location, + SymbolType type, + const std::any& subtype, + const std::any& blob); + + const SymbolInfo& lookupSymbolByName(const std::string& name); + + const SymbolInfo& lookupSymbolByLocation(const Location& location); + + const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type); + + [[nodiscard]] SymbolTableIterator begin() const; + + [[nodiscard]] SymbolTableIterator end() const; + + [[nodiscard]] const std::vector>& getSymbols() + const { + return symbols_; + }; + + // Add the capability for ::testing::PrintToString to print this. + friend std::ostream& operator<<(std::ostream& os, const SymbolTable& result) { + os << std::string("{"); + bool outputFirst = false; + for (const auto& symbol : result.getSymbols()) { + if (outputFirst) { + os << std::string(", "); + } + os << symbol->name; + outputFirst = true; + } + os << std::string("}"); + return os; + } + + static const SymbolInfo kUnknownSymbol; + + private: + friend SymbolTableIterator; + + std::unordered_map names_; + + std::vector> symbols_; + std::unordered_map symbols_by_name_; + std::unordered_map symbols_by_location_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp new file mode 100644 index 00000000..96d38e18 --- /dev/null +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -0,0 +1,356 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "SymbolTablePrinter.h" + +#include +#include + +#include "substrait/common/Exceptions.h" +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/extensions/extensions.pb.h" +#include "substrait/textplan/Any.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/converter/PlanPrinterVisitor.h" + +namespace io::substrait::textplan { + +namespace { + +void localFileToText( + const ::substrait::proto::ReadRel::LocalFiles::FileOrFiles& item, + std::stringstream* text) { + switch (item.path_type_case()) { + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriPath: + *text << "uri_path: \"" << item.uri_path() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: + *text << "uri_path_glob: \"" + item.uri_path_glob() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriFile: + *text << "uri_file: \"" << item.uri_file() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kUriFolder: + *text << "uri_folder: \"" << item.uri_folder() << "\""; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles:: + PATH_TYPE_NOT_SET: + default: + SUBSTRAIT_UNSUPPORTED( + "Unknown path type " + std::to_string(item.path_type_case()) + + "provided."); + } + if (item.partition_index() != 0) { + *text << " partition_index: " << std::to_string(item.partition_index()); + } + if (item.start() != 0 || item.length() != 0) { + *text << " start: " << std::to_string(item.start()); + *text << " length: " << std::to_string(item.length()); + } + switch (item.file_format_case()) { + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kParquet: + *text << " parquet: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kArrow: + *text << " arrow: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kOrc: + *text << " orc: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kExtension: + SUBSTRAIT_UNSUPPORTED("Extension file format type not yet supported."); + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles::kDwrf: + *text << " dwrf: {}"; + break; + case ::substrait::proto::ReadRel::LocalFiles::FileOrFiles:: + FILE_FORMAT_NOT_SET: + default: + SUBSTRAIT_UNSUPPORTED( + "Unknown file format type " + + std::to_string(item.file_format_case()) + "provided."); + } +} + +std::string typeToText(const ::substrait::proto::Type& type) { + switch (type.kind_case()) { + case ::substrait::proto::Type::kBool: + if (type.bool_().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "bool?"; + return "bool"; + case ::substrait::proto::Type::kI8: + if (type.i8().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i8?"; + return "i8"; + case ::substrait::proto::Type::kI16: + if (type.i16().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i16?"; + return "i16"; + case ::substrait::proto::Type::kI32: + if (type.i32().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i32?"; + return "i32"; + case ::substrait::proto::Type::kI64: + if (type.i64().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "i64?"; + return "i64"; + case ::substrait::proto::Type::kFp32: + if (type.fp32().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fp32?"; + return "fp32"; + case ::substrait::proto::Type::kFp64: + if (type.fp64().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fp64?"; + return "fp64"; + case ::substrait::proto::Type::kString: + if (type.string().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "string?"; + return "string"; + case ::substrait::proto::Type::kDecimal: + if (type.string().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "decimal?"; + return "decimal"; + case ::substrait::proto::Type::kVarchar: + if (type.varchar().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "varchar?"; + return "varchar"; + case ::substrait::proto::Type::kFixedChar: + if (type.fixed_char().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "fixedchar?"; + return "fixedchar"; + case ::substrait::proto::Type::kDate: + if (type.date().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) + return "date?"; + return "date"; + case ::substrait::proto::Type::KIND_NOT_SET: + default: + return "UNSUPPORTED_TYPE"; + } +}; + +std::string relationToText( + const SymbolTable& symbolTable, + const SymbolInfo& info) { + auto relation = ANY_CAST(const ::substrait::proto::Rel*, info.blob); + + PlanPrinterVisitor printer(symbolTable); + return printer.printRelation(info.name, relation); +} + +std::string outputPipelinesSection(const SymbolTable& symbolTable) { + // TODO: Implement. + return ""; +} + +std::string outputRelationsSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kRelation) + continue; + // TODO: Put handling this into the PlanPrinterVisitor. + if (hasPreviousText) + text << "\n"; + text << relationToText(symbolTable, info); + hasPreviousText = true; + } + return text.str(); +} + +std::string outputSchemaSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kSchema) + continue; + + if (hasPreviousText) + text << "\n"; + + const auto* schema = + ANY_CAST(const ::substrait::proto::NamedStruct*, info.blob); + text << "schema " << info.name << " {\n"; + int idx = 0; + while (idx < schema->names_size() && + idx < schema->struct_().types_size()) { + // TODO -- Handle potential whitespace in the names here or elsewhere. + text << " " << schema->names(idx); + text << " " << typeToText(schema->struct_().types(idx)); + text << ";\n"; + ++idx; + } + text << "}\n"; + hasPreviousText = true; + } + return text.str(); +} + +std::string outputSourceSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kSource) + continue; + + if (hasPreviousText) + text << "\n"; + auto subtype = ANY_CAST(SourceType, info.subtype); + switch (subtype) { + case SourceType::kNamedTable: { + auto table = + ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + case SourceType::kLocalFiles: { + // TODO: Put handling this into the PlanPrinterVisitor. + auto files = + ANY_CAST(const ::substrait::proto::ReadRel_LocalFiles*, info.blob); + text << "source local_files " << info.name << " {\n"; + text << " items = [\n"; + for (const auto& item : files->items()) { + text << " {"; + localFileToText(item, &text); + text << "}\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + case SourceType::kVirtualTable: + SUBSTRAIT_FAIL("Printing of virtual tables not yet implemented."); + case SourceType::kExtensionTable: + SUBSTRAIT_FAIL("Printing of extension tables not yet implemented."); + case SourceType::kUnknown: + default: + SUBSTRAIT_FAIL("Printing of an unknown read source requested."); + } + } + return text.str(); +} + +std::string outputFunctionsSection(const SymbolTable& symbolTable) { + std::stringstream text; + + std::map space_names; + std::set used_spaces; + + // Look at the existing spaces. + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kExtensionSpace) + continue; + + auto anchor = ANY_CAST(uint32_t, info.blob); + + space_names.insert(std::make_pair(anchor, info.name)); + } + + // Find any spaces that are used but undefined. + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kFunction) + continue; + + auto extension = ANY_CAST( + const ::substrait::proto::extensions:: + SimpleExtensionDeclaration_ExtensionFunction*, + info.blob); + used_spaces.insert(extension->extension_uri_reference()); + } + + // Finally output the extensions by space in the order they were encountered. + for (const uint32_t space : used_spaces) { + if (space_names.find(space) == space_names.end()) { + // TODO: Handle this case as a warning. + text << "extension_space {\n"; + } else { + text << "extension_space " << space_names[space] << " {\n"; + } + + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kFunction) + continue; + + auto extension = ANY_CAST( + const ::substrait::proto::extensions:: + SimpleExtensionDeclaration_ExtensionFunction*, + info.blob); + if (extension->extension_uri_reference() != space) + continue; + + text << " function " << extension->name() << " as " << info.name + << ";\n"; + } + text << "}\n"; + } + + return text.str(); +} + +} // namespace + +std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + + std::string newText = outputPipelinesSection(symbolTable); + if (!newText.empty()) { + text << newText; + hasPreviousText = true; + } + + newText = outputRelationsSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputSchemaSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputSourceSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + + newText = outputFunctionsSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + } + return text.str(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.h b/src/substrait/textplan/SymbolTablePrinter.h new file mode 100644 index 00000000..34d3d22b --- /dev/null +++ b/src/substrait/textplan/SymbolTablePrinter.h @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "SymbolTable.h" + +namespace io::substrait::textplan { + +class SymbolTablePrinter { + public: + static std::string outputToText(const SymbolTable& symbolTable); +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp new file mode 100644 index 00000000..a00056f0 --- /dev/null +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -0,0 +1,1084 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/BasePlanProtoVisitor.h" + +#include +#include +#include + +#include "substrait/common/Exceptions.h" +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait::textplan { + +std::any BasePlanProtoVisitor::visitSubqueryScalar( + const ::substrait::proto::Expression_Subquery_Scalar& query) { + if (query.has_input()) { + visitRelation(query.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSubqueryInPredicate( + const ::substrait::proto::Expression_Subquery_InPredicate& query) { + if (query.has_haystack()) { + visitRelation(query.haystack()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSubquerySetPredicate( + const ::substrait::proto::Expression_Subquery_SetPredicate& query) { + if (query.has_tuples()) { + visitRelation(query.tuples()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSubquerySetComparison( + const ::substrait::proto::Expression_Subquery_SetComparison& query) { + if (query.has_left()) { + visitExpression(query.left()); + } + if (query.has_right()) { + visitRelation(query.right()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNestedStruct( + const ::substrait::proto::Expression_Nested_Struct& nested) { + for (const auto& field : nested.fields()) { + visitExpression(field); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNestedList( + const ::substrait::proto::Expression_Nested_List& nested) { + for (const auto& value : nested.values()) { + visitExpression(value); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNestedMap( + const ::substrait::proto::Expression_Nested_Map& nested) { + for (const auto& kv : nested.key_values()) { + visitNestedKeyValue(kv); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNestedKeyValue( + const ::substrait::proto::Expression_Nested_Map_KeyValue& kv) { + if (kv.has_key()) { + visitExpression(kv.key()); + } + if (kv.has_value()) { + visitExpression(kv.value()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitStructItem( + const ::substrait::proto::Expression_MaskExpression_StructItem& item) { + if (item.has_child()) { + visitSelect(item.child()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitReferenceSegmentMapKey( + const ::substrait::proto::Expression_ReferenceSegment_MapKey& mapkey) { + if (mapkey.has_map_key()) { + visitLiteral(mapkey.map_key()); + } + if (mapkey.has_child()) { + visitReferenceSegment(mapkey.child()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitReferenceSegmentStructField( + const ::substrait::proto::Expression_ReferenceSegment_StructField& + structure) { + if (structure.has_child()) { + visitReferenceSegment(structure.child()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitReferenceSegmentListElement( + const ::substrait::proto::Expression_ReferenceSegment_ListElement& + element) { + if (element.has_child()) { + visitReferenceSegment(element.child()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSelect( + const ::substrait::proto::Expression_MaskExpression_Select& select) { + switch (select.type_case()) { + case ::substrait::proto::Expression_MaskExpression_Select::kStruct: + return visitStructSelect(select.struct_()); + case ::substrait::proto::Expression_MaskExpression_Select::kList: + return visitListSelect(select.list()); + case ::substrait::proto::Expression_MaskExpression_Select::kMap: + return visitMapSelect(select.map()); + case ::substrait::proto::Expression_MaskExpression_Select::TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported maskexpression select type encountered: " + + std::to_string(select.type_case())); +} + +std::any BasePlanProtoVisitor::visitType(const ::substrait::proto::Type& type) { + switch (type.kind_case()) { + case ::substrait::proto::Type::kBool: + case ::substrait::proto::Type::kI8: + case ::substrait::proto::Type::kI16: + case ::substrait::proto::Type::kI32: + case ::substrait::proto::Type::kI64: + case ::substrait::proto::Type::kFp32: + case ::substrait::proto::Type::kFp64: + case ::substrait::proto::Type::kString: + case ::substrait::proto::Type::kBinary: + case ::substrait::proto::Type::kTimestamp: + case ::substrait::proto::Type::kDate: + case ::substrait::proto::Type::kTime: + case ::substrait::proto::Type::kIntervalYear: + case ::substrait::proto::Type::kIntervalDay: + case ::substrait::proto::Type::kTimestampTz: + case ::substrait::proto::Type::kUuid: + case ::substrait::proto::Type::kFixedChar: + case ::substrait::proto::Type::kVarchar: + case ::substrait::proto::Type::kFixedBinary: + case ::substrait::proto::Type::kDecimal: + return std::nullopt; + case ::substrait::proto::Type::kStruct: + return visitStruct(type.struct_()); + case ::substrait::proto::Type::kList: + return visitTypeList(type.list()); + case ::substrait::proto::Type::kMap: + return visitTypeMap(type.map()); + case ::substrait::proto::Type::kUserDefined: + return visitTypeUserDefined(type.user_defined()); + case ::substrait::proto::Type::kUserDefinedTypeReference: + SUBSTRAIT_UNSUPPORTED( + "user_defined_type_reference was replaced by user_defined_type. Please update your plan version."); + case ::substrait::proto::Type::KIND_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported type kind encountered: " + std::to_string(type.kind_case())); +} + +std::any BasePlanProtoVisitor::visitTypeUserDefined( + const ::substrait::proto::Type_UserDefined& type) { + for (const auto& parameter : type.type_parameters()) { + visitTypeParameter(parameter); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitTypeParameter( + const ::substrait::proto::Type_Parameter& type) { + switch (type.parameter_case()) { + case ::substrait::proto::Type_Parameter::kNull: + return std::nullopt; + case ::substrait::proto::Type_Parameter::kDataType: + return visitType(type.data_type()); + case ::substrait::proto::Type_Parameter::kBoolean: + case ::substrait::proto::Type_Parameter::kInteger: + case ::substrait::proto::Type_Parameter::kEnum: + case ::substrait::proto::Type_Parameter::kString: + return std::nullopt; + case ::substrait::proto::Type_Parameter::PARAMETER_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported type parameter encountered: " + + std::to_string(type.parameter_case())); +} + +std::any BasePlanProtoVisitor::visitMap( + const ::substrait::proto::Expression_Literal_Map& map) { + for (const auto& kv : map.key_values()) { + visitMapKeyValue(kv); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitMapKeyValue( + const ::substrait::proto::Expression_Literal_Map_KeyValue& kv) { + if (kv.has_key()) { + visitLiteral(kv.key()); + } + if (kv.has_value()) { + visitLiteral(kv.value()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitList( + const ::substrait::proto::Expression_Literal_List& list) { + for (const auto& value : list.values()) { + visitLiteral(value); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitTypeList( + const ::substrait::proto::Type_List& list) { + if (list.has_type()) { + visitType(list.type()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitTypeMap( + const ::substrait::proto::Type_Map& list) { + if (list.has_key()) { + visitType(list.key()); + } + if (list.has_value()) { + visitType(list.value()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitUserDefined( + const ::substrait::proto::Expression_Literal_UserDefined& type) { + for (const auto& parameter : type.type_parameters()) { + visitTypeParameter(parameter); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitFunctionArgument( + const ::substrait::proto::FunctionArgument& argument) { + switch (argument.arg_type_case()) { + case ::substrait::proto::FunctionArgument::kEnum: + return std::nullopt; + case ::substrait::proto::FunctionArgument::kType: + return visitType(argument.type()); + case ::substrait::proto::FunctionArgument::kValue: + return visitExpression(argument.value()); + case ::substrait::proto::FunctionArgument::ARG_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported function argument type encountered: " + + std::to_string(argument.arg_type_case())); +} + +std::any BasePlanProtoVisitor::visitFunctionOption( + const ::substrait::proto::FunctionOption& option) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitRecord( + const ::substrait::proto::Expression_MultiOrList_Record& record) { + for (const auto& field : record.fields()) { + visitExpression(field); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitIfClause( + const ::substrait::proto::Expression_IfThen_IfClause& ifclause) { + if (ifclause.has_if_()) { + visitExpression(ifclause.if_()); + } + if (ifclause.has_then()) { + visitExpression(ifclause.then()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitIfValue( + const ::substrait::proto::Expression_SwitchExpression_IfValue& ifclause) { + if (ifclause.has_if_()) { + visitLiteral(ifclause.if_()); + } + if (ifclause.has_then()) { + visitExpression(ifclause.then()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitStruct( + const ::substrait::proto::Type_Struct& structure) { + for (const auto& t : structure.types()) { + visitType(t); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitLiteral( + const ::substrait::proto::Expression::Literal& literal) { + switch (literal.literal_type_case()) { + case ::substrait::proto::Expression_Literal::kBoolean: + case ::substrait::proto::Expression_Literal::kI8: + case ::substrait::proto::Expression_Literal::kI16: + case ::substrait::proto::Expression_Literal::kI32: + case ::substrait::proto::Expression_Literal::kI64: + case ::substrait::proto::Expression_Literal::kFp32: + case ::substrait::proto::Expression_Literal::kFp64: + case ::substrait::proto::Expression_Literal::kString: + case ::substrait::proto::Expression_Literal::kBinary: + case ::substrait::proto::Expression_Literal::kTimestamp: + case ::substrait::proto::Expression_Literal::kDate: + case ::substrait::proto::Expression_Literal::kTime: + case ::substrait::proto::Expression_Literal::kIntervalYearToMonth: + case ::substrait::proto::Expression_Literal::kIntervalDayToSecond: + case ::substrait::proto::Expression_Literal::kFixedChar: + case ::substrait::proto::Expression_Literal::kVarChar: + case ::substrait::proto::Expression_Literal::kFixedBinary: + case ::substrait::proto::Expression_Literal::kDecimal: + return std::nullopt; + case ::substrait::proto::Expression_Literal::kStruct: + return visitExpressionLiteralStruct(literal.struct_()); + case ::substrait::proto::Expression_Literal::kMap: + return visitMap(literal.map()); + case ::substrait::proto::Expression_Literal::kTimestampTz: + case ::substrait::proto::Expression_Literal::kUuid: + return std::nullopt; + case ::substrait::proto::Expression_Literal::kNull: + return visitType(literal.null()); + case ::substrait::proto::Expression_Literal::kList: + return visitList(literal.list()); + case ::substrait::proto::Expression_Literal::kEmptyList: + return visitTypeList(literal.empty_list()); + case ::substrait::proto::Expression_Literal::kEmptyMap: + return visitTypeMap(literal.empty_map()); + case ::substrait::proto::Expression_Literal::kUserDefined: + return visitUserDefined(literal.user_defined()); + case ::substrait::proto::Expression_Literal::LITERAL_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported literal type encountered: " + + std::to_string(literal.literal_type_case())); +} + +std::any BasePlanProtoVisitor::visitScalarFunction( + const ::substrait::proto::Expression::ScalarFunction& function) { + for (const auto& arg : function.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& arg : function.options()) { + visitFunctionOption(arg); + } + if (function.has_output_type()) { + visitType(function.output_type()); + } + for (const auto& arg : function.args()) { + visitExpression(arg); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitWindowFunction( + const ::substrait::proto::Expression::WindowFunction& function) { + for (const auto& arg : function.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& arg : function.options()) { + visitFunctionOption(arg); + } + if (function.has_output_type()) { + visitType(function.output_type()); + } + for (const auto& sort : function.sorts()) { + visitSortField(sort); + } + for (const auto& partition : function.partitions()) { + visitExpression(partition); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitIfThen( + const ::substrait::proto::Expression::IfThen& ifthen) { + for (const auto& if_ : ifthen.ifs()) { + visitIfClause(if_); + } + if (ifthen.has_else_()) { + visitExpression(ifthen.else_()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSwitchExpression( + const ::substrait::proto::Expression::SwitchExpression& expression) { + if (expression.has_match()) { + visitExpression(expression.match()); + } + for (const auto& if_ : expression.ifs()) { + visitIfValue(if_); + } + if (expression.has_else_()) { + visitExpression(expression.else_()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSingularOrList( + const ::substrait::proto::Expression::SingularOrList& expression) { + if (expression.has_value()) { + visitExpression(expression.value()); + } + for (const auto& option : expression.options()) { + visitExpression(option); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitMultiOrList( + const ::substrait::proto::Expression::MultiOrList& expression) { + for (const auto& value : expression.value()) { + visitExpression(value); + } + for (const auto& option : expression.options()) { + visitRecord(option); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitCast( + const ::substrait::proto::Expression::Cast& cast) { + if (cast.has_input()) { + visitExpression(cast.input()); + } + if (cast.has_type()) { + visitType(cast.type()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSubquery( + const ::substrait::proto::Expression_Subquery& query) { + switch (query.subquery_type_case()) { + case ::substrait::proto::Expression_Subquery::kScalar: + return visitSubqueryScalar(query.scalar()); + case ::substrait::proto::Expression_Subquery::kInPredicate: + return visitSubqueryInPredicate(query.in_predicate()); + case ::substrait::proto::Expression_Subquery::kSetPredicate: + return visitSubquerySetPredicate(query.set_predicate()); + case ::substrait::proto::Expression_Subquery::kSetComparison: + return visitSubquerySetComparison(query.set_comparison()); + case ::substrait::proto::Expression_Subquery::SUBQUERY_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported subquery type encountered: " + + std::to_string(query.subquery_type_case())); +} + +std::any BasePlanProtoVisitor::visitNested( + const ::substrait::proto::Expression_Nested& structure) { + switch (structure.nested_type_case()) { + case ::substrait::proto::Expression_Nested::kStruct: + return visitNestedStruct(structure.struct_()); + case ::substrait::proto::Expression_Nested::kList: + return visitNestedList(structure.list()); + case ::substrait::proto::Expression_Nested::kMap: + return visitNestedMap(structure.map()); + case ::substrait::proto::Expression_Nested::NESTED_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported nested type encountered: " + + std::to_string(structure.nested_type_case())); +} + +std::any BasePlanProtoVisitor::visitEnum( + const ::substrait::proto::Expression_Enum& value) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitStructSelect( + const ::substrait::proto::Expression_MaskExpression_StructSelect& + structure) { + for (const auto& item : structure.struct_items()) { + visitStructItem(item); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitListSelect( + const ::substrait::proto::Expression_MaskExpression_ListSelect& select) { + for (const auto& item : select.selection()) { + visitListSelectItem(item); + } + if (select.has_child()) { + visitSelect(select.child()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitListSelectItem( + const ::substrait::proto:: + Expression_MaskExpression_ListSelect_ListSelectItem& item) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitMapSelect( + const ::substrait::proto::Expression_MaskExpression_MapSelect& select) { + if (select.has_child()) { + visitSelect(select.child()); + } + switch (select.select_case()) { + case ::substrait::proto::Expression_MaskExpression_MapSelect::kKey: + return std::nullopt; + case ::substrait::proto::Expression_MaskExpression_MapSelect::kExpression: + return std::nullopt; + case ::substrait::proto::Expression_MaskExpression_MapSelect:: + SELECT_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported map select type encountered: " + + std::to_string(select.select_case())); +} + +std::any BasePlanProtoVisitor::visitExpressionLiteralStruct( + const ::substrait::proto::Expression_Literal_Struct& structure) { + for (const auto& literal : structure.fields()) { + visitLiteral(literal); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitFileOrFiles( + const ::substrait::proto::ReadRel_LocalFiles_FileOrFiles& structure) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitAggregateFunction( + const ::substrait::proto::AggregateFunction& structure) { + for (const auto& arg : structure.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& option : structure.options()) { + visitFunctionOption(option); + } + if (structure.has_output_type()) { + visitType(structure.output_type()); + } + for (const auto& sort : structure.sorts()) { + visitSortField(sort); + } + for (const auto& arg : structure.args()) { + visitExpression(arg); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitReferenceSegment( + const ::substrait::proto::Expression_ReferenceSegment& segment) { + switch (segment.reference_type_case()) { + case ::substrait::proto::Expression_ReferenceSegment::kMapKey: + return visitReferenceSegmentMapKey(segment.map_key()); + case ::substrait::proto::Expression_ReferenceSegment::kStructField: + return visitReferenceSegmentStructField(segment.struct_field()); + case ::substrait::proto::Expression_ReferenceSegment::kListElement: + return visitReferenceSegmentListElement(segment.list_element()); + case ::substrait::proto::Expression_ReferenceSegment:: + REFERENCE_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported reference segment type encountered: " + + std::to_string(segment.reference_type_case())); +} + +std::any BasePlanProtoVisitor::visitRelationCommon( + const ::substrait::proto::RelCommon& common) { + if (common.has_advanced_extension()) { + visitAdvancedExtension(common.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNamedStruct( + const ::substrait::proto::NamedStruct& named) { + return visitStruct(named.struct_()); +} + +std::any BasePlanProtoVisitor::visitExpression( + const ::substrait::proto::Expression& expression) { + switch (expression.rex_type_case()) { + case ::substrait::proto::Expression::RexTypeCase::kLiteral: + return visitLiteral(expression.literal()); + case ::substrait::proto::Expression::RexTypeCase::kSelection: + return visitFieldReference(expression.selection()); + case ::substrait::proto::Expression::RexTypeCase::kScalarFunction: + return visitScalarFunction(expression.scalar_function()); + case ::substrait::proto::Expression::RexTypeCase::kWindowFunction: + return visitWindowFunction(expression.window_function()); + case ::substrait::proto::Expression::RexTypeCase::kIfThen: + return visitIfThen(expression.if_then()); + case ::substrait::proto::Expression::RexTypeCase::kSwitchExpression: + return visitSwitchExpression(expression.switch_expression()); + case ::substrait::proto::Expression::RexTypeCase::kSingularOrList: + return visitSingularOrList(expression.singular_or_list()); + case ::substrait::proto::Expression::RexTypeCase::kMultiOrList: + return visitMultiOrList(expression.multi_or_list()); + case ::substrait::proto::Expression::RexTypeCase::kCast: + return visitCast(expression.cast()); + case ::substrait::proto::Expression::RexTypeCase::kSubquery: + return visitSubquery(expression.subquery()); + case ::substrait::proto::Expression::RexTypeCase::kNested: + return visitNested(expression.nested()); + case ::substrait::proto::Expression::RexTypeCase::kEnum: + return visitEnum(expression.enum_()); + case ::substrait::proto::Expression::RexTypeCase::REX_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported expression type encountered: " + + std::to_string(expression.rex_type_case())); +} + +std::any BasePlanProtoVisitor::visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression) { + if (expression.has_select()) { + visitStructSelect(expression.select()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table) { + for (const auto& value : table.values()) { + visitExpressionLiteralStruct(value); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local) { + for (const auto& item : local.items()) { + visitFileOrFiles(item); + } + if (local.has_advanced_extension()) { + visitAdvancedExtension(local.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table) { + if (table.has_advanced_extension()) { + visitAdvancedExtension(table.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitGrouping( + const ::substrait::proto::AggregateRel_Grouping& grouping) { + for (const auto& expr : grouping.grouping_expressions()) { + visitExpression(expr); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitMeasure( + const ::substrait::proto::AggregateRel_Measure& measure) { + if (measure.has_measure()) { + visitAggregateFunction(measure.measure()); + } + if (measure.has_filter()) { + visitExpression(measure.filter()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSortField( + const ::substrait::proto::SortField& sort) { + if (sort.has_expr()) { + return visitExpression(sort.expr()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitFieldReference( + const ::substrait::proto::Expression::FieldReference& ref) { + if (ref.has_direct_reference()) { + visitReferenceSegment(ref.direct_reference()); + } + if (ref.has_masked_reference()) { + visitMaskExpression(ref.masked_reference()); + } + if (ref.has_expression()) { + visitExpression(ref.expression()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitReadRelation( + const ::substrait::proto::ReadRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_base_schema()) { + visitNamedStruct(relation.base_schema()); + } + if (relation.has_filter()) { + visitExpression(relation.filter()); + } + if (relation.has_best_effort_filter()) { + visitExpression(relation.best_effort_filter()); + } + if (relation.has_projection()) { + visitMaskExpression(relation.projection()); + } + visitAdvancedExtension(relation.advanced_extension()); + switch (relation.read_type_case()) { + case ::substrait::proto::ReadRel::ReadTypeCase::kVirtualTable: + return visitVirtualTable(relation.virtual_table()); + case ::substrait::proto::ReadRel::ReadTypeCase::kLocalFiles: + return visitLocalFiles(relation.local_files()); + case ::substrait::proto::ReadRel::ReadTypeCase::kNamedTable: + return visitNamedTable(relation.named_table()); + case ::substrait::proto::ReadRel::ReadTypeCase::kExtensionTable: + return visitExtensionTable(relation.extension_table()); + case ::substrait::proto::ReadRel::ReadTypeCase::READ_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported read type encountered: " + + std::to_string(relation.read_type_case())); +} + +std::any BasePlanProtoVisitor::visitFilterRelation( + const ::substrait::proto::FilterRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_condition()) { + visitExpression(relation.condition()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + if (relation.has_input()) { + visitRelation(relation.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitFetchRelation( + const ::substrait::proto::FetchRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + if (relation.has_input()) { + visitRelation(relation.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + for (const auto& grouping : relation.groupings()) { + visitGrouping(grouping); + } + for (const auto& measure : relation.measures()) { + visitMeasure(measure); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + if (relation.has_input()) { + visitRelation(relation.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSortRelation( + const ::substrait::proto::SortRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + for (const auto& sort : relation.sorts()) { + visitSortField(sort); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + if (relation.has_input()) { + visitRelation(relation.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitJoinRelation( + const ::substrait::proto::JoinRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + if (relation.has_expression()) { + visitExpression(relation.expression()); + } + if (relation.has_post_join_filter()) { + visitExpression(relation.post_join_filter()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitProjectRelation( + const ::substrait::proto::ProjectRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& expr : relation.expressions()) { + visitExpression(expr); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitSetRelation( + const ::substrait::proto::SetRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + for (const auto& input : relation.inputs()) { + visitRelation(input); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExtensionSingleRelation( + const ::substrait::proto::ExtensionSingleRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExtensionMultiRelation( + const ::substrait::proto::ExtensionMultiRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + for (const auto& input : relation.inputs()) { + visitRelation(input); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExtensionLeafRelation( + const ::substrait::proto::ExtensionLeafRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitCrossRelation( + const ::substrait::proto::CrossRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitHashJoinRelation( + const ::substrait::proto::HashJoinRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + for (const auto& key : relation.left_keys()) { + visitFieldReference(key); + } + for (const auto& key : relation.right_keys()) { + visitFieldReference(key); + } + if (relation.has_post_join_filter()) { + visitExpression(relation.post_join_filter()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitMergeJoinRelation( + const ::substrait::proto::MergeJoinRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + for (const auto& key : relation.left_keys()) { + visitFieldReference(key); + } + for (const auto& key : relation.right_keys()) { + visitFieldReference(key); + } + if (relation.has_post_join_filter()) { + visitExpression(relation.post_join_filter()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitRelation( + const ::substrait::proto::Rel& relation) { + switch (relation.rel_type_case()) { + case ::substrait::proto::Rel::RelTypeCase::kRead: + return visitReadRelation(relation.read()); + case ::substrait::proto::Rel::RelTypeCase::kFilter: + return visitFilterRelation(relation.filter()); + case ::substrait::proto::Rel::RelTypeCase::kFetch: + return visitFetchRelation(relation.fetch()); + case ::substrait::proto::Rel::RelTypeCase::kAggregate: + return visitAggregateRelation(relation.aggregate()); + case ::substrait::proto::Rel::RelTypeCase::kSort: + return visitSortRelation(relation.sort()); + case ::substrait::proto::Rel::RelTypeCase::kJoin: + return visitJoinRelation(relation.join()); + case ::substrait::proto::Rel::RelTypeCase::kProject: + return visitProjectRelation(relation.project()); + case ::substrait::proto::Rel::RelTypeCase::kSet: + return visitSetRelation(relation.set()); + case ::substrait::proto::Rel::RelTypeCase::kExtensionSingle: + return visitExtensionSingleRelation(relation.extension_single()); + case ::substrait::proto::Rel::RelTypeCase::kExtensionMulti: + return visitExtensionMultiRelation(relation.extension_multi()); + case ::substrait::proto::Rel::RelTypeCase::kExtensionLeaf: + return visitExtensionLeafRelation(relation.extension_leaf()); + case ::substrait::proto::Rel::RelTypeCase::kCross: + return visitCrossRelation(relation.cross()); + case ::substrait::proto::Rel::RelTypeCase::kHashJoin: + return visitHashJoinRelation(relation.hash_join()); + case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: + return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::REL_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported relation type encountered: " + + std::to_string(relation.rel_type_case())); +} + +std::any BasePlanProtoVisitor::visitRelationRoot( + const ::substrait::proto::RelRoot& relation) { + return visitRelation(relation.input()); +} + +std::any BasePlanProtoVisitor::visitExtensionUri( + const ::substrait::proto::extensions::SimpleExtensionURI& uri) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExtension( + const ::substrait::proto::extensions::SimpleExtensionDeclaration& + extension) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitPlanRelation( + const ::substrait::proto::PlanRel& relation) { + switch (relation.rel_type_case()) { + case ::substrait::proto::PlanRel::RelTypeCase::kRel: + return visitRelation(relation.rel()); + case ::substrait::proto::PlanRel::RelTypeCase::kRoot: + return visitRelationRoot(relation.root()); + case ::substrait::proto::PlanRel::RelTypeCase::REL_TYPE_NOT_SET: + break; + } + SUBSTRAIT_UNSUPPORTED( + "Unsupported plan relation type encountered: " + + std::to_string(relation.rel_type_case())); +} + +std::any BasePlanProtoVisitor::visitAdvancedExtension( + const ::substrait::proto::extensions::AdvancedExtension& extension) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExpectedTypeUrl(const std::string& url) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitPlan(const ::substrait::proto::Plan& plan) { + for (const auto& uri : plan.extension_uris()) { + visitExtensionUri(uri); + } + for (const auto& extension : plan.extensions()) { + visitExtension(extension); + } + for (const auto& relation : plan.relations()) { + visitPlanRelation(relation); + } + if (plan.has_advanced_extensions()) { + visitAdvancedExtension(plan.advanced_extensions()); + } + for (const auto& url : plan.expected_type_urls()) { + visitExpectedTypeUrl(url); + } + return std::nullopt; +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h new file mode 100644 index 00000000..451ca63c --- /dev/null +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -0,0 +1,190 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "substrait/proto/plan.pb.h" + +namespace io::substrait::textplan { + +// BasePlanProtoVisitor provides a visitor that will visit all of the major +// messages within a Plan protobuffer object. Subclass the visitor to add your +// own functionality. +class BasePlanProtoVisitor { + public: + BasePlanProtoVisitor() = default; + + // visit() begins the traversal of the entire plan. + virtual void visit(const ::substrait::proto::Plan& plan) { + visitPlan(plan); + } + + protected: + virtual std::any visitSubqueryScalar( + const ::substrait::proto::Expression_Subquery_Scalar& query); + virtual std::any visitSubqueryInPredicate( + const ::substrait::proto::Expression_Subquery_InPredicate& query); + virtual std::any visitSubquerySetPredicate( + const ::substrait::proto::Expression_Subquery_SetPredicate& query); + virtual std::any visitSubquerySetComparison( + const ::substrait::proto::Expression_Subquery_SetComparison& query); + virtual std::any visitNestedStruct( + const ::substrait::proto::Expression_Nested_Struct& nested); + virtual std::any visitNestedList( + const ::substrait::proto::Expression_Nested_List& nested); + virtual std::any visitNestedMap( + const ::substrait::proto::Expression_Nested_Map& nested); + virtual std::any visitNestedKeyValue( + const ::substrait::proto::Expression_Nested_Map_KeyValue& kv); + virtual std::any visitStructItem( + const ::substrait::proto::Expression_MaskExpression_StructItem& item); + virtual std::any visitReferenceSegmentMapKey( + const ::substrait::proto::Expression_ReferenceSegment_MapKey& mapkey); + virtual std::any visitReferenceSegmentStructField( + const ::substrait::proto::Expression_ReferenceSegment_StructField& + structure); + virtual std::any visitReferenceSegmentListElement( + const ::substrait::proto::Expression_ReferenceSegment_ListElement& + element); + virtual std::any visitSelect( + const ::substrait::proto::Expression_MaskExpression_Select& select); + + virtual std::any visitType(const ::substrait::proto::Type& type); + virtual std::any visitMap( + const ::substrait::proto::Expression_Literal_Map& map); + virtual std::any visitMapKeyValue( + const ::substrait::proto::Expression_Literal_Map_KeyValue& kv); + virtual std::any visitList( + const ::substrait::proto::Expression_Literal_List& list); + virtual std::any visitTypeList(const ::substrait::proto::Type_List& list); + virtual std::any visitTypeMap(const ::substrait::proto::Type_Map& list); + virtual std::any visitTypeUserDefined( + const ::substrait::proto::Type_UserDefined& type); + virtual std::any visitTypeParameter( + const ::substrait::proto::Type_Parameter& parameter); + virtual std::any visitUserDefined( + const ::substrait::proto::Expression_Literal_UserDefined& list); + virtual std::any visitFunctionArgument( + const ::substrait::proto::FunctionArgument& argument); + virtual std::any visitFunctionOption( + const ::substrait::proto::FunctionOption& option); + virtual std::any visitIfClause( + const ::substrait::proto::Expression_IfThen_IfClause& ifclause); + virtual std::any visitIfValue( + const ::substrait::proto::Expression_SwitchExpression_IfValue& ifclause); + virtual std::any visitRecord( + const ::substrait::proto::Expression_MultiOrList_Record& record); + + virtual std::any visitStruct( + const ::substrait::proto::Type_Struct& structure); + virtual std::any visitLiteral( + const ::substrait::proto::Expression::Literal& literal); + virtual std::any visitScalarFunction( + const ::substrait::proto::Expression::ScalarFunction& function); + virtual std::any visitWindowFunction( + const ::substrait::proto::Expression::WindowFunction& function); + virtual std::any visitIfThen( + const ::substrait::proto::Expression::IfThen& ifthen); + virtual std::any visitSwitchExpression( + const ::substrait::proto::Expression::SwitchExpression& expression); + virtual std::any visitSingularOrList( + const ::substrait::proto::Expression::SingularOrList& expression); + virtual std::any visitMultiOrList( + const ::substrait::proto::Expression::MultiOrList& expression); + virtual std::any visitCast(const ::substrait::proto::Expression::Cast& cast); + virtual std::any visitSubquery( + const ::substrait::proto::Expression_Subquery& query); + virtual std::any visitNested( + const ::substrait::proto::Expression_Nested& structure); + virtual std::any visitEnum(const ::substrait::proto::Expression_Enum& value); + virtual std::any visitStructSelect( + const ::substrait::proto::Expression_MaskExpression_StructSelect& + structure); + virtual std::any visitListSelect( + const ::substrait::proto::Expression_MaskExpression_ListSelect& select); + virtual std::any visitListSelectItem( + const ::substrait::proto:: + Expression_MaskExpression_ListSelect_ListSelectItem& item); + virtual std::any visitMapSelect( + const ::substrait::proto::Expression_MaskExpression_MapSelect& select); + virtual std::any visitExpressionLiteralStruct( + const ::substrait::proto::Expression_Literal_Struct& structure); + virtual std::any visitFileOrFiles( + const ::substrait::proto::ReadRel_LocalFiles_FileOrFiles& structure); + virtual std::any visitAggregateFunction( + const ::substrait::proto::AggregateFunction& function); + virtual std::any visitReferenceSegment( + const ::substrait::proto::Expression_ReferenceSegment& segment); + + virtual std::any visitRelationCommon( + const ::substrait::proto::RelCommon& common); + virtual std::any visitNamedStruct( + const ::substrait::proto::NamedStruct& named); + virtual std::any visitExpression( + const ::substrait::proto::Expression& expression); + virtual std::any visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression); + virtual std::any visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table); + virtual std::any visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local); + virtual std::any visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table); + virtual std::any visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table); + virtual std::any visitGrouping( + const ::substrait::proto::AggregateRel_Grouping& grouping); + virtual std::any visitMeasure( + const ::substrait::proto::AggregateRel_Measure& measure); + virtual std::any visitSortField(const ::substrait::proto::SortField& sort); + virtual std::any visitFieldReference( + const ::substrait::proto::Expression::FieldReference& ref); + + virtual std::any visitReadRelation( + const ::substrait::proto::ReadRel& relation); + virtual std::any visitFilterRelation( + const ::substrait::proto::FilterRel& relation); + virtual std::any visitFetchRelation( + const ::substrait::proto::FetchRel& relation); + virtual std::any visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation); + virtual std::any visitSortRelation( + const ::substrait::proto::SortRel& relation); + virtual std::any visitJoinRelation( + const ::substrait::proto::JoinRel& relation); + virtual std::any visitProjectRelation( + const ::substrait::proto::ProjectRel& relation); + virtual std::any visitSetRelation(const ::substrait::proto::SetRel& relation); + virtual std::any visitExtensionSingleRelation( + const ::substrait::proto::ExtensionSingleRel& relation); + virtual std::any visitExtensionMultiRelation( + const ::substrait::proto::ExtensionMultiRel& relation); + virtual std::any visitExtensionLeafRelation( + const ::substrait::proto::ExtensionLeafRel& relation); + virtual std::any visitCrossRelation( + const ::substrait::proto::CrossRel& relation); + virtual std::any visitHashJoinRelation( + const ::substrait::proto::HashJoinRel& relation); + virtual std::any visitMergeJoinRelation( + const ::substrait::proto::MergeJoinRel& relation); + + virtual std::any visitRelation(const ::substrait::proto::Rel& relation); + virtual std::any visitRelationRoot( + const ::substrait::proto::RelRoot& relation); + + virtual std::any visitExtensionUri( + const ::substrait::proto::extensions::SimpleExtensionURI& uri); + virtual std::any visitExtension( + const ::substrait::proto::extensions::SimpleExtensionDeclaration& + extension); + virtual std::any visitPlanRelation( + const ::substrait::proto::PlanRel& relation); + virtual std::any visitAdvancedExtension( + const ::substrait::proto::extensions::AdvancedExtension& extension); + virtual std::any visitExpectedTypeUrl(const std::string& url); + + virtual std::any visitPlan(const ::substrait::proto::Plan& plan); +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt new file mode 100644 index 00000000..0a7b66b3 --- /dev/null +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +set(TEXTPLAN_SRCS + InitialPlanProtoVisitor.cpp InitialPlanProtoVisitor.h + BasePlanProtoVisitor.cpp BasePlanProtoVisitor.h + PlanPrinterVisitor.cpp PlanPrinterVisitor.h + LoadBinary.cpp LoadBinary.h ParseBinary.cpp ParseBinary.h) + +add_library(substrait_textplan_converter ${TEXTPLAN_SRCS}) + +target_link_libraries( + substrait_textplan_converter + substrait_common + substrait_proto + symbol_table + error_listener) + +if (${SUBSTRAIT_CPP_BUILD_TESTING}) + add_subdirectory(tests) +endif () + +add_executable(planconverter Tool.cpp) + +target_link_libraries( + planconverter + substrait_textplan_converter) diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp new file mode 100644 index 00000000..2be0c93b --- /dev/null +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -0,0 +1,159 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/InitialPlanProtoVisitor.h" + +#include +#include +#include + +#include "substrait/common/Exceptions.h" +#include "substrait/proto/ProtoUtils.h" +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/Location.h" +#include "substrait/textplan/SymbolTable.h" + +namespace io::substrait::textplan { + +namespace { + +std::string shortName(std::string str) { + auto loc = str.find(':'); + if (loc != std::string::npos) { + return str.substr(0, loc); + } + return str; +} + +} // namespace + +std::any InitialPlanProtoVisitor::visitExtension( + const ::substrait::proto::extensions::SimpleExtensionDeclaration& + extension) { + if (extension.mapping_type_case() != + ::substrait::proto::extensions::SimpleExtensionDeclaration:: + kExtensionFunction) { + SUBSTRAIT_FAIL( + "Unknown mapping type case " + + std::to_string(extension.mapping_type_case()) + " encountered."); + } + const auto& unique_name = symbol_table_->getUniqueName( + shortName(extension.extension_function().name())); + symbol_table_->defineSymbol( + unique_name, + Location((::google::protobuf::Message*)&extension.extension_function()), + SymbolType::kFunction, + std::nullopt, + &extension.extension_function()); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitExtensionUri( + const ::substrait::proto::extensions::SimpleExtensionURI& uri) { + symbol_table_->defineSymbol( + uri.uri(), + Location((::google::protobuf::Message*)&uri), + SymbolType::kExtensionSpace, + std::nullopt, + uri.extension_uri_anchor()); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitPlanRelation( + const ::substrait::proto::PlanRel& relation) { + BasePlanProtoVisitor::visitPlanRelation(relation); + std::string name = PlanRelTypeCaseName(relation.rel_type_case()); + auto unique_name = symbol_table_->getUniqueName(name); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&relation), + SymbolType::kPlanRelation, + std::nullopt, + &relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitRelation( + const ::substrait::proto::Rel& relation) { + std::string name = RelTypeCaseName(relation.rel_type_case()); + BasePlanProtoVisitor::visitRelation(relation); + auto unique_name = symbol_table_->getUniqueName(name); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&relation), + SymbolType::kRelation, + relation.rel_type_case(), + &relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitRelationRoot( + const ::substrait::proto::RelRoot& relation) { + BasePlanProtoVisitor::visitRelationRoot(relation); + return std::nullopt; +} + +std::any InitialPlanProtoVisitor::visitReadRelation( + const ::substrait::proto::ReadRel& relation) { + if (relation.has_base_schema()) { + const std::string& name = symbol_table_->getUniqueName("schema"); + symbol_table_->defineSymbol( + name, + Location((google::protobuf::Message*)&relation.base_schema()), + SymbolType::kSchema, + std::nullopt, + &relation.base_schema()); + } + + return BasePlanProtoVisitor::visitReadRelation(relation); +} + +std::any InitialPlanProtoVisitor::visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("virtual"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kVirtualTable, + &table); + return BasePlanProtoVisitor::visitVirtualTable(table); +} + +std::any InitialPlanProtoVisitor::visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local) { + const auto& unique_name = symbol_table_->getUniqueName("local"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&local), + SymbolType::kSource, + SourceType::kLocalFiles, + &local); + return BasePlanProtoVisitor::visitLocalFiles(local); +} + +std::any InitialPlanProtoVisitor::visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("named"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kNamedTable, + &table); + return BasePlanProtoVisitor::visitNamedTable(table); +} + +std::any InitialPlanProtoVisitor::visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table) { + const auto& unique_name = symbol_table_->getUniqueName("extensiontable"); + symbol_table_->defineSymbol( + unique_name, + Location((google::protobuf::Message*)&table), + SymbolType::kSource, + SourceType::kExtensionTable, + &table); + return BasePlanProtoVisitor::visitExtensionTable(table); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.h b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h new file mode 100644 index 00000000..cf2de907 --- /dev/null +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h @@ -0,0 +1,60 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/SubstraitErrorListener.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/converter/BasePlanProtoVisitor.h" + +namespace io::substrait::textplan { + +// InitialPlanProtoVisitor is the first part of the binary to text conversion +// process which identifies the prominent symbols and gives them names. +class InitialPlanProtoVisitor : public BasePlanProtoVisitor { + public: + explicit InitialPlanProtoVisitor() : BasePlanProtoVisitor() { + symbol_table_ = std::make_shared(); + error_listener_ = std::make_shared(); + }; + + [[nodiscard]] std::shared_ptr getSymbolTable() const { + return symbol_table_; + }; + + [[nodiscard]] std::shared_ptr getErrorListener() + const { + return error_listener_; + }; + + private: + std::any visitExtensionUri( + const ::substrait::proto::extensions::SimpleExtensionURI& uri) override; + std::any visitExtension( + const ::substrait::proto::extensions::SimpleExtensionDeclaration& + extension) override; + + std::any visitPlanRelation( + const ::substrait::proto::PlanRel& relation) override; + std::any visitRelation(const ::substrait::proto::Rel& relation) override; + std::any visitRelationRoot( + const ::substrait::proto::RelRoot& relation) override; + std::any visitReadRelation( + const ::substrait::proto::ReadRel& relation) override; + + std::any visitVirtualTable( + const ::substrait::proto::ReadRel_VirtualTable& table) override; + std::any visitLocalFiles( + const ::substrait::proto::ReadRel_LocalFiles& local) override; + std::any visitNamedTable( + const ::substrait::proto::ReadRel_NamedTable& table) override; + std::any visitExtensionTable( + const ::substrait::proto::ReadRel_ExtensionTable& table) override; + + std::shared_ptr symbol_table_; + std::shared_ptr error_listener_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp new file mode 100644 index 00000000..4706bf8e --- /dev/null +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -0,0 +1,89 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/LoadBinary.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/common/Exceptions.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait::textplan { + +namespace { + +class StringErrorCollector : public google::protobuf::io::ErrorCollector { + public: + void AddError(int line, int column, const std::string& message) override { + errors_.push_back( + std::to_string(line + 1) + ":" + std::to_string(column + 1) + " → " + + message); + } + + [[nodiscard]] std::vector GetErrors() const { + return errors_; + } + + private: + std::vector errors_; +}; + +} // namespace + +std::string readFromFile(std::string_view msgPath) { + std::ifstream textFile(std::string{msgPath}); + if (textFile.fail()) { + auto currdir = std::filesystem::current_path().string(); + SUBSTRAIT_FAIL( + "Failed to open file {} when running in {}: {}", + msgPath, + currdir, + strerror(errno)); + } + std::stringstream buffer; + buffer << textFile.rdbuf(); + return buffer.str(); +} + +PlanOrErrors loadFromJSON(std::string_view json) { + if (json.empty()) { + return PlanOrErrors({"Provided JSON string was empty."}); + } + std::string_view usable_json = json; + if (json[0] == '#') { + int idx = 0; + while (idx < json.size() && json[idx] != '\n') { + idx++; + } + usable_json.remove_prefix(idx); + } + ::substrait::proto::Plan plan; + auto status = google::protobuf::util::JsonStringToMessage( + std::string{usable_json}, &plan); + if (!status.ok()) { + return PlanOrErrors({fmt::format( + "Failed to parse Substrait JSON: {}", status.message().ToString())}); + } + return PlanOrErrors(plan); +} + +PlanOrErrors loadFromText(const std::string& text) { + ::substrait::proto::Plan plan; + ::google::protobuf::TextFormat::Parser parser; + StringErrorCollector collector; + parser.RecordErrorsTo(&collector); + if (!parser.ParseFromString(text, &plan)) { + return PlanOrErrors(collector.GetErrors()); + } + return PlanOrErrors(plan); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.h b/src/substrait/textplan/converter/LoadBinary.h new file mode 100644 index 00000000..36a72af5 --- /dev/null +++ b/src/substrait/textplan/converter/LoadBinary.h @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include +#include + +#include "substrait/proto/plan.pb.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// PlanOrErrors behaves similarly to abseil::StatusOr. +class PlanOrErrors { + public: + explicit PlanOrErrors(::substrait::proto::Plan plan) + : plan_(std::move(plan)){}; + explicit PlanOrErrors(std::vector errors) + : errors_(std::move(errors)){}; + + bool ok() { + return errors_.empty(); + } + + const ::substrait::proto::Plan& operator*() { + return plan_; + } + + const std::vector& errors() { + return errors_; + } + + private: + ::substrait::proto::Plan plan_; + std::vector errors_; +}; + +// Read the contents of a file from disk. +// Throws an exception if file cannot be read. +std::string readFromFile(std::string_view msgPath); + +// Reads a plan from a json-encoded text proto. +// Returns a list of errors if the file cannot be parsed. +PlanOrErrors loadFromJSON(std::string_view json); + +// Reads a plan encoded as a text protobuf. +// Returns a list of errors if the file cannot be parsed. +PlanOrErrors loadFromText(const std::string& text); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/ParseBinary.cpp b/src/substrait/textplan/converter/ParseBinary.cpp new file mode 100644 index 00000000..af18981e --- /dev/null +++ b/src/substrait/textplan/converter/ParseBinary.cpp @@ -0,0 +1,27 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/ParseBinary.h" + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/converter/InitialPlanProtoVisitor.h" +#include "substrait/textplan/converter/PlanPrinterVisitor.h" + +namespace io::substrait::textplan { + +ParseResult parseBinaryPlan(const ::substrait::proto::Plan& plan) { + InitialPlanProtoVisitor visitor; + visitor.visit(plan); + auto symbols = visitor.getSymbolTable(); + auto syntaxErrors = visitor.getErrorListener()->getErrorMessages(); + std::vector semanticErrors; + + PlanPrinterVisitor printer(*symbols); + printer.visit(plan); + auto moreErrors = printer.getErrorListener()->getErrorMessages(); + semanticErrors.insert( + semanticErrors.end(), moreErrors.begin(), moreErrors.end()); + + return {*printer.getSymbolTable(), syntaxErrors, semanticErrors}; +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/ParseBinary.h b/src/substrait/textplan/converter/ParseBinary.h new file mode 100644 index 00000000..c811d1d8 --- /dev/null +++ b/src/substrait/textplan/converter/ParseBinary.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "substrait/textplan/ParseResult.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +ParseResult parseBinaryPlan(const ::substrait::proto::Plan& plan); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp new file mode 100644 index 00000000..59da0144 --- /dev/null +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -0,0 +1,192 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/PlanPrinterVisitor.h" + +#include +#include +#include +#include + +#include "substrait/proto/ProtoUtils.h" +#include "substrait/proto/algebra.pb.h" +#include "substrait/textplan/Any.h" + +namespace io::substrait::textplan { + +std::string PlanPrinterVisitor::printRelation( + const std::string& symbolName, + const ::substrait::proto::Rel* relation) { + std::stringstream text; + + text << RelTypeCaseName(relation->rel_type_case()) << " relation " + << symbolName << " {\n"; + auto symbol = symbol_table_->lookupSymbolByLocation( + Location((google::protobuf::Message*)&relation)); + if (symbol != SymbolTable::kUnknownSymbol) { + text << " source " << symbol.name << ";\n"; + } + auto result = this->visitRelation(*relation); + if (result.type() != typeid(std::string)) { + return "ERROR: Relation subtype " + + RelTypeCaseName((relation->rel_type_case())) + " not supported!\n"; + } + text << ANY_CAST(std::string, result); + text << "}\n"; + + return text.str(); +} + +std::any PlanPrinterVisitor::visitAggregateFunction( + const ::substrait::proto::AggregateFunction& function) { + return std::string("AF-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitExpression( + const ::substrait::proto::Expression& expression) { + return std::string("EXPR-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression) { + return std::string("MASKEXPR-NOT-YET-IMPLEMENTED"); +} + +std::any PlanPrinterVisitor::visitReadRelation( + const ::substrait::proto::ReadRel& relation) { + std::stringstream text; + if (relation.has_base_schema()) { + const auto& symbol = symbol_table_->lookupSymbolByLocation( + Location((google::protobuf::Message*)&relation)); + if (symbol != SymbolTable::kUnknownSymbol) { + text << " base_schema " << symbol.name << ";\n"; + } + } + if (relation.has_filter()) { + text << " filter " + << ANY_CAST(std::string, visitExpression(relation.filter())) + ";\n"; + } + if (relation.has_best_effort_filter()) { + text << " filter " + << ANY_CAST( + std::string, visitExpression(relation.best_effort_filter())) + << ";\n"; + } + if (relation.has_projection()) { + text << " projection " + << ANY_CAST(std::string, visitMaskExpression(relation.projection())) + << ";\n"; + } + + return text.str(); +} + +std::any PlanPrinterVisitor::visitFilterRelation( + const ::substrait::proto::FilterRel& relation) { + std::stringstream text; + if (relation.has_condition()) { + text << " condition " + << ANY_CAST(std::string, visitExpression(relation.condition())) + << ";\n"; + } + return text.str(); +} + +std::any PlanPrinterVisitor::visitFetchRelation( + const ::substrait::proto::FetchRel& relation) { + std::stringstream text; + if (relation.offset() != 0) { + text << " offset " << std::to_string(relation.offset()) << ";\n"; + } + text << " count " << std::to_string(relation.count()) << ";\n"; + return text.str(); +} + +std::any PlanPrinterVisitor::visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation) { + std::stringstream text; + for (const auto& group : relation.groupings()) { + for (const auto& expr : group.grouping_expressions()) { + text << " grouping " << ANY_CAST(std::string, visitExpression(expr)) + << ";\n"; + } + } + for (const auto& measure : relation.measures()) { + if (!measure.has_measure()) + continue; + text << " measure {\n"; + text << " measure " + << ANY_CAST(std::string, visitAggregateFunction(measure.measure())) + << ";\n"; + if (measure.has_filter()) { + text << " filter " + + ANY_CAST(std::string, visitExpression(measure.filter())) + << ";\n"; + } + text << " }\n"; + } + return text.str(); +} + +std::any PlanPrinterVisitor::visitSortRelation( + const ::substrait::proto::SortRel& relation) { + std::stringstream text; + for (const auto& sort : relation.sorts()) { + text << " sort " << ANY_CAST(std::string, visitExpression(sort.expr())); + switch (sort.sort_kind_case()) { + case ::substrait::proto::SortField::kDirection: + switch (sort.direction()) { + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + text << " by ASC_NULLS_FIRST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + text << " by ASC_NULLS_LAST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + text << " by DESC_NULLS_FIRST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + text << " by DESC_NULLS_LAST"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_CLUSTERED: + text << " by CLUSTERED"; + break; + case ::substrait::proto:: + SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED: + default: + break; + } + break; + case ::substrait::proto::SortField::kComparisonFunctionReference: { + auto field = symbol_table_->nthSymbolByType( + sort.comparison_function_reference(), SymbolType::kFunction); + if (field == SymbolTable::kUnknownSymbol) { + return field.name; + } else { + return "functionref#" + + std::to_string(sort.comparison_function_reference()); + } + } + case ::substrait::proto::SortField::SORT_KIND_NOT_SET: + break; + } + text << ";\n"; + } + return text.str(); +} + +std::any PlanPrinterVisitor::visitProjectRelation( + const ::substrait::proto::ProjectRel& relation) { + std::stringstream text; + for (const auto& expr : relation.expressions()) { + text << " expression " << ANY_CAST(std::string, visitExpression(expr)) + << ";\n"; + } + return text.str(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.h b/src/substrait/textplan/converter/PlanPrinterVisitor.h new file mode 100644 index 00000000..ec787c25 --- /dev/null +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.h @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/SubstraitErrorListener.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/converter/BasePlanProtoVisitor.h" + +namespace io::substrait::textplan { + +class PlanPrinterVisitor : public BasePlanProtoVisitor { + public: + // PlanPrinterVisitor takes ownership of the provided symbol table. + explicit PlanPrinterVisitor(const SymbolTable& symbol_table) { + symbol_table_ = std::make_shared(symbol_table); + error_listener_ = std::make_shared(); + }; + + [[nodiscard]] std::shared_ptr getSymbolTable() const { + return symbol_table_; + }; + + [[nodiscard]] std::shared_ptr getErrorListener() + const { + return error_listener_; + }; + + std::string printRelation( + const std::string& symbolName, + const ::substrait::proto::Rel* relation); + + private: + std::any visitAggregateFunction( + const ::substrait::proto::AggregateFunction& function) override; + std::any visitExpression( + const ::substrait::proto::Expression& expression) override; + std::any visitMaskExpression( + const ::substrait::proto::Expression::MaskExpression& expression) + override; + + std::any visitReadRelation( + const ::substrait::proto::ReadRel& relation) override; + std::any visitFilterRelation( + const ::substrait::proto::FilterRel& relation) override; + std::any visitFetchRelation( + const ::substrait::proto::FetchRel& relation) override; + std::any visitAggregateRelation( + const ::substrait::proto::AggregateRel& relation) override; + std::any visitSortRelation( + const ::substrait::proto::SortRel& relation) override; + std::any visitProjectRelation( + const ::substrait::proto::ProjectRel& relation) override; + + std::shared_ptr symbol_table_; + std::shared_ptr error_listener_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp new file mode 100644 index 00000000..97498a93 --- /dev/null +++ b/src/substrait/textplan/converter/Tool.cpp @@ -0,0 +1,52 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include + +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/ParseBinary.h" + +namespace io::substrait::textplan { +namespace { + +void convertJSONToText(const char* filename) { + std::string json = readFromFile(filename); + auto planOrError = loadFromJSON(json); + if (!planOrError.ok()) { + std::cerr << "An error occurred while reading: " << filename << std::endl; + for (const auto& err : planOrError.errors()) { + std::cerr << err << std::endl; + } + return; + } + + auto result = parseBinaryPlan(*planOrError); + std::cout << SymbolTablePrinter::outputToText(result.getSymbolTable()); +} + +} // namespace +} // namespace io::substrait::textplan + +int main(int argc, char* argv[]) { + while (true) { + int option_index = 0; + static struct option long_options[] = {{nullptr, 0, nullptr, 0}}; + + int c = getopt_long(argc, argv, "", long_options, &option_index); + if (c == -1) + break; + } + + if (optind >= argc) { + printf("Usage: planconverter ...\n"); + return EXIT_FAILURE; + } + + int curr_arg = optind; + for (; curr_arg < argc; curr_arg++) { + printf("===== %s =====\n", argv[curr_arg]); + io::substrait::textplan::convertJSONToText(argv[curr_arg]); + } + + return EXIT_SUCCESS; +} diff --git a/src/substrait/textplan/converter/data/q6_first_stage.json b/src/substrait/textplan/converter/data/q6_first_stage.json new file mode 100644 index 00000000..1985ad74 --- /dev/null +++ b/src/substrait/textplan/converter/data/q6_first_stage.json @@ -0,0 +1,701 @@ +{ + "extension_uris": [], + "extensions": [ + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 4, + "name": "lte:fp64_fp64" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 5, + "name": "sum:opt_fp64" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 3, + "name": "lt:fp64_fp64" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "is_not_null:fp64" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 1, + "name": "and:bool_bool" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 2, + "name": "gte:fp64_fp64" + } + }, + { + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 6, + "name": "multiply:opt_fp64_fp64" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "base_schema": { + "names": [ + "l_quantity", + "l_extendedprice", + "l_discount", + "l_shipdate_new" + ], + "struct": { + "types": [ + { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "type_variation_reference": 0, + "nullability": "NULLABILITY_UNSPECIFIED" + } + }, + "filter": { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 0, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 3 + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 0, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 0, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 0 + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 2, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 3 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 8766 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 3, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 3 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 9131 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 2, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 0.05 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 4, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 0.07 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 3, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 0 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 24 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + "local_files": { + "items": [ + { + "partition_index": "0", + "start": "0", + "length": "3719", + "uri_file": "/mock_lineitem.orc", + "orc": {} + } + ] + } + } + }, + "condition": { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 1, + "args": [ + { + "scalar_function": { + "function_reference": 2, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 3 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 8766 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 3, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 3 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 9131 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 2, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 0.05 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 4, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 0.07 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + }, + { + "scalar_function": { + "function_reference": 3, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 0 + } + } + } + }, + { + "literal": { + "nullable": false, + "fp64": 24 + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "output_type": { + "bool": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + } + }, + "expressions": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 1 + } + } + } + }, + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 2 + } + } + } + } + ] + } + }, + "groupings": [ + { + "grouping_expressions": [] + } + ], + "measures": [ + { + "measure": { + "function_reference": 5, + "args": [ + { + "scalar_function": { + "function_reference": 6, + "args": [ + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 0 + } + } + } + }, + { + "selection": { + "direct_reference": { + "struct_field": { + "field": 1 + } + } + } + } + ], + "output_type": { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE", + "output_type": { + "fp64": { + "type_variation_reference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ] + } + }, + "names": [] + } + } + ], + "expected_type_urls": [] +} diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp new file mode 100644 index 00000000..172b8efc --- /dev/null +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -0,0 +1,263 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include + +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/ParseBinary.h" +#include "substrait/textplan/tests/ParseResultMatchers.h" + +namespace io::substrait::textplan { + +using ::testing::Eq; + +namespace { + +class TestCase { + public: + std::string name; + std::string input; + ::testing::Matcher expectedMatch; +}; + +class BinaryToTextPlanConverterTestFixture + : public ::testing::TestWithParam {}; + +std::vector GetTestCases() { + static std::vector cases = { + { + "bad proto input", + "gibberish", + HasErrors( + {"1:10 → Message type \"substrait.proto.Plan\" has no field named \"gibberish\"."}), + }, + { + "empty plan", + "", + WhenSerialized(Eq("")), + }, + { + "empty extension space", + R"(extension_uris: { + extension_uri_anchor: 42; + uri: "http://life@everything", + })", + WhenSerialized(Eq("")), + }, + { + "used extension space", + R"(extension_uris: { + extension_uri_anchor: 42; + uri: "http://life@everything", + } + extensions: { + extension_function: { + extension_uri_reference: 42 + function_anchor: 5 + name: "sum:fp64_fp64" + } + })", + WhenSerialized(EqSquashingWhitespace( + R"(extension_space http://life@everything { + function sum:fp64_fp64 as sum; + })")), + }, + { + "seven extensions, no uris", + R"(extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 4 + name: "lte:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 5 + name: "sum:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 3 + name: "lt:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 0 + name: "is_not_null:fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 1 + name: "and:bool_bool" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 2 + name: "gte:fp64_fp64" + } + } + extensions: { + extension_function: { + extension_uri_reference: 0 + function_anchor: 6 + name: "multiply:opt_fp64_fp64" + } + })", + WhenSerialized(EqSquashingWhitespace( + R"(extension_space { + function lte:fp64_fp64 as lte; + function sum:fp64_fp64 as sum; + function lt:fp64_fp64 as lt; + function is_not_null:fp64 as is_not_null; + function and:bool_bool as and; + function gte:fp64_fp64 as gte; + function multiply:opt_fp64_fp64 as multiply; + })")), + }, + { + "read local files", + R"(relations: { + root: { + input: { + read: { + local_files { + items { + partition_index: 0 + start: 0 + length: 3719 + uri_file: "/mock_lineitem.orc" + orc {} + } + } + } + } + } + })", + AllOf( + HasSymbols({"local", "read", "root"}), + WhenSerialized(EqSquashingWhitespace( + R"(read relation read { + } + + source local_files local { + items = [ + {uri_file: "/mock_lineitem.orc" start: 0 length: 3719 orc: {}} + ] + })"))), + }, + { + "read named table", + "relations: { root: { input: { read: { base_schema {} named_table { names: \"#2\" } } } } }", + AllOf( + HasSymbols({"schema", "named", "read", "root"}), + WhenSerialized(EqSquashingWhitespace( + R"(read relation read { + } + + schema schema { + } + + source named_table named { + names = [ + "#2", + ] + })"))), + }, + { + "single three node pipeline", + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", + HasSymbols({"local", "read", "project", "root"}), + }, + { + "two identical three node pipelines", + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }" + "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", + HasSymbols( + {"local", + "read", + "project", + "root", + "local2", + "read2", + "project2", + "root2"}), + }, + }; + return cases; +} + +TEST_P(BinaryToTextPlanConverterTestFixture, Parse) { + auto [name, input, matcher] = GetParam(); + + auto planOrError = loadFromText(input); + if (!planOrError.ok()) { + ParseResult result(SymbolTable(), planOrError.errors(), {}); + ASSERT_THAT(result, matcher); + return; + } + + auto result = parseBinaryPlan(*planOrError); + ASSERT_THAT(result, matcher); +} + +INSTANTIATE_TEST_SUITE_P( + BinaryPlanConversionTests, + BinaryToTextPlanConverterTestFixture, + ::testing::ValuesIn(GetTestCases()), + [](const testing::TestParamInfo& info) { + std::string identifier = info.param.name; + // Remove non-alphanumeric characters to make the test framework happy. + identifier.erase( + std::remove_if( + identifier.begin(), + identifier.end(), + [](auto const& c) -> bool { return !std::isalnum(c); }), + identifier.end()); + return identifier; + }); + +class BinaryToTextPlanConversionTest : public ::testing::Test {}; + +TEST_F(BinaryToTextPlanConversionTest, loadFromJSON) { + std::string json = readFromFile("data/q6_first_stage.json"); + auto planOrError = loadFromJSON(json); + ASSERT_TRUE(planOrError.ok()); + auto plan = *planOrError; + EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); + + auto result = parseBinaryPlan(plan); + auto symbols = result.getSymbolTable().getSymbols(); + ASSERT_THAT( + result, + HasSymbols({ + "lte", + "sum", + "lt", + "is_not_null", + "and", + "gte", + "multiply", + + "schema", + "local", + + "read", + "filter", + "project", + "aggregate", + "root", + })); +} + +} // namespace +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/tests/CMakeLists.txt b/src/substrait/textplan/converter/tests/CMakeLists.txt new file mode 100644 index 00000000..8bad053d --- /dev/null +++ b/src/substrait/textplan/converter/tests/CMakeLists.txt @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_test_case( + textplan_conversion_test + SOURCES + BinaryToTextPlanConversionTest.cpp + EXTRA_LINK_LIBS + substrait_textplan_converter + substrait_common + parse_result_matchers + gmock + gtest + gtest_main) + +cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH TEXTPLAN_SOURCE_DIR) + +add_custom_command( + TARGET textplan_conversion_test + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Copying unit test data.." + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" + COMMAND ${CMAKE_COMMAND} -E copy + "${TEXTPLAN_SOURCE_DIR}/data/q6_first_stage.json" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" + ) + +message(STATUS "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json") diff --git a/src/substrait/textplan/tests/CMakeLists.txt b/src/substrait/textplan/tests/CMakeLists.txt new file mode 100644 index 00000000..fb0bbb8a --- /dev/null +++ b/src/substrait/textplan/tests/CMakeLists.txt @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_library(parse_result_matchers + ParseResultMatchers.cpp ParseResultMatchers.h) + +add_dependencies(parse_result_matchers parse_result) + +target_link_libraries( + parse_result_matchers + parse_result + gmock) + +add_test_case( + symbol_table_test + SOURCES + SymbolTableTest.cpp + EXTRA_LINK_LIBS + symbol_table + substrait_proto + fmt::fmt-header-only + gmock + gtest + gtest_main) diff --git a/src/substrait/textplan/tests/ParseResultMatchers.cpp b/src/substrait/textplan/tests/ParseResultMatchers.cpp new file mode 100644 index 00000000..f6d9d0ff --- /dev/null +++ b/src/substrait/textplan/tests/ParseResultMatchers.cpp @@ -0,0 +1,256 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/tests/ParseResultMatchers.h" + +#include +#include +#include + +#include +#include + +#include "substrait/textplan/ParseResult.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/SymbolTablePrinter.h" + +namespace io::substrait::textplan { +namespace { + +std::vector symbolNames( + const std::vector>& symbols) { + std::vector names; + for (const auto& symbol : symbols) { + names.push_back(symbol->name); + } + return names; +} + +bool StringEqSquashingWhitespace( + const std::string& have, + const std::string& expected) { + auto atHave = have.begin(); + auto atExpected = expected.begin(); + while (atHave != have.end() && atExpected != expected.end()) { + if (isspace(*atExpected)) { + if (!isspace(*atHave)) { + return false; + } + // Have a match, consume all remaining space. + do { + atExpected++; + } while (atExpected != expected.end() && isspace(*atExpected)); + do { + atHave++; + } while (atHave != have.end() && isspace(*atHave)); + continue; + } + if (*atHave != *atExpected) { + return false; + } + atHave++; + atExpected++; + } + // For convenience consume any trailing whitespace on both sides. + if (atExpected != expected.end()) { + do { + atExpected++; + } while (atExpected != expected.end() && isspace(*atExpected)); + } + if (atHave != have.end()) { + do { + atHave++; + } while (atHave != have.end() && isspace(*atHave)); + } + return atHave == have.end() && atExpected == expected.end(); +} + +} // namespace + +class ParsesOkMatcher { + public: + using is_gtest_matcher = void; + + static bool MatchAndExplain( + const ParseResult& result, + std::ostream* /* listener */) { + return result.successful(); + } + + static void DescribeTo(std::ostream* os) { + *os << "parses successfully"; + } + + static void DescribeNegationTo(std::ostream* os) { + *os << "does not parse successfully"; + } +}; + +[[maybe_unused]] ::testing::Matcher ParsesOk() { + return ParsesOkMatcher(); +} + +class HasSymbolsMatcher { + public: + using is_gtest_matcher = void; + + explicit HasSymbolsMatcher(std::vector expected_symbols) + : expected_symbols_(std::move(expected_symbols)) {} + + bool MatchAndExplain(const ParseResult& result, std::ostream* listener) + const { + auto actual_symbols = symbolNames(result.getSymbolTable().getSymbols()); + if (listener != nullptr) { + std::vector extra_symbols(actual_symbols.size()); + auto end = std::set_difference( + actual_symbols.begin(), + actual_symbols.end(), + expected_symbols_.begin(), + expected_symbols_.end(), + extra_symbols.begin()); + extra_symbols.resize(end - extra_symbols.begin()); + if (!extra_symbols.empty()) { + *listener << std::endl << " with missing symbols: "; + for (const auto& symbol : extra_symbols) { + *listener << " \"" << symbol << "\""; + } + } + + std::vector missing_symbols(expected_symbols_.size()); + end = std::set_difference( + expected_symbols_.begin(), + expected_symbols_.end(), + actual_symbols.begin(), + actual_symbols.end(), + missing_symbols.begin()); + missing_symbols.resize(end - missing_symbols.begin()); + if (!missing_symbols.empty()) { + if (!extra_symbols.empty()) { + *listener << ", and extra symbols: "; + } else { + *listener << " with extra symbols: "; + } + for (const auto& symbol : missing_symbols) { + *listener << " \"" << symbol << "\""; + } + } + } + return actual_symbols == expected_symbols_; + } + + void DescribeTo(std::ostream* os) const { + *os << "has exactly these symbols: " + << ::testing::PrintToString(expected_symbols_); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not have exactly these symbols: " + << ::testing::PrintToString(expected_symbols_); + } + + private: + const std::vector expected_symbols_; +}; + +::testing::Matcher HasSymbols( + std::vector expected_symbols) { + return HasSymbolsMatcher(std::move(expected_symbols)); +} + +class WhenSerializedMatcher { + public: + using is_gtest_matcher = void; + + explicit WhenSerializedMatcher( + ::testing::Matcher string_matcher) + : string_matcher_(std::move(string_matcher)) {} + + bool MatchAndExplain( + const ParseResult& result, + ::testing::MatchResultListener* listener) const { + std::string outputText = + SymbolTablePrinter::outputToText(result.getSymbolTable()); + return MatchPrintAndExplain(outputText, string_matcher_, listener); + } + + void DescribeTo(::std::ostream* os) const { + *os << "matches after serializing "; + string_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const { + *os << "does not match after serializing "; + string_matcher_.DescribeTo(os); + } + + private: + ::testing::Matcher string_matcher_; +}; + +::testing::Matcher WhenSerialized( + ::testing::Matcher string_matcher) { + return WhenSerializedMatcher(std::move(string_matcher)); +} + +class HasErrorsMatcher { + public: + using is_gtest_matcher = void; + + explicit HasErrorsMatcher(std::vector expected_errors) + : expected_errors_(std::move(expected_errors)) {} + + bool MatchAndExplain(const ParseResult& result, std::ostream* /* listener */) + const { + return result.getAllErrors() == expected_errors_; + } + + void DescribeTo(std::ostream* os) const { + *os << "has exactly these symbols: " + << ::testing::PrintToString(expected_errors_); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not have exactly these symbols: " + << ::testing::PrintToString(expected_errors_); + } + + private: + const std::vector expected_errors_; +}; + +::testing::Matcher HasErrors( + std::vector expected_errors) { + return HasErrorsMatcher(std::move(expected_errors)); +} + +class EqSquashingWhitespaceMatcher { + public: + using is_gtest_matcher = void; + + explicit EqSquashingWhitespaceMatcher(std::string expected_string) + : expected_string_(std::move(expected_string)) {} + + bool MatchAndExplain(const std::string& str, std::ostream* /* listener */) + const { + return StringEqSquashingWhitespace(str, expected_string_); + } + + void DescribeTo(std::ostream* os) const { + *os << "equals squashing whitespace " + << ::testing::PrintToString(expected_string_); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not equal squashing whitespace " + << ::testing::PrintToString(expected_string_); + } + + private: + std::string expected_string_; +}; + +::testing::Matcher EqSquashingWhitespace( + std::string expected_string) { + return EqSquashingWhitespaceMatcher(std::move(expected_string)); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/tests/ParseResultMatchers.h b/src/substrait/textplan/tests/ParseResultMatchers.h new file mode 100644 index 00000000..96f4707d --- /dev/null +++ b/src/substrait/textplan/tests/ParseResultMatchers.h @@ -0,0 +1,28 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include + +#include "substrait/textplan/ParseResult.h" + +namespace io::substrait::textplan { + +[[maybe_unused]] ::testing::Matcher ParsesOk(); + +[[maybe_unused]] ::testing::Matcher HasSymbols( + std::vector expected_symbols); + +[[maybe_unused]] ::testing::Matcher WhenSerialized( + ::testing::Matcher string_matcher); + +[[maybe_unused]] ::testing::Matcher HasErrors( + std::vector expected_errors); + +// Matches strings ignoring differences in kinds of whitespace (as long as they +// are present) and ignoring trailing whitespace as well. +[[maybe_unused]] ::testing::Matcher EqSquashingWhitespace( + std::string expected_string); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/tests/SymbolTableTest.cpp b/src/substrait/textplan/tests/SymbolTableTest.cpp new file mode 100644 index 00000000..68e3aa4d --- /dev/null +++ b/src/substrait/textplan/tests/SymbolTableTest.cpp @@ -0,0 +1,143 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/SymbolTable.h" +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/Any.h" +#include "substrait/textplan/Location.h" + +#include +#include + +namespace io::substrait::textplan { +namespace { + +class SymbolTableTest : public ::testing::Test { + public: + SymbolTableTest() : UnspecifiedLocation(Location::kUnknownLocation){}; + + protected: + static std::vector symbolNames( + const std::vector>& symbols) { + std::vector names; + for (const auto& symbol : symbols) { + names.push_back(symbol->name); + } + return names; + } + + static SymbolTable createSimpleTable(::substrait::proto::Plan* plan) { + auto* ptr1 = plan->add_relations(); + auto* ptr2 = plan->add_extension_uris(); + auto* ptr3 = plan->add_extensions(); + + SymbolTable table; + table.defineSymbol( + "symbol1", + Location(ptr1), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr1); + table.defineSymbol( + "symbol2", + Location(ptr2), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr2); + table.defineSymbol( + "symbol3", + Location(ptr3), + SymbolType::kUnknown, + RelationType::kUnknown, + ptr3); + return table; + } + + const Location UnspecifiedLocation; +}; + +TEST_F(SymbolTableTest, DuplicateSymbolsNotDetected) { + SymbolTable table; + table.defineSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + table.defineSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + + ASSERT_THAT( + symbolNames(table.getSymbols()), ::testing::ElementsAre("a", "a")); +} + +TEST_F(SymbolTableTest, DuplicateSymbolsHandledByUnique) { + SymbolTable table; + table.defineUniqueSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + table.defineUniqueSymbol( + "a", + UnspecifiedLocation, + SymbolType::kUnknown, + RelationType::kUnknown, + nullptr); + + ASSERT_THAT( + symbolNames(table.getSymbols()), ::testing::ElementsAre("a", "a2")); +} + +TEST_F(SymbolTableTest, LocationsUnchangedAfterCopy) { + ::substrait::proto::Plan plan; + SymbolTable table = createSimpleTable(&plan); + auto* ptr1 = &plan.relations(0); + auto* ptr2 = plan.mutable_extension_uris(0); + auto* ptr3 = &plan.extensions(0); + + SymbolTable table2 = table; + auto symbols = table2.getSymbols(); + ASSERT_THAT( + symbolNames(symbols), + ::testing::ElementsAre("symbol1", "symbol2", "symbol3")); + + ASSERT_THAT( + ANY_CAST(::substrait::proto::PlanRel*, symbols[0]->blob), + ::testing::Eq(ptr1)); + ASSERT_THAT( + ANY_CAST( + ::substrait::proto::extensions::SimpleExtensionURI*, + symbols[1]->blob), + ::testing::Eq(ptr2)); + ASSERT_THAT( + ANY_CAST( + ::substrait::proto::extensions::SimpleExtensionDeclaration*, + symbols[2]->blob), + ::testing::Eq(ptr3)); + + ASSERT_THAT(symbols[0]->location, ::testing::Eq(symbols[0]->location)); + ASSERT_THAT( + symbols[0]->location, + ::testing::Not(::testing::Eq(symbols[1]->location))); + ASSERT_THAT( + symbols[0]->location, + ::testing::Not(::testing::Eq(symbols[2]->location))); + ASSERT_THAT( + symbols[1]->location, + ::testing::Not(::testing::Eq(symbols[2]->location))); + + ASSERT_THAT( + table.getSymbols()[0]->location, ::testing::Eq(symbols[0]->location)); + ASSERT_THAT( + table.getSymbols()[1]->location, ::testing::Eq(symbols[1]->location)); + ASSERT_THAT( + table.getSymbols()[2]->location, ::testing::Eq(symbols[2]->location)); +} + +} // namespace +} // namespace io::substrait::textplan