Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/substrait/proto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ foreach(PROTO_FILE IN LISTS PROTOBUF_FILELIST)
endforeach()

# Add the generated protobuf C++ files to our exported library.
add_library(substrait_proto ${PROTO_SRCS} ${PROTO_HDRS})
add_library(
substrait_proto
${PROTO_SRCS} ${PROTO_HDRS}
ProtoUtils.cpp ProtoUtils.h)

# Include the protobuf library as a dependency to use this class.
target_link_libraries(substrait_proto ${PROTOBUF_LIBRARIES})
Expand Down
47 changes: 47 additions & 0 deletions src/substrait/proto/ProtoUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/proto/ProtoUtils.h"

namespace substrait::proto {

std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num) {
static std::vector<std::string> case_names = {
"unknown",
"rel",
"root",
};

if (num >= case_names.size()) {
return "unknown";
}

return case_names[num];
}

std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num) {
static std::vector<std::string> case_names = {
"unknown",
"read",
"filter",
"fetch",
"aggregate",
"sort",
"join",
"project",
"set",
"extensionsingle",
"extensionmulti",
"extensionleaf",
"cross",
"hashjoin",
"mergejoin",
};

if (num >= case_names.size()) {
return "unknown";
}

return case_names[num];
}

} // namespace substrait::proto
16 changes: 16 additions & 0 deletions src/substrait/proto/ProtoUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string>

#include "substrait/proto/algebra.pb.h"
#include "substrait/proto/plan.pb.h"

namespace substrait::proto {

std::string PlanRelTypeCaseName(::substrait::proto::PlanRel::RelTypeCase num);

std::string RelTypeCaseName(::substrait::proto::Rel::RelTypeCase num);

} // namespace substrait::proto
25 changes: 25 additions & 0 deletions src/substrait/textplan/Any.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <any>

#include "fmt/format.h"

namespace io::substrait::textplan {

template <class ValueType>
inline ValueType any_cast(const std::any& value, const char* file, int line) {
try {
return std::any_cast<ValueType>(value);
} catch (std::bad_any_cast& ex) {
throw std::invalid_argument(
fmt::format("{}:{} - {}", file, line, "bad any cast"));
}
}

// A wrapper around std::any_cast that provides exceptions with line numbers.
#define ANY_CAST(ValueType, Value) \
::io::substrait::textplan::any_cast<ValueType>(Value, __FILE__, __LINE__)

} // namespace io::substrait::textplan
15 changes: 12 additions & 3 deletions src/substrait/textplan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ add_subdirectory(converter)
add_library(symbol_table
Location.cpp Location.h
SymbolTable.cpp SymbolTable.h
SymbolTablePrinter.cpp SymbolTablePrinter.h)
SymbolTablePrinter.cpp SymbolTablePrinter.h
Any.h)

add_library(error_listener
SubstraitErrorListener.cpp SubstraitErrorListener.h
Expand All @@ -16,7 +17,15 @@ add_library(parse_result
)

add_dependencies(symbol_table
substrait_proto)
substrait_proto
substrait_common
fmt::fmt-header-only
)

target_link_libraries(
symbol_table
fmt::fmt-header-only
)

