diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..eab4576f --- /dev/null +++ b/.clang-format @@ -0,0 +1,87 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/.gitignore b/.gitignore index 259148fa..ca9d349b 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ *.exe *.out *.app + +src/proto/substrait diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..aa615775 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "third_party/yaml-cpp"] + path = third_party/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git +[submodule "third_party/googletest"] + path = third_party/googletest + url = https://github.com/google/googletest.git +[submodule "third_party/substrait"] + path = third_party/substrait + url = https://github.com/substrait-io/substrait.git +[submodule "third_party/fmt"] + path = third_party/fmt + url = https://github.com/fmtlib/fmt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..ea70a95a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,33 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +cmake_minimum_required(VERSION 3.10) + +# set the project name +project(substrait-cpp) + +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +option( + BUILD_TESTING + "Enable substrait-cpp tests. This will enable all other build options automatically." + ON) + +find_package(Protobuf REQUIRED) +include_directories(${PROTOBUF_INCLUDE_DIRS}) + +add_subdirectory(third_party) +include_directories(src) +add_subdirectory(src) diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..6b59d481 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +.PHONY: all clean build debug release + +BUILD_TYPE := Release + +all: debug + +clean: + @rm -rf build-* + +build-common: + @mkdir -p build-${BUILD_TYPE} + @cd build-${BUILD_TYPE} && \ + cmake -Wno-dev \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DPREFER_STATIC_LIBS=OFF \ + $(FORCE_COLOR) \ + .. + +build: + VERBOSE=1 cmake --build build-${BUILD_TYPE} -j $${CPU_COUNT:-`nproc`} || \ + cmake --build build-${BUILD_TYPE} + +debug: + @$(MAKE) build-common BUILD_TYPE=Debug + @$(MAKE) build BUILD_TYPE=Debug + +release: + @$(MAKE) build-common BUILD_TYPE=Release + @$(MAKE) build BUILD_TYPE=Release diff --git a/README.md b/README.md index 36ca729e..65fbdfaa 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,36 @@ # substrait-cpp Planned home for CPP libraries to help build/consume Substrait query plans. + +## Getting Started + +We provide scripts to help developers setup and install substrait-cpp dependencies. + +### Get the substrait-cpp Source +``` +git clone --recursive https://github.com/substrait-io/substrait-cpp.git +cd substrait-cpp +# if you are updating an existing checkout +git submodule sync --recursive +git submodule update --init --recursive +``` + +### Setting up on Linux (Ubuntu 20.04 or later) + +Once you have checked out substrait-cpp, you can setup and build like so: + +```shell +$ ./scripts/setup-ubuntu.sh +$ make +``` + +## Community + +The main communication channel with the substrait through the +[substrait chanel](http://substrait.slack.com). + + +## License + +substrait-cpp is licensed under the Apache 2.0 License. A copy of the license +[can be found here.](LICENSE) \ No newline at end of file diff --git a/scripts/setup-helper-functions.sh b/scripts/setup-helper-functions.sh new file mode 100755 index 00000000..b6238d1d --- /dev/null +++ b/scripts/setup-helper-functions.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# github_checkout $REPO $VERSION $GIT_CLONE_PARAMS clones or re-uses an existing clone of the +# specified repo, checking out the requested version. +function github_checkout { + local REPO=$1 + shift + local VERSION=$1 + shift + local GIT_CLONE_PARAMS=$@ + local DIRNAME=$(basename $REPO) + cd "${DEPENDENCY_DIR}" + if [ -z "${DIRNAME}" ]; then + echo "Failed to get repo name from ${REPO}" + exit 1 + fi + if [ -d "${DIRNAME}" ] && prompt "${DIRNAME} already exists. Delete?"; then + rm -rf "${DIRNAME}" + fi + if [ ! -d "${DIRNAME}" ]; then + git clone -q -b $VERSION $GIT_CLONE_PARAMS "https://github.com/${REPO}.git" + fi + cd "${DIRNAME}" +} + + +# get_cxx_flags [$CPU_ARCH] +# Sets and exports the variable VELOX_CXX_FLAGS with appropriate compiler flags. +# If $CPU_ARCH is set then we use that else we determine best possible set of flags +# to use based on current cpu architecture. +# The goal of this function is to consolidate all architecture specific flags to one +# location. +# The values that CPU_ARCH can take are as follows: +# arm64 : Target Apple silicon. +# aarch64: Target general 64 bit arm cpus. +# avx: Target Intel CPUs with AVX. +# sse: Target Intel CPUs with sse. +# Echo's the appropriate compiler flags which can be captured as so +# CXX_FLAGS=$(get_cxx_flags) or +# CXX_FLAGS=$(get_cxx_flags "avx") + +function get_cxx_flags { + local CPU_ARCH=$1 + + local OS + OS=$(uname) + local MACHINE + MACHINE=$(uname -m) + + if [ -z "$CPU_ARCH" ]; then + + if [ "$OS" = "Darwin" ]; then + + if [ "$MACHINE" = "x86_64" ]; then + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(sysctl -a | grep machdep.cpu.features | awk '{print tolower($0)}') + + if [[ $CPU_CAPABILITIES =~ "avx" ]]; then + CPU_ARCH="avx" + else + CPU_ARCH="sse" + fi + + elif [[ $(sysctl -a | grep machdep.cpu.brand_string) =~ "Apple" ]]; then + # Apple silicon. + CPU_ARCH="arm64" + fi + else [ "$OS" = "Linux" ]; + + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(cat /proc/cpuinfo | grep flags | head -n 1| awk '{print tolower($0)}') + + if [[ "$CPU_CAPABILITIES" =~ "avx" ]]; then + CPU_ARCH="avx" + elif [[ "$CPU_CAPABILITIES" =~ "sse" ]]; then + CPU_ARCH="sse" + elif [ "$MACHINE" = "aarch64" ]; then + CPU_ARCH="aarch64" + fi + fi + fi + + case $CPU_ARCH in + + "arm64") + echo -n "-mcpu=apple-m1+crc -std=c++17" + ;; + + "avx") + echo -n "-mavx2 -mfma -mavx -mf16c -mlzcnt -std=c++17" + ;; + + "sse") + echo -n "-msse4.2 -std=c++17" + ;; + + "aarch64") + echo -n "-mcpu=neoverse-n1 -std=c++17" + ;; + *) + echo -n "Architecture not supported!" + esac + +} + +function cmake_install { + local NAME=$(basename "$(pwd)") + local BINARY_DIR=_build + if [ -d "${BINARY_DIR}" ] && prompt "Do you want to rebuild ${NAME}?"; then + rm -rf "${BINARY_DIR}" + fi + mkdir -p "${BINARY_DIR}" + CPU_TARGET="${CPU_TARGET:-avx}" + COMPILER_FLAGS=$(get_cxx_flags $CPU_TARGET) + + # CMAKE_POSITION_INDEPENDENT_CODE is required so that Velox can be built into dynamic libraries \ + cmake -Wno-dev -B"${BINARY_DIR}" \ + -GNinja \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_CXX_STANDARD=17 \ + "${INSTALL_PREFIX+-DCMAKE_PREFIX_PATH=}${INSTALL_PREFIX-}" \ + "${INSTALL_PREFIX+-DCMAKE_INSTALL_PREFIX=}${INSTALL_PREFIX-}" \ + -DCMAKE_CXX_FLAGS="$COMPILER_FLAGS" \ + -DBUILD_TESTING=OFF \ + "$@" + ninja -C "${BINARY_DIR}" install +} + diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh new file mode 100755 index 00000000..dc0fe5f2 --- /dev/null +++ b/scripts/setup-ubuntu.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Minimal setup for Ubuntu 20.04. +set -eufx -o pipefail +SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") +source $SCRIPTDIR/setup-helper-functions.sh + +CPU_TARGET="${CPU_TARGET:-avx}" +export COMPILER_FLAGS=$(get_cxx_flags $CPU_TARGET) +NPROC=$(getconf _NPROCESSORS_ONLN) +DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} + +# Install all dependencies. +sudo --preserve-env apt install -y \ + g++ \ + cmake \ + ccache \ + ninja-build \ + checkinstall \ + git + +function run_and_time { + time "$@" + { echo "+ Finished running $*"; } 2> /dev/null +} + +function prompt { + ( + while true; do + local input="${PROMPT_ALWAYS_RESPOND:-}" + echo -n "$(tput bold)$* [Y, n]$(tput sgr0) " + [[ -z "${input}" ]] && read input + if [[ "${input}" == "Y" || "${input}" == "y" || "${input}" == "" ]]; then + return 0 + elif [[ "${input}" == "N" || "${input}" == "n" ]]; then + return 1 + fi + done + ) 2> /dev/null +} + +function install_protobuf { + wget https://github.com/protocolbuffers/protobuf/releases/download/v21.4/protobuf-all-21.4.tar.gz + tar -xzf protobuf-all-21.4.tar.gz + cd protobuf-21.4 + ./configure --prefix=/usr + make "-j$(nproc)" + make install + ldconfig +} + +function install_deps { + run_and_time install_protobuf +} + +(return 2> /dev/null) && return # If script was sourced, don't run commands. + +( + if [[ $# -ne 0 ]]; then + for cmd in "$@"; do + run_and_time "${cmd}" + done + else + install_deps + fi +) + +echo "All deps installed! Now try \"make\"" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..d53e7c7a --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +add_subdirectory(common) +add_subdirectory(core) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt new file mode 100644 index 00000000..d5ee9b15 --- /dev/null +++ b/src/common/CMakeLists.txt @@ -0,0 +1,23 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library( + substrait_common + Exceptions.cpp) + +target_link_libraries( + substrait_common + fmt) + +if (${BUILD_TESTING}) + add_subdirectory(tests) +endif () diff --git a/src/common/Exceptions.cpp b/src/common/Exceptions.cpp new file mode 100644 index 00000000..e76e9d59 --- /dev/null +++ b/src/common/Exceptions.cpp @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Exceptions.h" +#include "fmt/format.h" + +namespace io::substrait { +namespace common { +SubstraitException::SubstraitException( + std::string exceptionCode, + std::string& exceptionMessage, + Type exceptionType, + std::string exceptionName) + : msg_(fmt::format( + "Exception: {}\nError Code: {}\nReason: {}\n" + "Function: {}\nFile: {}\n:Line: {}\n", + exceptionName, + exceptionCode, + exceptionMessage, + __FUNCTION__, + __FILE__, + std::to_string(__LINE__))) {} +} // namespace common +} // namespace io::substrait diff --git a/src/common/Exceptions.h b/src/common/Exceptions.h new file mode 100644 index 00000000..f7ef68dc --- /dev/null +++ b/src/common/Exceptions.h @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "fmt/format.h" + +namespace io::substrait { +namespace common { +namespace error_code { + +//====================== User Error Codes ======================: + +// An error raised when an argument verification fails +inline constexpr auto kInvalidArgument = "INVALID_ARGUMENT"; + +// An error raised when a requested operation is not supported. +inline constexpr auto kUnsupported = "UNSUPPORTED"; + +//====================== Runtime Error Codes ======================: + +// An error raised when the current state of a component is invalid. +inline constexpr auto kInvalidState = "INVALID_STATE"; + +// An error raised when unreachable code point was executed. +inline constexpr auto kUnreachableCode = "UNREACHABLE_CODE"; + +// An error raised when a requested operation is not yet supported. +inline constexpr auto kNotImplemented = "NOT_IMPLEMENTED"; + +} // namespace error_code + +class SubstraitException : public std::exception { + public: + enum class Type { kUser = 0, kSystem = 1 }; + + SubstraitException( + std::string exceptionCode, + std::string& exceptionMessage, + Type exceptionType = Type::kSystem, + std::string exceptionName = "SubstraitException"); + + // Inherited + const char* what() const noexcept override { + return msg_.c_str(); + } + + private: + const std::string msg_; +}; + +class SubstraitUserError : public SubstraitException { + public: + SubstraitUserError( + std::string exceptionCode, + std::string& exceptionMessage, + std::string exceptionName = "SubstraitUserError") + : SubstraitException( + exceptionCode, + exceptionMessage, + Type::kUser, + exceptionName) {} +}; + +class SubstraitRuntimeError final : public SubstraitException { + public: + SubstraitRuntimeError( + std::string exceptionCode, + std::string& exceptionMessage, + std::string exceptionName = "SubstraitRuntimeError") + : SubstraitException( + exceptionCode, + exceptionMessage, + Type::kSystem, + exceptionName) {} +}; + +template +std::string errorMessage(fmt::string_view fmt, const Args&... args) { + return fmt::vformat(fmt, fmt::make_format_args(args...)); +} + +#define _SUBSTRAIT_THROW(exception, errorCode, ...) \ + { \ + auto message = io::substrait::common::errorMessage(__VA_ARGS__); \ + throw exception(errorCode, message); \ + } + +#define SUBSTRAIT_UNSUPPORTED(...) \ + _SUBSTRAIT_THROW( \ + ::io::substrait::common::SubstraitUserError, \ + ::io::substrait::common::error_code::kUnsupported, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_UNREACHABLE(...) \ + _SUBSTRAIT_THROW( \ + ::io::substrait::common::SubstraitRuntimeError, \ + ::io::substrait::common::error_code::kUnreachableCode, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_FAIL(...) \ + _SUBSTRAIT_THROW( \ + ::io::substrait::common::SubstraitRuntimeError, \ + ::io::substrait::common::error_code::kInvalidState, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_USER_FAIL(...) \ + _SUBSTRAIT_THROW( \ + ::io::substrait::common::SubstraitUserError, \ + ::io::substrait::common::error_code::kInvalidState, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_NYI(...) \ + _SUBSTRAIT_THROW( \ + ::io::substrait::common::SubstraitRuntimeError, \ + ::io::substrait::common::error_code::kNotImplemented, \ + ##__VA_ARGS__) + +} // namespace common +} // namespace io::substrait diff --git a/src/common/tests/CMakeLists.txt b/src/common/tests/CMakeLists.txt new file mode 100644 index 00000000..dbf17ecb --- /dev/null +++ b/src/common/tests/CMakeLists.txt @@ -0,0 +1,25 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable( + substrait_common_test + ExceptionsTest.cpp) + +add_test( + substrait_common_test + substrait_common_test) + +target_link_libraries( + substrait_common_test + substrait_common + gtest + gtest_main) diff --git a/src/common/tests/ExceptionsTest.cpp b/src/common/tests/ExceptionsTest.cpp new file mode 100644 index 00000000..6956ca11 --- /dev/null +++ b/src/common/tests/ExceptionsTest.cpp @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/Exceptions.h" +#include + +class SubstraitExceptionTest : public ::testing::Test {}; + +TEST_F(SubstraitExceptionTest, decodeTest) {} diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt new file mode 100644 index 00000000..8e3a4974 --- /dev/null +++ b/src/core/CMakeLists.txt @@ -0,0 +1,58 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Set up Proto +set(PROTO_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/substrait/proto) +set(SUBSTRAIT_PROTO_DIRECTORY ${PROTO_DIRECTORY}/substrait) +set(PROTO_OUTPUT_DIR "${CMAKE_SOURCE_DIR}/third_party/substrait/proto") +file(GLOB PROTO_FILES ${SUBSTRAIT_PROTO_DIRECTORY}/*.proto + ${SUBSTRAIT_PROTO_DIRECTORY}/extensions/*.proto) +foreach (PROTO ${PROTO_FILES}) + file(RELATIVE_PATH REL_PROTO ${SUBSTRAIT_PROTO_DIRECTORY} ${PROTO}) + string(REGEX REPLACE "\\.proto" "" PROTO_NAME ${REL_PROTO}) + list(APPEND PROTO_SRCS "${PROTO_OUTPUT_DIR}/substrait/${PROTO_NAME}.pb.cc") + list(APPEND PROTO_HDRS "${PROTO_OUTPUT_DIR}/substrait/${PROTO_NAME}.pb.h") +endforeach () +set(PROTO_OUTPUT_FILES ${PROTO_HDRS} ${PROTO_SRCS}) +set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) +get_filename_component(PROTO_DIR ${SUBSTRAIT_PROTO_DIRECTORY}/, DIRECTORY) + +# Generate Substrait hearders +add_custom_command( + OUTPUT ${PROTO_OUTPUT_FILES} + COMMAND protoc --proto_path ${PROTO_DIRECTORY}/ --cpp_out ${PROTO_OUTPUT_DIR} + ${PROTO_FILES} + DEPENDS ${PROTO_DIR} + COMMENT "Running PROTO compiler" + VERBATIM) + +include_directories(${PROTO_OUTPUT_DIR}) + +set(SRCS + Type.cpp + Function.cpp + Extension.cpp + FunctionMapping.h + FunctionLookup.cpp + FunctionSignature.h) + +add_library(substrait-core ${SRCS}) + +target_link_libraries( + substrait-core + substrait_common + yaml-cpp) + +if (${BUILD_TESTING}) + add_subdirectory(tests) +endif () diff --git a/src/core/Extension.cpp b/src/core/Extension.cpp new file mode 100644 index 00000000..e73554d0 --- /dev/null +++ b/src/core/Extension.cpp @@ -0,0 +1,292 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Extension.h" +#include "yaml-cpp/yaml.h" + +bool decodeFunctionVariant( + const YAML::Node& node, + io::substrait::FunctionVariant& function) { + const auto& returnType = node["return"]; + if (returnType && returnType.IsScalar()) { + /// Return type can be an expression. + const auto& returnExpr = returnType.as(); + std::stringstream ss(returnExpr); + + // TODO: currently we only parse the last sentence of type definition, use + // ANTLR in future. + std::string lastReturnType; + while (std::getline(ss, lastReturnType, '\n')) { + } + function.returnType = io::substrait::Type::decode(lastReturnType); + } + const auto& args = node["args"]; + if (args && args.IsSequence()) { + for (auto& arg : args) { + if (arg["options"]) { // enum argument + auto enumArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(enumArgument); + } else if (arg["value"]) { // value argument + auto valueArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(valueArgument); + } else { // type argument + auto typeArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(typeArgument); + } + } + } + + const auto& variadic = node["variadic"]; + if (variadic) { + auto& min = variadic["min"]; + auto& max = variadic["max"]; + if (min) { + function.variadic = std::make_optional( + {min.as(), + max ? std::make_optional(max.as()) : std::nullopt}); + } else { + function.variadic = std::nullopt; + } + } else { + function.variadic = std::nullopt; + } + + return true; +} + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::EnumArgument& argument) { + // 'options' is required property + const auto& options = node["options"]; + if (options && options.IsSequence()) { + auto& required = node["required"]; + argument.required = required && required.as(); + return true; + } else { + return false; + } + } +}; + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::ValueArgument& argument) { + const auto& value = node["value"]; + if (value && value.IsScalar()) { + auto valueType = value.as(); + argument.type = io::substrait::Type::decode(valueType); + return true; + } + return false; + } +}; + +template <> +struct YAML::convert { + static bool decode( + const YAML::Node& node, + io::substrait::TypeArgument& argument) { + // no properties need to populate for type argument, just return true if + // 'type' element exists. + if (node["type"]) { + return true; + } + return false; + } +}; + +template <> +struct YAML::convert { + static bool decode( + const Node& node, + io::substrait::ScalarFunctionVariant& function) { + return decodeFunctionVariant(node, function); + }; +}; + +template <> +struct YAML::convert { + static bool decode( + const Node& node, + io::substrait::AggregateFunctionVariant& function) { + const auto& res = decodeFunctionVariant(node, function); + if (res) { + const auto& intermediate = node["intermediate"]; + if (intermediate) { + function.intermediate = + io::substrait::Type::decode(intermediate.as()); + } + } + return res; + } +}; + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::TypeVariant& typeAnchor) { + const auto& name = node["name"]; + if (name && name.IsScalar()) { + typeAnchor.name = name.as(); + return true; + } + return false; + } +}; + +namespace io::substrait { + +std::shared_ptr Extension::load(const std::string& basePath) { + static const std::vector extensionFiles{ + "functions_aggregate_approx.yaml", + "functions_aggregate_generic.yaml", + "functions_arithmetic.yaml", + "functions_arithmetic_decimal.yaml", + "functions_boolean.yaml", + "functions_comparison.yaml", + "functions_datetime.yaml", + "functions_logarithmic.yaml", + "functions_rounding.yaml", + "functions_string.yaml", + "functions_set.yaml", + "unknown.yaml", + }; + return load(basePath, extensionFiles); +} + +std::shared_ptr Extension::load( + const std::string& basePath, + const std::vector& extensionFiles) { + std::vector yamlExtensionFiles; + yamlExtensionFiles.reserve(extensionFiles.size()); + for (auto& extensionFile : extensionFiles) { + auto const pos = basePath.find_last_of('/'); + const auto& extensionUri = basePath.substr(0, pos) + "/" + extensionFile; + yamlExtensionFiles.emplace_back(extensionUri); + } + return load(yamlExtensionFiles); +} + +std::shared_ptr Extension::load( + const std::vector& extensionFiles) { + auto extension = std::make_shared(); + for (const auto& extensionUri : extensionFiles) { + const auto& node = YAML::LoadFile(extensionUri); + + const auto& scalarFunctions = node["scalar_functions"]; + if (scalarFunctions && scalarFunctions.IsSequence()) { + for (auto& scalarFunctionNode : scalarFunctions) { + const auto functionName = scalarFunctionNode["name"].as(); + for (auto& scalaFunctionVariantNode : scalarFunctionNode["impls"]) { + auto scalarFunctionVariant = + scalaFunctionVariantNode.as(); + scalarFunctionVariant.name = functionName; + scalarFunctionVariant.uri = extensionUri; + extension->addScalarFunctionVariant( + std::make_shared(scalarFunctionVariant)); + } + } + } + + const auto& aggregateFunctions = node["aggregate_functions"]; + if (aggregateFunctions && aggregateFunctions.IsSequence()) { + for (auto& aggregateFunctionNode : aggregateFunctions) { + const auto functionName = + aggregateFunctionNode["name"].as(); + for (auto& aggregateFunctionVariantNode : + aggregateFunctionNode["impls"]) { + auto aggregateFunctionVariant = + aggregateFunctionVariantNode.as(); + aggregateFunctionVariant.name = functionName; + aggregateFunctionVariant.uri = extensionUri; + extension->addAggregateFunctionVariant( + std::make_shared( + aggregateFunctionVariant)); + } + } + } + + const auto& types = node["types"]; + if (types && types.IsSequence()) { + for (auto& type : types) { + auto typeAnchor = type.as(); + typeAnchor.uri = extensionUri; + extension->addTypeVariant(std::make_shared(typeAnchor)); + } + } + } + return extension; +} + +void Extension::addWindowFunctionVariant( + const FunctionVariantPtr& functionVariant) { + const auto& functionVariants = + windowFunctionVariantMap_.find(functionVariant->name); + if (functionVariants != windowFunctionVariantMap_.end()) { + auto& variants = functionVariants->second; + variants.emplace_back(functionVariant); + } else { + std::vector variants; + variants.emplace_back(functionVariant); + windowFunctionVariantMap_.insert( + {functionVariant->name, std::move(variants)}); + } +} + +void Extension::addTypeVariant(const TypeVariantPtr& functionVariant) { + typeVariantMap_.insert({functionVariant->name, functionVariant}); +} + +TypeVariantPtr Extension::lookupType(const std::string& typeName) const { + auto typeVariantIter = typeVariantMap_.find(typeName); + if (typeVariantIter != typeVariantMap_.end()) { + return typeVariantIter->second; + } + return nullptr; +} + +void Extension::addScalarFunctionVariant( + const FunctionVariantPtr& functionVariant) { + const auto& functionVariants = + scalarFunctionVariantMap_.find(functionVariant->name); + if (functionVariants != scalarFunctionVariantMap_.end()) { + auto& variants = functionVariants->second; + variants.emplace_back(functionVariant); + } else { + std::vector variants; + variants.emplace_back(functionVariant); + scalarFunctionVariantMap_.insert( + {functionVariant->name, std::move(variants)}); + } +} + +void Extension::addAggregateFunctionVariant( + const FunctionVariantPtr& functionVariant) { + const auto& functionVariants = + aggregateFunctionVariantMap_.find(functionVariant->name); + if (functionVariants != aggregateFunctionVariantMap_.end()) { + auto& variants = functionVariants->second; + variants.emplace_back(functionVariant); + } else { + std::vector variants; + variants.emplace_back(functionVariant); + aggregateFunctionVariantMap_.insert( + {functionVariant->name, std::move(variants)}); + } +} + +} // namespace io::substrait diff --git a/src/core/Extension.h b/src/core/Extension.h new file mode 100644 index 00000000..0b1a1692 --- /dev/null +++ b/src/core/Extension.h @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Function.h" +#include "FunctionSignature.h" +#include "Type.h" + +namespace io::substrait { + +struct TypeVariant { + std::string name; + std::string uri; +}; + +using TypeVariantPtr = std::shared_ptr; + +using FunctionVariantMap = + std::unordered_map>; + +using TypeVariantMap = std::unordered_map; + +class Extension { + public: + /// Deserialize default substrait extension by given basePath + /// @throws exception if file not found + static std::shared_ptr load(const std::string& basePath); + + /// Deserialize substrait extension by given basePath and extensionFiles. + static std::shared_ptr load( + const std::string& basePath, + const std::vector& extensionFiles); + + /// Deserialize substrait extension by given extensionFiles. + static std::shared_ptr load( + const std::vector& extensionFiles); + + /// Add a scalar function variant. + void addScalarFunctionVariant(const FunctionVariantPtr& functionVariant); + + /// Add a aggregate function variant. + void addAggregateFunctionVariant(const FunctionVariantPtr& functionVariant); + + /// Add a window function variant. + void addWindowFunctionVariant(const FunctionVariantPtr& functionVariant); + + /// Add a type variant. + void addTypeVariant(const TypeVariantPtr& functionVariant); + + /// Lookup type variant by given type name. + /// @return matched type variant + TypeVariantPtr lookupType(const std::string& typeName) const; + + const FunctionVariantMap& scalaFunctionVariantMap() const { + return scalarFunctionVariantMap_; + } + + const FunctionVariantMap& windowFunctionVariantMap() const { + return windowFunctionVariantMap_; + } + + const FunctionVariantMap& aggregateFunctionVariantMap() const { + return aggregateFunctionVariantMap_; + } + + private: + FunctionVariantMap scalarFunctionVariantMap_; + + FunctionVariantMap aggregateFunctionVariantMap_; + + FunctionVariantMap windowFunctionVariantMap_; + + TypeVariantMap typeVariantMap_; +}; + +using ExtensionPtr = std::shared_ptr; + +} // namespace io::substrait diff --git a/src/core/Function.cpp b/src/core/Function.cpp new file mode 100644 index 00000000..bfd450f5 --- /dev/null +++ b/src/core/Function.cpp @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Function.h" +#include + +namespace io::substrait { + +std::string FunctionVariant::signature( + const std::string& name, + const std::vector& arguments) { + std::stringstream ss; + ss << name; + if (!arguments.empty()) { + ss << ":"; + for (auto it = arguments.begin(); it != arguments.end(); ++it) { + const auto& typeSign = (*it)->toTypeString(); + if (it == arguments.end() - 1) { + ss << typeSign; + } else { + ss << typeSign << "_"; + } + } + } + + return ss.str(); +} + +bool FunctionVariant::tryMatch(const FunctionSignature& signature) { + const auto& actualTypes = signature.arguments; + if (variadic.has_value()) { + // return false if actual types length less than min of variadic + const auto max = variadic->max; + if ((actualTypes.size() < variadic->min) || + (max.has_value() && actualTypes.size() > max.value())) { + return false; + } + + const auto& variadicArgument = arguments[0]; + // actual type must same as the variadicArgument + if (const auto& variadicValueArgument = + std::dynamic_pointer_cast(variadicArgument)) { + for (auto& actualType : actualTypes) { + if (!variadicValueArgument->type->isSameAs(actualType)) { + return false; + } + } + } + } else { + std::vector> valueArguments; + for (const auto& argument : arguments) { + if (const auto& variadicValueArgument = + std::dynamic_pointer_cast(argument)) { + valueArguments.emplace_back(variadicValueArgument); + } + } + // return false if size of actual types not equal to size of value + // arguments. + if (valueArguments.size() != actualTypes.size()) { + return false; + } + + for (auto i = 0; i < actualTypes.size(); i++) { + const auto& valueArgument = valueArguments[i]; + if (!valueArgument->type->isSameAs(actualTypes[i])) { + return false; + } + } + } + const auto& sigReturnType = signature.returnType; + if (this->returnType && sigReturnType) { + return returnType->isSameAs(sigReturnType); + } else { + return true; + } +} + +bool AggregateFunctionVariant::tryMatch(const FunctionSignature& signature) { + bool matched = FunctionVariant::tryMatch(signature); + if (!matched && intermediate) { + const auto& actualTypes = signature.arguments; + if (actualTypes.size() == 1) { + return intermediate->isSameAs(actualTypes[0]); + } + } + return matched; +} + +} // namespace io::substrait diff --git a/src/core/Function.h b/src/core/Function.h new file mode 100644 index 00000000..2b15628a --- /dev/null +++ b/src/core/Function.h @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "FunctionSignature.h" +#include "Type.h" + +namespace io::substrait { + +struct FunctionArgument { + virtual bool isRequired() const = 0; + + /// Convert argument type to short type string based on + /// https://substrait.io/extensions/#function-signature-compound-names + virtual std::string toTypeString() const = 0; + + virtual bool isWildcardType() const { + return false; + }; + + virtual bool isValueArgument() const { + return false; + } +}; + +using FunctionArgumentPtr = std::shared_ptr; + +struct EnumArgument : public FunctionArgument { + bool required; + + bool isRequired() const override { + return required; + } + + std::string toTypeString() const override { + return required ? "req" : "opt"; + } +}; + +struct TypeArgument : public FunctionArgument { + std::string toTypeString() const override { + return "type"; + } + + bool isRequired() const override { + return true; + } +}; + +struct ValueArgument : public FunctionArgument { + TypePtr type; + + std::string toTypeString() const override { + return type->signature(); + } + + bool isRequired() const override { + return true; + } + + bool isWildcardType() const override { + return type->isWildcard(); + } + + bool isValueArgument() const override { + return true; + } +}; + +struct FunctionVariadic { + int min; + std::optional max; +}; + +struct FunctionVariant { + std::string name; + std::string uri; + std::vector arguments; + TypePtr returnType; + std::optional variadic; + + /// Test if the actual types matched with this function variant. + virtual bool tryMatch(const FunctionSignature& signature); + + /// Create function signature by given function name and arguments. + static std::string signature( + const std::string& name, + const std::vector& arguments); + + /// Create function signature by function name and arguments. + const std::string signature() const { + return signature(name, arguments); + } +}; + +using FunctionVariantPtr = std::shared_ptr; + +struct ScalarFunctionVariant : public FunctionVariant {}; + +struct AggregateFunctionVariant : public FunctionVariant { + TypePtr intermediate; + + bool tryMatch(const FunctionSignature& signature) override; +}; + +} // namespace io::substrait diff --git a/src/core/FunctionLookup.cpp b/src/core/FunctionLookup.cpp new file mode 100644 index 00000000..0e4eb3d9 --- /dev/null +++ b/src/core/FunctionLookup.cpp @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FunctionLookup.h" + +namespace io::substrait { + +FunctionVariantPtr FunctionLookup::lookupFunction( + const FunctionSignature& signature) const { + const auto& functionMappings = getFunctionMap(); + + const auto& substraitFunctionName = + functionMappings.find(signature.name) != functionMappings.end() + ? functionMappings.at(signature.name) + : signature.name; + + const auto& functionVariants = getFunctionVariants(); + auto functionVariantIter = functionVariants.find(substraitFunctionName); + if (functionVariantIter != functionVariants.end()) { + for (const auto& candidateFunctionVariant : functionVariantIter->second) { + if (candidateFunctionVariant->tryMatch(signature)) { + return candidateFunctionVariant; + } + } + } + return nullptr; +} + +} // namespace io::substrait diff --git a/src/core/FunctionLookup.h b/src/core/FunctionLookup.h new file mode 100644 index 00000000..6a42d347 --- /dev/null +++ b/src/core/FunctionLookup.h @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Extension.h" +#include "FunctionMapping.h" +#include "FunctionSignature.h" + +namespace io::substrait { + +class FunctionLookup { + public: + FunctionLookup( + const ExtensionPtr& extension, + const FunctionMappingPtr& functionMapping) + : extension_(extension), functionMapping_(functionMapping) {} + + virtual FunctionVariantPtr lookupFunction( + const FunctionSignature& signature) const; + + virtual ~FunctionLookup() {} + + protected: + virtual const FunctionMap getFunctionMap() const = 0; + + virtual const FunctionVariantMap& getFunctionVariants() const = 0; + + const FunctionMappingPtr functionMapping_; + + ExtensionPtr extension_; +}; + +using FunctionLookupPtr = std::shared_ptr; + +class ScalarFunctionLookup : public FunctionLookup { + public: + ScalarFunctionLookup( + const ExtensionPtr& extension, + const FunctionMappingPtr& functionMapping) + : FunctionLookup(extension, functionMapping) {} + + protected: + const FunctionMap getFunctionMap() const override { + return functionMapping_->scalaMapping(); + } + + const FunctionVariantMap& getFunctionVariants() const override { + return extension_->scalaFunctionVariantMap(); + } +}; + +class AggregateFunctionLookup : public FunctionLookup { + public: + AggregateFunctionLookup( + const ExtensionPtr& extension, + const FunctionMappingPtr& functionMapping) + : FunctionLookup(extension, functionMapping) {} + + protected: + const FunctionMap getFunctionMap() const override { + return functionMapping_->aggregateMapping(); + } + + const FunctionVariantMap& getFunctionVariants() const override { + return extension_->aggregateFunctionVariantMap(); + } +}; + +class WindowFunctionLookup : public FunctionLookup { + public: + WindowFunctionLookup( + const ExtensionPtr& extension, + const FunctionMappingPtr& functionMapping) + : FunctionLookup(extension, functionMapping) {} + + protected: + const FunctionMap getFunctionMap() const override { + return functionMapping_->windowMapping(); + } + + const FunctionVariantMap& getFunctionVariants() const override { + return extension_->windowFunctionVariantMap(); + } +}; + +} // namespace io::substrait diff --git a/src/core/FunctionMapping.h b/src/core/FunctionMapping.h new file mode 100644 index 00000000..6bee1327 --- /dev/null +++ b/src/core/FunctionMapping.h @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace io::substrait { + +using FunctionMap = std::unordered_map; + +/// An interface describe the function names in difference between engine-own +/// and substrait system. +class FunctionMapping { + public: + /// Scalar function names in difference between engine own and substrait. + virtual const FunctionMap& scalaMapping() const { + static const FunctionMap scalaFunctionMap{}; + return scalaFunctionMap; + } + + /// Scalar function names in difference between engine own and substrait. + virtual const FunctionMap& aggregateMapping() const { + static const FunctionMap aggregateFunctionMap{}; + return aggregateFunctionMap; + } + + /// Window function names in difference between engine own and substrait. + virtual const FunctionMap& windowMapping() const { + static const FunctionMap windowFunctionMap{}; + return windowFunctionMap; + } +}; + +using FunctionMappingPtr = std::shared_ptr; +} // namespace io::substrait diff --git a/src/core/FunctionSignature.h b/src/core/FunctionSignature.h new file mode 100644 index 00000000..722fb781 --- /dev/null +++ b/src/core/FunctionSignature.h @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include "Type.h" + +namespace io::substrait { + +struct FunctionSignature { + std::string name; + std::vector arguments; + TypePtr returnType; +}; + +} // namespace io::substrait diff --git a/src/core/Type.cpp b/src/core/Type.cpp new file mode 100644 index 00000000..0885e761 --- /dev/null +++ b/src/core/Type.cpp @@ -0,0 +1,356 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Type.h" +#include +#include +#include +#include "../common/Exceptions.h" + +namespace io::substrait { + +namespace { + +size_t findNextComma(const std::string& str, size_t start) { + int cnt = 0; + for (auto i = start; i < str.size(); i++) { + if (str[i] == '<') { + cnt++; + } else if (str[i] == '>') { + cnt--; + } else if (cnt == 0 && str[i] == ',') { + return i; + } + } + + return std::string::npos; +} + +} // namespace + +TypePtr Type::decode(const std::string& rawType) { + std::string matchingType = rawType; + const auto& questionMaskPos = rawType.find_last_of('?'); + // deal with type and with a question mask like "i32?". + if (questionMaskPos != std::string::npos) { + matchingType = rawType.substr(0, questionMaskPos); + } + std::transform( + matchingType.begin(), + matchingType.end(), + matchingType.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const auto& leftAngleBracketPos = rawType.find('<'); + if (leftAngleBracketPos == std::string::npos) { + const auto& scalarType = scalarTypeMapping().find(matchingType); + if (scalarType != scalarTypeMapping().end()) { + return scalarType->second; + } else if (matchingType.rfind("unknown", 0) == 0) { + return std::make_shared(rawType); + } else { + return std::make_shared(rawType); + } + } + const auto& rightAngleBracketPos = rawType.rfind('>'); + + auto baseType = matchingType.substr(0, leftAngleBracketPos); + + std::vector nestedTypes; + auto prevPos = leftAngleBracketPos + 1; + auto commaPos = findNextComma(rawType, prevPos); + while (commaPos != std::string::npos) { + auto token = rawType.substr(prevPos, commaPos - prevPos); + nestedTypes.emplace_back(decode(token)); + prevPos = commaPos + 1; + commaPos = findNextComma(rawType, prevPos); + } + auto token = rawType.substr(prevPos, rightAngleBracketPos - prevPos); + nestedTypes.emplace_back(decode(token)); + + if (TypeTraits::typeString == baseType) { + return std::make_shared(nestedTypes[0]); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared(nestedTypes[0], nestedTypes[1]); + } else if (TypeTraits::typeString == baseType) { + auto precision = + std::dynamic_pointer_cast(nestedTypes[0]); + auto scale = + std::dynamic_pointer_cast(nestedTypes[1]); + return std::make_shared(precision, scale); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared(nestedTypes); + } else { + SUBSTRAIT_UNSUPPORTED("Unsupported substrait type: " + rawType); + } +} + +#define SCALAR_TYPE_MAPPING(typeKind) \ + { \ + TypeTraits::typeString, \ + std::make_shared>( \ + TypeBase()) \ + } + +const std::unordered_map& Type::scalarTypeMapping() { + static const std::unordered_map scalarTypeMap{ + SCALAR_TYPE_MAPPING(kBool), + SCALAR_TYPE_MAPPING(kI8), + SCALAR_TYPE_MAPPING(kI16), + SCALAR_TYPE_MAPPING(kI32), + SCALAR_TYPE_MAPPING(kI64), + SCALAR_TYPE_MAPPING(kFp32), + SCALAR_TYPE_MAPPING(kFp64), + SCALAR_TYPE_MAPPING(kString), + SCALAR_TYPE_MAPPING(kBinary), + SCALAR_TYPE_MAPPING(kTimestamp), + SCALAR_TYPE_MAPPING(kTimestampTz), + SCALAR_TYPE_MAPPING(kDate), + SCALAR_TYPE_MAPPING(kTime), + SCALAR_TYPE_MAPPING(kIntervalDay), + SCALAR_TYPE_MAPPING(kIntervalYear), + SCALAR_TYPE_MAPPING(kUuid), + }; + return scalarTypeMap; +} + +std::string FixedBinaryType::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<"; + sign << length_->value(); + sign << ">"; + return sign.str(); +} + +bool FixedBinaryType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = + std::dynamic_pointer_cast(other)) { + return true; + } + return false; +} + +std::string DecimalType::signature() const { + std::stringstream signature; + signature << TypeBase::signature(); + signature << "<"; + signature << precision_->value() << "," << scale_->value(); + signature << ">"; + return signature.str(); +} + +bool DecimalType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = std::dynamic_pointer_cast(other)) { + return true; + } + return false; +} + +std::string FixedCharType::signature() const { + std::ostringstream sign; + sign << TypeBase::signature(); + sign << "<"; + sign << length_->value(); + sign << ">"; + return sign.str(); +} + +bool FixedCharType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = + std::dynamic_pointer_cast(other)) { + return true; + } + return false; +} + +std::string VarcharType::signature() const { + std::ostringstream sign; + sign << TypeBase::signature(); + sign << "<"; + sign << length_->value(); + sign << ">"; + return sign.str(); +} + +bool VarcharType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = std::dynamic_pointer_cast(other)) { + return true; + } + return false; +} + +std::string StructType::signature() const { + std::ostringstream signature; + signature << TypeBase::signature(); + signature << "<"; + for (auto it = children_.begin(); it != children_.end(); ++it) { + const auto& typeSign = (*it)->signature(); + if (it == children_.end() - 1) { + signature << typeSign; + } else { + signature << typeSign << ","; + } + } + signature << ">"; + return signature.str(); +} + +bool StructType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = std::dynamic_pointer_cast(other)) { + bool sameSize = type->children_.size() == children_.size(); + if (sameSize) { + for (int i = 0; i < children_.size(); i++) { + if (!children_[i]->isSameAs(type->children_[i])) { + return false; + } + } + return true; + } + } + return false; +} + +std::string MapType::signature() const { + std::ostringstream signature; + signature << TypeBase::signature(); + signature << "<"; + signature << keyType_->signature(); + signature << ","; + signature << valueType_->signature(); + signature << ">"; + return signature.str(); +} + +bool MapType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = std::dynamic_pointer_cast(other)) { + return keyType_->isSameAs(type->keyType_) && + valueType_->isSameAs(type->valueType_); + } + return false; +} + +std::string ListType::signature() const { + std::ostringstream signature; + signature << TypeBase::signature(); + signature << "<"; + signature << elementType_->signature(); + signature << ">"; + return signature.str(); +} + +bool ListType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = std::dynamic_pointer_cast(other)) { + return elementType_->isSameAs(type->elementType_); + } + return false; +} + +bool UsedDefinedType::isSameAs(const std::shared_ptr& other) const { + if (const auto& type = + std::dynamic_pointer_cast(other)) { + return type->value_ == value_; + } + return false; +} + +bool StringLiteralType::isSameAs( + const std::shared_ptr& other) const { + if (isWildcard()) { + return true; + } + if (const auto& type = + std::dynamic_pointer_cast(other)) { + return type->value_ == value_; + } + return false; +} + +std::shared_ptr> BOOL() { + return std::make_shared>(); +} + +std::shared_ptr> TINYINT() { + return std::make_shared>(); +} + +std::shared_ptr> SMALLINT() { + return std::make_shared>(); +} + +std::shared_ptr> INTEGER() { + return std::make_shared>(); +} + +std::shared_ptr> BIGINT() { + return std::make_shared>(); +} + +std::shared_ptr> FLOAT() { + return std::make_shared>(); +} + +std::shared_ptr> DOUBLE() { + return std::make_shared>(); +} + +std::shared_ptr> STRING() { + return std::make_shared>(); +} + +std::shared_ptr> BINARY() { + return std::make_shared>(); +} + +std::shared_ptr> TIMESTAMP() { + return std::make_shared>(); +} + +std::shared_ptr> DATE() { + return std::make_shared>(); +} + +std::shared_ptr> TIME() { + return std::make_shared>(); +} + +std::shared_ptr> INTERVAL_YEAR() { + return std::make_shared>(); +} + +std::shared_ptr> INTERVAL_DAY() { + return std::make_shared>(); +} + +std::shared_ptr> TIMESTAMP_TZ() { + return std::make_shared>(); +} + +std::shared_ptr> UUID() { + return std::make_shared>(); +} + +} // namespace io::substrait diff --git a/src/core/Type.h b/src/core/Type.h new file mode 100644 index 00000000..f9cd2bf3 --- /dev/null +++ b/src/core/Type.h @@ -0,0 +1,459 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace io::substrait { + +enum class TypeKind : int8_t { + kBool = 1, + kI8 = 2, + kI16 = 3, + kI32 = 5, + kI64 = 7, + kFp32 = 10, + kFp64 = 11, + kString = 12, + kBinary = 13, + kTimestamp = 14, + kDate = 16, + kTime = 17, + kIntervalYear = 19, + kIntervalDay = 20, + kTimestampTz = 29, + kUuid = 32, + kFixedChar = 21, + kVarchar = 22, + kFixedBinary = 23, + kDecimal = 24, + kStruct = 25, + kList = 27, + kMap = 28, + kUserDefined = 30, + KIND_NOT_SET = 0, +}; + +template +struct TypeTraits {}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "bool"; + static constexpr const char* typeString = "boolean"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i8"; + static constexpr const char* typeString = "i8"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i16"; + static constexpr const char* typeString = "i16"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i32"; + static constexpr const char* typeString = "i32"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i64"; + static constexpr const char* typeString = "i64"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fp32"; + static constexpr const char* typeString = "fp32"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fp64"; + static constexpr const char* typeString = "fp64"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "str"; + static constexpr const char* typeString = "string"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "vbin"; + static constexpr const char* typeString = "binary"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "ts"; + static constexpr const char* typeString = "timestamp"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "tstz"; + static constexpr const char* typeString = "timestamp_tz"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "date"; + static constexpr const char* typeString = "date"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "time"; + static constexpr const char* typeString = "time"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "iyear"; + static constexpr const char* typeString = "interval_year"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "iday"; + static constexpr const char* typeString = "interval_day"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "uuid"; + static constexpr const char* typeString = "uuid"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fchar"; + static constexpr const char* typeString = "fixedchar"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "vchar"; + static constexpr const char* typeString = "varchar"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fbin"; + static constexpr const char* typeString = "fixedbinary"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "dec"; + static constexpr const char* typeString = "decimal"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "struct"; + static constexpr const char* typeString = "struct"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "list"; + static constexpr const char* typeString = "list"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "map"; + static constexpr const char* typeString = "map"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "u!name"; + static constexpr const char* typeString = "user defined type"; +}; + +class Type { + public: + /// Deserialize substrait raw type string into Substrait extension type. + /// @param rawType - substrait extension raw string type + static std::shared_ptr decode(const std::string& rawType); + + virtual std::string signature() const = 0; + + /// Test type is a Wildcard type or not. + virtual bool isWildcard() const { + return false; + } + + virtual TypeKind kind() const = 0; + + virtual std::string typeString() const = 0; + + /// Test whether two types are same as each other + virtual bool isSameAs(const std::shared_ptr& other) const { + return kind() == other->kind(); + } + + private: + /// A map store the raw type string and corresponding Substrait Type + static const std::unordered_map>& + scalarTypeMapping(); +}; + +using TypePtr = std::shared_ptr; + +/// Types used in function argument declarations. +template +class TypeBase : public Type { + public: + std::string signature() const override { + return TypeTraits::signature; + } + + virtual TypeKind kind() const override { + return Kind; + } + + std::string typeString() const override { + return TypeTraits::typeString; + } +}; + +template +class ScalarType : public TypeBase {}; + +/// A string literal type can present the 'any1'. +class StringLiteralType : public Type { + public: + StringLiteralType(const std::string& value) : value_(value) {} + + const std::string& value() const { + return value_; + } + + std::string signature() const override { + return value_; + } + + std::string typeString() const override { + return value_; + } + + bool isWildcard() const override { + return value_.find("any") == 0 || value_ == "T"; + } + + bool isSameAs(const std::shared_ptr& other) const override; + + TypeKind kind() const override { + return TypeKind ::KIND_NOT_SET; + } + + private: + const std::string value_; +}; + +using StringLiteralTypePtr = std::shared_ptr; + +class DecimalType : public TypeBase { + public: + DecimalType( + const StringLiteralTypePtr& precision, + const StringLiteralTypePtr& scale) + : precision_(precision), scale_(scale) {} + + DecimalType(const std::string& precision, const std::string& scale) + : precision_(std::make_shared(precision)), + scale_(std::make_shared(scale)) {} + + DecimalType(const int precision, const int scale) + : DecimalType(std::to_string(precision), std::to_string(scale)) {} + + bool isSameAs(const std::shared_ptr& other) const override; + + std::string signature() const override; + + std::string precision() const { + return precision_->value(); + } + + std::string scale() const { + return scale_->value(); + } + + private: + StringLiteralTypePtr precision_; + StringLiteralTypePtr scale_; +}; + +class FixedBinaryType : public TypeBase { + public: + FixedBinaryType(const StringLiteralTypePtr& length) : length_(length) {} + + FixedBinaryType(const int length) + : FixedBinaryType( + std::make_shared(std::to_string(length))) {} + + bool isSameAs(const std::shared_ptr& other) const override; + + const StringLiteralTypePtr& length() const { + return length_; + } + + std::string signature() const override; + + protected: + StringLiteralTypePtr length_; +}; + +class FixedCharType : public TypeBase { + public: + FixedCharType(const StringLiteralTypePtr& length) : length_(length) {} + + FixedCharType(const int length) + : FixedCharType( + std::make_shared(std::to_string(length))) {} + + bool isSameAs(const std::shared_ptr& other) const override; + + const StringLiteralTypePtr& length() const { + return length_; + } + + std::string signature() const override; + + protected: + StringLiteralTypePtr length_; +}; + +class VarcharType : public TypeBase { + public: + VarcharType(const StringLiteralTypePtr& length) : length_(length) {} + + VarcharType(const int length) + : VarcharType( + std::make_shared(std::to_string(length))) {} + + bool isSameAs(const std::shared_ptr& other) const override; + + const StringLiteralTypePtr& length() const { + return length_; + } + + std::string signature() const override; + + protected: + StringLiteralTypePtr length_; +}; + +class ListType : public TypeBase { + public: + ListType(const TypePtr& elementType) : elementType_(elementType){}; + + const TypePtr elementType() const { + return elementType_; + } + + bool isSameAs(const std::shared_ptr& other) const override; + + std::string signature() const override; + + private: + TypePtr elementType_; +}; + +class StructType : public TypeBase { + public: + StructType(const std::vector& types) : children_(types) {} + + bool isSameAs(const std::shared_ptr& other) const override; + + std::string signature() const override; + + const std::vector& children() const { + return children_; + } + + private: + std::vector children_; +}; + +class MapType : public TypeBase { + public: + MapType(const TypePtr& keyType, const TypePtr& valueType) + : keyType_(keyType), valueType_(valueType) {} + + TypePtr keyType() const { + return keyType_; + } + + TypePtr valueType() const { + return valueType_; + } + + bool isSameAs(const std::shared_ptr& other) const override; + + std::string signature() const override; + + private: + TypePtr keyType_; + TypePtr valueType_; +}; + +class UsedDefinedType : public TypeBase { + public: + UsedDefinedType(const std::string& value) : value_(value) {} + + const std::string& value() const { + return value_; + } + + bool isSameAs(const std::shared_ptr& other) const override; + + private: + /// raw string of wildcard type. + const std::string value_; +}; + +std::shared_ptr> BOOL(); +std::shared_ptr> TINYINT(); +std::shared_ptr> SMALLINT(); +std::shared_ptr> INTEGER(); +std::shared_ptr> BIGINT(); +std::shared_ptr> FLOAT(); +std::shared_ptr> DOUBLE(); +std::shared_ptr> STRING(); +std::shared_ptr> BINARY(); +std::shared_ptr> TIMESTAMP(); +std::shared_ptr> DATE(); +std::shared_ptr> TIME(); +std::shared_ptr> INTERVAL_YEAR(); +std::shared_ptr> INTERVAL_DAY(); +std::shared_ptr> TIMESTAMP_TZ(); +std::shared_ptr> UUID(); + +} // namespace io::substrait diff --git a/src/core/tests/CMakeLists.txt b/src/core/tests/CMakeLists.txt new file mode 100644 index 00000000..76dd9afc --- /dev/null +++ b/src/core/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable( + substrait_core_test + TypeTest.cpp + FunctionLookupTest.cpp) + +add_test( + substrait_core_test + substrait_core_test) + +target_link_libraries( + substrait_core_test + substrait-core + gtest + gtest_main) diff --git a/src/core/tests/FunctionLookupTest.cpp b/src/core/tests/FunctionLookupTest.cpp new file mode 100644 index 00000000..29376e68 --- /dev/null +++ b/src/core/tests/FunctionLookupTest.cpp @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/FunctionLookup.h" +#include +#include "iostream" + +using namespace io::substrait; + +class VeloxFunctionMappings : public FunctionMapping { + public: + /// scalar function names in difference between velox and Substrait. + const FunctionMap& scalaMapping() const override { + static const FunctionMap scalarMappings{ + {"plus", "add"}, + {"minus", "subtract"}, + {"mod", "modulus"}, + {"eq", "equal"}, + {"neq", "not_equal"}, + {"substr", "substring"}, + }; + return scalarMappings; + }; +}; + +class FunctionLookupTest : public ::testing::Test { + protected: + std::string getExtensionAbsolutePath() { + const std::string absolute_path = __FILE__; + auto const pos = absolute_path.find_last_of('/'); + return absolute_path.substr(0, pos) + + "/../../../third_party/substrait/extensions/"; + } + + void SetUp() override { + ExtensionPtr extension_ = Extension::load(getExtensionAbsolutePath()); + FunctionMappingPtr mappings_ = + std::make_shared(); + scalarFunctionLookup_ = + std::make_shared(extension_, mappings_); + aggregateFunctionLookup_ = + std::make_shared(extension_, mappings_); + } + + void testScalarFunctionLookup( + const FunctionSignature& inputSignature, + const std::string& outputSignature) { + const auto& functionVariant = + scalarFunctionLookup_->lookupFunction(inputSignature); + + ASSERT_TRUE(functionVariant != nullptr); + ASSERT_EQ(functionVariant->signature(), outputSignature); + } + + void testAggregateFunctionLookup( + const FunctionSignature& inputSignature, + const std::string& outputSignature) { + const auto& functionVariant = + aggregateFunctionLookup_->lookupFunction(inputSignature); + + ASSERT_TRUE(functionVariant != nullptr); + ASSERT_EQ(functionVariant->signature(), outputSignature); + } + + private: + FunctionLookupPtr scalarFunctionLookup_; + FunctionLookupPtr aggregateFunctionLookup_; +}; + +TEST_F(FunctionLookupTest, compare_function) { + testScalarFunctionLookup( + {"lt", {TINYINT(), TINYINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {SMALLINT(), SMALLINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {INTEGER(), INTEGER()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {BIGINT(), BIGINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup({"lt", {FLOAT(), FLOAT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {DOUBLE(), DOUBLE()}, BOOL()}, "lt:any1_any1"); + testScalarFunctionLookup( + {"between", {TINYINT(), TINYINT(), TINYINT()}, BOOL()}, + "between:any1_any1_any1"); +} + +TEST_F(FunctionLookupTest, arithmetic_function) { + testScalarFunctionLookup( + {"add", {TINYINT(), TINYINT()}, TINYINT()}, "add:opt_i8_i8"); + + testScalarFunctionLookup( + {"plus", {TINYINT(), TINYINT()}, TINYINT()}, "add:opt_i8_i8"); + testScalarFunctionLookup( + {"divide", + { + FLOAT(), + FLOAT(), + }, + FLOAT()}, + "divide:opt_opt_opt_fp32_fp32"); +} + +TEST_F(FunctionLookupTest, aggregate) { + // for intermediate type + testAggregateFunctionLookup( + {"avg", {Type::decode("struct")}, FLOAT()}, "avg:opt_fp32"); +} + +TEST_F(FunctionLookupTest, logical) { + testScalarFunctionLookup({"and", {BOOL(), BOOL()}, BOOL()}, "and:bool"); + testScalarFunctionLookup({"or", {BOOL(), BOOL()}, BOOL()}, "or:bool"); + testScalarFunctionLookup({"not", {BOOL()}, BOOL()}, "not:bool"); + testScalarFunctionLookup({"xor", {BOOL(), BOOL()}, BOOL()}, "xor:bool_bool"); +} + +TEST_F(FunctionLookupTest, string_function) { + testScalarFunctionLookup( + {"like", {STRING(), STRING()}, BOOL()}, "like:opt_str_str"); + testScalarFunctionLookup( + {"like", + {Type::decode("varchar"), Type::decode("varchar")}, + BOOL()}, + "like:opt_vchar_vchar"); + testScalarFunctionLookup( + {"substr", {STRING(), INTEGER(), INTEGER()}, STRING()}, + "substring:str_i32_i32"); +} diff --git a/src/core/tests/TypeTest.cpp b/src/core/tests/TypeTest.cpp new file mode 100644 index 00000000..3c5c3106 --- /dev/null +++ b/src/core/tests/TypeTest.cpp @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../Type.h" +#include + +using namespace io::substrait; + +class TypeTest : public ::testing::Test { + protected: + template + void testDecode(const std::string& rawType, const std::string& signature) { + const auto& type = Type::decode(rawType); + ASSERT_TRUE(type->kind() == kind); + ASSERT_EQ(type->signature(), signature); + } + + template + void testDecode( + const std::string& rawType, + const std::function&)>& + typeCallBack) { + const auto& type = Type::decode(rawType); + if (typeCallBack) { + typeCallBack(std::dynamic_pointer_cast(type)); + } + } +}; + +TEST_F(TypeTest, decodeTest) { + testDecode("i32?", "i32"); + testDecode("BOOLEAN", "bool"); + testDecode("boolean", "bool"); + testDecode("i8", "i8"); + testDecode("i16", "i16"); + testDecode("i32", "i32"); + testDecode("i64", "i64"); + testDecode("fp32", "fp32"); + testDecode("fp64", "fp64"); + testDecode("binary", "vbin"); + testDecode("timestamp", "ts"); + testDecode("date", "date"); + testDecode("time", "time"); + testDecode("interval_day", "iday"); + testDecode("interval_year", "iyear"); + testDecode("timestamp_tz", "tstz"); + testDecode("uuid", "uuid"); + + testDecode( + "fixedchar", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->length()->value(), "L1"); + ASSERT_EQ(typePtr->signature(), "fchar"); + }); + + testDecode( + "fixedbinary", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->length()->value(), "L1"); + ASSERT_EQ(typePtr->signature(), "fbin"); + }); + + testDecode( + "varchar", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "vchar"); + ASSERT_EQ(typePtr->length()->value(), "L1"); + }); + + testDecode( + "decimal", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "dec"); + ASSERT_EQ(typePtr->precision(), "P"); + ASSERT_EQ(typePtr->scale(), "S"); + }); + + testDecode( + "struct", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "struct"); + }); + + testDecode( + "struct>", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "struct>"); + }); + + testDecode( + "list", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "list"); + }); + + testDecode( + "map", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "map"); + }); + + testDecode( + "any1", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "any1"); + ASSERT_TRUE(typePtr->isWildcard()); + }); + + testDecode( + "any", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "any"); + ASSERT_TRUE(typePtr->isWildcard()); + }); + + testDecode( + "T", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "T"); + ASSERT_TRUE(typePtr->isWildcard()); + }); + + testDecode( + "unknown", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "u!name"); + }); +} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 00000000..59405e6a --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(fmt) +include_directories(fmt/include) + +add_subdirectory(googletest) + +set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "Enable testing") +include_directories(yaml-cpp/include) +add_subdirectory(yaml-cpp) diff --git a/third_party/fmt b/third_party/fmt new file mode 160000 index 00000000..80f8d344 --- /dev/null +++ b/third_party/fmt @@ -0,0 +1 @@ +Subproject commit 80f8d34427d40ec5e7ce3b10ededc46bd4bd5759 diff --git a/third_party/googletest b/third_party/googletest new file mode 160000 index 00000000..3026483a --- /dev/null +++ b/third_party/googletest @@ -0,0 +1 @@ +Subproject commit 3026483ae575e2de942db5e760cf95e973308dd5 diff --git a/third_party/substrait b/third_party/substrait new file mode 160000 index 00000000..f3f6bdc9 --- /dev/null +++ b/third_party/substrait @@ -0,0 +1 @@ +Subproject commit f3f6bdc947e689e800279666ff33f118e42d2146 diff --git a/third_party/yaml-cpp b/third_party/yaml-cpp new file mode 160000 index 00000000..c90c08cc --- /dev/null +++ b/third_party/yaml-cpp @@ -0,0 +1 @@ +Subproject commit c90c08ccc9a08abcca609064fb9a856dfdbbb7b4