# Provide access to the generated protobuffer headers hierarchy.
target_include_directories(
Expand All @@ -25,4 +34,4 @@ target_include_directories(

if (${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
endif ()
endif ()
74 changes: 34 additions & 40 deletions src/substrait/textplan/Location.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,28 @@
#include "substrait/textplan/Location.h"

#include <functional>
#include <sstream>
#include <string>

namespace io::substrait::textplan {

ProtoLocation ProtoLocation::visit(const std::string& name) const {
ProtoLocation new_location = *this;
new_location.location_.push_back(name);
return new_location;
}

std::string ProtoLocation::toString() const {
std::stringstream text;
bool written = false;
for (const auto& loc : location_) {
if (!written) {
text << " -> ";
written = true;
}
text << loc;
}
return text.str();
}
constexpr const Location Location::kUnknownLocation(
static_cast<google::protobuf::Message*>(nullptr));

bool operator==(const Location& c1, const Location& c2) {
// Test only one side since we only store one kind of content per table.
if (std::holds_alternative<ProtoLocation>(
c1.loc_)) {
auto s1 = std::get<ProtoLocation>(c1.loc_).toString();
auto s2 = std::get<ProtoLocation>(c2.loc_).toString();
return s1 == s2;
} else if (std::holds_alternative<antlr4::ParserRuleContext*>(c1.loc_)) {
if (std::holds_alternative<antlr4::ParserRuleContext*>(c1.loc_)) {
if (!std::holds_alternative<antlr4::ParserRuleContext*>(c2.loc_)) {
return false;
}
auto a1 = std::get<antlr4::ParserRuleContext*>(c1.loc_);
auto a2 = std::get<antlr4::ParserRuleContext*>(c2.loc_);
return a1 == a2;
} else if (std::holds_alternative<google::protobuf::Message*>(c1.loc_)) {
if (!std::holds_alternative<google::protobuf::Message*>(c2.loc_)) {
return false;
}
auto a1 = std::get<google::protobuf::Message*>(c1.loc_);
auto a2 = std::get<google::protobuf::Message*>(c2.loc_);
return a1 == a2;
}
// Should not be reached.
return false;
Expand All @@ -47,28 +34,35 @@ bool operator==(const Location& c1, const Location& c2) {

std::size_t std::hash<::io::substrait::textplan::Location>::operator()(
const ::io::substrait::textplan::Location& loc) const noexcept {
if (std::holds_alternative<::io::substrait::textplan::ProtoLocation>(
loc.loc_)) {
return std::hash<std::string>()(
std::get<::io::substrait::textplan::ProtoLocation>(loc.loc_)
.toString());
} else if (std::holds_alternative<antlr4::ParserRuleContext*>(loc.loc_)) {
if (std::holds_alternative<antlr4::ParserRuleContext*>(loc.loc_)) {
return std::hash<antlr4::ParserRuleContext*>()(
std::get<antlr4::ParserRuleContext*>(loc.loc_));
} else if (std::holds_alternative<google::protobuf::Message*>(loc.loc_)) {
return std::hash<google::protobuf::Message*>()(
std::get<google::protobuf::Message*>(loc.loc_));
}
// Should not be reached.
return 0;
}

std::size_t std::less<::io::substrait::textplan::Location>::operator()(
bool std::less<::io::substrait::textplan::Location>::operator()(
const ::io::substrait::textplan::Location& lhs,
const ::io::substrait::textplan::Location& rhs) const noexcept {
if (std::holds_alternative<::io::substrait::textplan::ProtoLocation>(
lhs.loc_)) {
return std::get<::io::substrait::textplan::ProtoLocation>(lhs.loc_)
.toString() <
std::get<::io::substrait::textplan::ProtoLocation>(rhs.loc_).toString();
if (std::holds_alternative<antlr4::ParserRuleContext*>(lhs.loc_)) {
if (!std::holds_alternative<antlr4::ParserRuleContext*>(rhs.loc_)) {
// This alternative is always less than the remaining choices.
return true;
}
return std::get<antlr4::ParserRuleContext*>(lhs.loc_) <
std::get<antlr4::ParserRuleContext*>(rhs.loc_);
} else if (std::holds_alternative<google::protobuf::Message*>(lhs.loc_)) {
if (!std::holds_alternative<google::protobuf::Message*>(rhs.loc_)) {
// This alternative is always less than the remaining (zero) choices.
return true;
}
return std::get<google::protobuf::Message*>(lhs.loc_) <
std::get<google::protobuf::Message*>(rhs.loc_);
}
return std::get<antlr4::ParserRuleContext*>(lhs.loc_) <
std::get<antlr4::ParserRuleContext*>(rhs.loc_);
// Should not be reached.
return false;
}
30 changes: 11 additions & 19 deletions src/substrait/textplan/Location.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,23 @@ namespace antlr4 {
class ParserRuleContext;
}

namespace io::substrait::textplan {

class ProtoLocation {
public:
[[nodiscard]] ProtoLocation visit(const std::string& name) const;

[[nodiscard]] std::string toString() const;

private:
std::vector<std::string> location_;
namespace google::protobuf {
class Message;
};

namespace io::substrait::textplan {

// Location is used for keeping track of where a symbol is within a parse tree.
// Since SymbolTable supports both antlr4 and protobuf messages there are
// essentially two flavors of location. It is expected that only one type of
// location would be used in any SymbolTable instance.
class Location {
public:
explicit Location(antlr4::ParserRuleContext* node) {
loc_ = node;
};
explicit Location(ProtoLocation location) {
loc_ = location;
};
constexpr explicit Location(antlr4::ParserRuleContext* node) : loc_(node) {}

constexpr explicit Location(google::protobuf::Message* msg) : loc_(msg) {}

static const Location kUnknownLocation;

protected:
friend bool operator==(const Location& c1, const Location& c2);
Expand All @@ -42,7 +35,7 @@ class Location {
friend std::hash<Location>;
friend std::less<Location>;

std::variant<ProtoLocation, antlr4::ParserRuleContext*> loc_;
std::variant<antlr4::ParserRuleContext*, google::protobuf::Message*> loc_;
};

} // namespace io::substrait::textplan
Expand All @@ -55,8 +48,7 @@ struct std::hash<::io::substrait::textplan::Location> {

template <>
struct std::less<::io::substrait::textplan::Location> {
std::size_t operator()(
bool operator()(
const ::io::substrait::textplan::Location& lhs,
const ::io::substrait::textplan::Location& rhs) const noexcept;
};

2 changes: 1 addition & 1 deletion src/substrait/textplan/SubstraitErrorListener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ std::vector<std::string> SubstraitErrorListener::getErrorMessages() {
return messages;
}

} // namespace io::substrait::textplan
} // namespace io::substrait::textplan
23 changes: 18 additions & 5 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ SymbolTableIterator SymbolTableIterator::operator++() {
return *this;
}

bool operator==(const SymbolInfo& left, const SymbolInfo& right) {
return (left.name == right.name) && (left.location == right.location) &&
(left.type == right.type);
}

bool operator!=(const SymbolInfo& left, const SymbolInfo& right) {
return !(left == right);
}

std::string SymbolTable::getUniqueName(const std::string& base_name) {
auto symbolSeenCount = names_.find(base_name);
if (symbolSeenCount == names_.end()) {
Expand Down Expand Up @@ -58,8 +67,7 @@ SymbolInfo* SymbolTable::defineUniqueSymbol(
return defineSymbol(unique_name, location, type, subtype, blob);
}

const SymbolInfo& SymbolTable::lookupSymbolByName(
const std::string& name) {
const SymbolInfo& SymbolTable::lookupSymbolByName(const std::string& name) {
auto itr = symbols_by_name_.find(name);
if (itr == symbols_by_name_.end()) {
return kUnknownSymbol;
Expand All @@ -76,9 +84,7 @@ const SymbolInfo& SymbolTable::lookupSymbolByLocation(
return *symbols_[itr->second];
}

const SymbolInfo& SymbolTable::nthSymbolByType(
uint32_t n,
SymbolType type) {
const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type) {
int count = 0;
for (const auto& symbol : symbols_) {
if (symbol->type == type) {
Expand All @@ -97,4 +103,11 @@ SymbolTableIterator SymbolTable::end() const {
return {this, symbols_by_name_.size()};
}

const SymbolInfo SymbolTable::kUnknownSymbol = {
"__UNKNOWN__",
Location::kUnknownLocation,
SymbolType::kUnknown,
std::nullopt,
std::nullopt};

} // namespace io::substrait::textplan
Loading