diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt new file mode 100644 index 00000000000..61d54f7eeb6 --- /dev/null +++ b/cpp/src/gandiva/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (C) 2017-2018 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.10) + +project(gandiva) + +# LLVM/Clang is required by multiple subdirs. +find_package(LLVM) + +# Set the path where the byte-code files will be installed. +set(GANDIVA_BC_INSTALL_DIR + ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/gandiva) + +set(GANDIVA_BC_FILE_NAME irhelpers.bc) +set(GANDIVA_BC_INSTALL_PATH ${GANDIVA_BC_INSTALL_DIR}/${GANDIVA_BC_FILE_NAME}) +set(GANDIVA_BC_OUTPUT_PATH ${CMAKE_BINARY_DIR}/${GANDIVA_BC_FILE_NAME}) + +# Set the path where the so lib file will be installed. +if (APPLE) + set(GANDIVA_HELPER_LIB_FILE_NAME libgandiva_helpers.dylib) +else() + set(GANDIVA_HELPER_LIB_FILE_NAME libgandiva_helpers.so) +endif(APPLE) + +set(GANDIVA_HELPER_LIB_INSTALL_PATH ${GANDIVA_BC_INSTALL_DIR}/${GANDIVA_HELPER_LIB_FILE_NAME}) +set(GANDIVA_HELPER_LIB_OUTPUT_PATH ${CMAKE_BINARY_DIR}/src/codegen/${GANDIVA_HELPER_LIB_FILE_NAME}) + +add_subdirectory(codegen) +add_subdirectory(jni) +add_subdirectory(precompiled) diff --git a/cpp/src/gandiva/README.md b/cpp/src/gandiva/README.md new file mode 100644 index 00000000000..b8faf584f6e --- /dev/null +++ b/cpp/src/gandiva/README.md @@ -0,0 +1,68 @@ +# Gandiva C++ + +## System setup + +Gandiva uses CMake as a build configuration system. Currently, it supports +out-of-source builds only. + +Build Gandiva requires: + +* A C++11-enabled compiler. On Linux, gcc 4.8 and higher should be sufficient. +* CMake +* LLVM +* Arrow +* Boost +* Protobuf + +On macOS, you can use [Homebrew][1]: + +```shell +brew install cmake llvm boost protobuf +``` + +To install arrow, follow the steps in the [arrow Readme][2]. +## Building Gandiva + +Debug build : + +```shell +git clone https://github.com/dremio/gandiva.git +cd gandiva/cpp +mkdir debug +cd debug +cmake .. +make +ctest +``` + +Release build : + +```shell +git clone https://github.com/dremio/gandiva.git +cd gandiva/cpp +mkdir release +cd release +cmake .. -DCMAKE_BUILD_TYPE=Release +make +ctest +``` + +## Validating code style + +We follow the [google cpp code style][3]. To validate compliance, + +```shell +cd debug +make stylecheck +``` + +## Fixing code style + +```shell +cd debug +make stylefix +``` + +[1]: https://brew.sh/ +[2]: https://github.com/apache/arrow/tree/master/cpp +[3]: https://google.github.io/styleguide/cppguide.html diff --git a/cpp/src/gandiva/codegen/CMakeLists.txt b/cpp/src/gandiva/codegen/CMakeLists.txt new file mode 100644 index 00000000000..0bc760acad4 --- /dev/null +++ b/cpp/src/gandiva/codegen/CMakeLists.txt @@ -0,0 +1,125 @@ +# Copyright (C) 2017-2018 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project(gandiva) + +# Find arrow +find_package(ARROW) + +find_package(Boost COMPONENTS system regex filesystem REQUIRED) + +set(BC_FILE_PATH_CC "${CMAKE_CURRENT_BINARY_DIR}/bc_file_path.cc") +configure_file(bc_file_path.cc.in ${BC_FILE_PATH_CC}) + +# helper files that are shared between libgandiva and libgandiva_helpers +set(SHARED_HELPER_FILES + like_holder.cc + regex_util.cc) + +set(SRC_FILES annotator.cc + bitmap_accumulator.cc + configuration.cc + engine.cc + expr_decomposer.cc + expr_validator.cc + expression.cc + expression_registry.cc + filter.cc + function_registry.cc + function_signature.cc + llvm_generator.cc + llvm_types.cc + projector.cc + selection_vector.cc + tree_expr_builder.cc + ${SHARED_HELPER_FILES} + ${BC_FILE_PATH_CC}) + +add_library(gandiva_obj_lib OBJECT ${SRC_FILES}) + +# set PIC so that object library can be included in shared libs. +set_target_properties(gandiva_obj_lib PROPERTIES POSITION_INDEPENDENT_CODE 1) + +# For users of gandiva library (including integ tests), include-dir is : +# /usr/**/include dir after install, +# cpp/include during build +# For building gandiva library itself, include-dir (in addition to above) is : +# cpp/src +target_include_directories(gandiva_obj_lib + PUBLIC + $ + $ + PRIVATE + ${CMAKE_SOURCE_DIR}/src + $ + $ + $ + $ +) + +build_gandiva_lib("shared") + +build_gandiva_lib("static") + +# install for gandiva +include(GNUInstallDirs) + +# install libgandiva +install( + TARGETS gandiva_shared gandiva_static + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# install the header files. +install( + DIRECTORY ${CMAKE_SOURCE_DIR}/include/gandiva + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + +# Pre-compiled .so library for function helpers. +add_library(gandiva_helpers SHARED + ${SHARED_HELPER_FILES} + function_holder_stubs.cc) + +target_compile_definitions(gandiva_helpers + PRIVATE -DGDV_HELPERS +) + +target_include_directories(gandiva_helpers + PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/src + $ +) + +target_link_libraries(gandiva_helpers PRIVATE Boost::boost) +if (NOT APPLE) + target_link_libraries(gandiva_helpers LINK_PRIVATE -static-libstdc++ -static-libgcc) +endif() + +#args: label test-file src-files +add_gandiva_unit_test(bitmap_accumulator_test.cc bitmap_accumulator.cc) +add_gandiva_unit_test(engine_llvm_test.cc engine.cc llvm_types.cc configuration.cc ${BC_FILE_PATH_CC}) +add_gandiva_unit_test(function_signature_test.cc function_signature.cc) +add_gandiva_unit_test(function_registry_test.cc function_registry.cc function_signature.cc) +add_gandiva_unit_test(llvm_types_test.cc llvm_types.cc) +add_gandiva_unit_test(llvm_generator_test.cc llvm_generator.cc regex_util.cc engine.cc llvm_types.cc expr_decomposer.cc function_registry.cc annotator.cc bitmap_accumulator.cc configuration.cc function_signature.cc like_holder.cc regex_util.cc ${BC_FILE_PATH_CC}) +add_gandiva_unit_test(annotator_test.cc annotator.cc function_signature.cc) +add_gandiva_unit_test(tree_expr_test.cc tree_expr_builder.cc expr_decomposer.cc annotator.cc function_registry.cc function_signature.cc like_holder.cc regex_util.cc) +add_gandiva_unit_test(expr_decomposer_test.cc expr_decomposer.cc tree_expr_builder.cc annotator.cc function_registry.cc function_signature.cc like_holder.cc regex_util.cc) +add_gandiva_unit_test(status_test.cc) +add_gandiva_unit_test(expression_registry_test.cc llvm_types.cc expression_registry.cc function_signature.cc function_registry.cc) +add_gandiva_unit_test(selection_vector_test.cc selection_vector.cc) +add_gandiva_unit_test(lru_cache_test.cc) +add_gandiva_unit_test(like_holder_test.cc like_holder.cc regex_util.cc) diff --git a/cpp/src/gandiva/codegen/annotator.cc b/cpp/src/gandiva/codegen/annotator.cc new file mode 100644 index 00000000000..afd0e269b51 --- /dev/null +++ b/cpp/src/gandiva/codegen/annotator.cc @@ -0,0 +1,105 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/annotator.h" + +#include +#include + +#include "codegen/field_descriptor.h" + +namespace gandiva { + +FieldDescriptorPtr Annotator::CheckAndAddInputFieldDescriptor(FieldPtr field) { + // If the field is already in the map, return the entry. + auto found = in_name_to_desc_.find(field->name()); + if (found != in_name_to_desc_.end()) { + return found->second; + } + + auto desc = MakeDesc(field); + in_name_to_desc_[field->name()] = desc; + return desc; +} + +FieldDescriptorPtr Annotator::AddOutputFieldDescriptor(FieldPtr field) { + auto desc = MakeDesc(field); + out_descs_.push_back(desc); + return desc; +} + +FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field) { + // TODO: + // - validity is optional + int data_idx = buffer_count_++; + int validity_idx = buffer_count_++; + int offsets_idx = FieldDescriptor::kInvalidIdx; + if (arrow::is_binary_like(field->type()->id())) { + offsets_idx = buffer_count_++; + } + return std::make_shared(field, data_idx, validity_idx, offsets_idx); +} + +void Annotator::PrepareBuffersForField(const FieldDescriptor &desc, + const arrow::ArrayData &array_data, + EvalBatch *eval_batch) { + int buffer_idx = 0; + + // TODO: + // - validity is optional + + uint8_t *validity_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.validity_idx(), validity_buf); + ++buffer_idx; + + if (desc.HasOffsetsIdx()) { + uint8_t *offsets_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf); + ++buffer_idx; + } + + uint8_t *data_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf); + ++buffer_idx; +} + +EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch &record_batch, + const ArrayDataVector &out_vector) { + EvalBatchPtr eval_batch = std::make_shared( + record_batch.num_rows(), buffer_count_, local_bitmap_count_); + + // Fill in the entries for the input fields. + for (int i = 0; i < record_batch.num_columns(); ++i) { + const std::string &name = record_batch.column_name(i); + auto found = in_name_to_desc_.find(name); + if (found == in_name_to_desc_.end()) { + // skip columns not involved in the expression. + continue; + } + + PrepareBuffersForField(*(found->second), *(record_batch.column(i))->data(), + eval_batch.get()); + } + + // Fill in the entries for the output fields. + int idx = 0; + for (auto &arraydata : out_vector) { + const FieldDescriptorPtr &desc = out_descs_.at(idx); + PrepareBuffersForField(*desc, *arraydata, eval_batch.get()); + ++idx; + } + return eval_batch; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/annotator.h b/cpp/src/gandiva/codegen/annotator.h new file mode 100644 index 00000000000..7a363c98410 --- /dev/null +++ b/cpp/src/gandiva/codegen/annotator.h @@ -0,0 +1,77 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_ANNOTATOR_H +#define GANDIVA_EXPR_ANNOTATOR_H + +#include +#include +#include +#include + +#include "codegen/eval_batch.h" +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/logging.h" + +namespace gandiva { + +/// \brief annotate the arrow fields in an expression, and use that +/// to convert the incoming arrow-format row batch to an EvalBatch. +class Annotator { + public: + Annotator() : buffer_count_(0), local_bitmap_count_(0) {} + + /// Add an annotated field descriptor for a field in an input schema. + /// If the field is already annotated, returns that instead. + FieldDescriptorPtr CheckAndAddInputFieldDescriptor(FieldPtr field); + + /// Add an annotated field descriptor for an output field. + FieldDescriptorPtr AddOutputFieldDescriptor(FieldPtr field); + + /// Add a local bitmap (for saving validity bits of an intermediate node). + /// Returns the index of the bitmap in the list of local bitmaps. + int AddLocalBitMap() { return local_bitmap_count_++; } + + /// Prepare an eval batch for the incoming record batch. + EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch &record_batch, + const ArrayDataVector &out_vector); + + private: + /// Annotate a field and return the descriptor. + FieldDescriptorPtr MakeDesc(FieldPtr field); + + /// Populate eval_batch by extracting the raw buffers from the arrow array, whose + /// contents are represent by the annotated descriptor 'desc'. + void PrepareBuffersForField(const FieldDescriptor &desc, + const arrow::ArrayData &array_data, EvalBatch *eval_batch); + + /// The list of input/output buffers (includes bitmap buffers, value buffers and + /// offset buffers). + int buffer_count_; + + /// The number of local bitmaps. These are used to save the validity bits for + /// intermediate nodes in the expression tree. + int local_bitmap_count_; + + /// map between field name and annotated input field descriptor. + std::unordered_map in_name_to_desc_; + + /// vector of annotated output field descriptors. + std::vector out_descs_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_ANNOTATOR_H diff --git a/cpp/src/gandiva/codegen/annotator_test.cc b/cpp/src/gandiva/codegen/annotator_test.cc new file mode 100644 index 00000000000..b26807e9c5b --- /dev/null +++ b/cpp/src/gandiva/codegen/annotator_test.cc @@ -0,0 +1,99 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/annotator.h" + +#include + +#include +#include +#include "codegen/field_descriptor.h" + +namespace gandiva { + +class TestAnnotator : public ::testing::Test { + protected: + ArrayPtr MakeInt32Array(int length); +}; + +ArrayPtr TestAnnotator::MakeInt32Array(int length) { + arrow::Status status; + + std::shared_ptr validity; + status = + arrow::AllocateBuffer(arrow::default_memory_pool(), (length + 63) / 8, &validity); + DCHECK_EQ(status.ok(), true); + + std::shared_ptr value; + status = AllocateBuffer(arrow::default_memory_pool(), length * sizeof(int32_t), &value); + DCHECK_EQ(status.ok(), true); + + auto array_data = arrow::ArrayData::Make(arrow::int32(), length, {validity, value}); + return arrow::MakeArray(array_data); +} + +TEST_F(TestAnnotator, TestAdd) { + Annotator annotator; + + auto field_a = arrow::field("a", arrow::int32()); + auto field_b = arrow::field("b", arrow::int32()); + auto in_schema = arrow::schema({field_a, field_b}); + auto field_sum = arrow::field("sum", arrow::int32()); + + FieldDescriptorPtr desc_a = annotator.CheckAndAddInputFieldDescriptor(field_a); + EXPECT_EQ(desc_a->field(), field_a); + EXPECT_EQ(desc_a->data_idx(), 0); + EXPECT_EQ(desc_a->validity_idx(), 1); + + // duplicate add shouldn't cause a new descriptor. + FieldDescriptorPtr dup = annotator.CheckAndAddInputFieldDescriptor(field_a); + EXPECT_EQ(dup, desc_a); + EXPECT_EQ(dup->validity_idx(), desc_a->validity_idx()); + + FieldDescriptorPtr desc_b = annotator.CheckAndAddInputFieldDescriptor(field_b); + EXPECT_EQ(desc_b->field(), field_b); + EXPECT_EQ(desc_b->data_idx(), 2); + EXPECT_EQ(desc_b->validity_idx(), 3); + + FieldDescriptorPtr desc_sum = annotator.AddOutputFieldDescriptor(field_sum); + EXPECT_EQ(desc_sum->field(), field_sum); + EXPECT_EQ(desc_sum->data_idx(), 4); + EXPECT_EQ(desc_sum->validity_idx(), 5); + + // prepare record batch + int num_records = 100; + auto arrow_v0 = MakeInt32Array(num_records); + auto arrow_v1 = MakeInt32Array(num_records); + + // prepare input record batch + auto record_batch = + arrow::RecordBatch::Make(in_schema, num_records, {arrow_v0, arrow_v1}); + + auto arrow_sum = MakeInt32Array(num_records); + EvalBatchPtr batch = annotator.PrepareEvalBatch(*record_batch, {arrow_sum->data()}); + EXPECT_EQ(batch->GetNumBuffers(), 6); + + auto buffers = batch->GetBufferArray(); + EXPECT_EQ(buffers[desc_a->validity_idx()], arrow_v0->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_a->data_idx()], arrow_v0->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_b->validity_idx()], arrow_v1->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_b->data_idx()], arrow_v1->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_sum->validity_idx()], arrow_sum->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_sum->data_idx()], arrow_sum->data()->buffers.at(1)->data()); + + auto bitmaps = batch->GetLocalBitMapArray(); + EXPECT_EQ(bitmaps, nullptr); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/arrow.h b/cpp/src/gandiva/codegen/arrow.h new file mode 100644 index 00000000000..4a88970c994 --- /dev/null +++ b/cpp/src/gandiva/codegen/arrow.h @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_EXPR_ARROW_H +#define GANDIVA_EXPR_ARROW_H + +#include +#include + +#include +#include +#include + +namespace gandiva { + +using ArrayPtr = std::shared_ptr; + +using DataTypePtr = std::shared_ptr; +using DataTypeVector = std::vector; + +using FieldPtr = std::shared_ptr; +using FieldVector = std::vector; + +using RecordBatchPtr = std::shared_ptr; + +using SchemaPtr = std::shared_ptr; + +using ArrayDataPtr = std::shared_ptr; +using ArrayDataVector = std::vector; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_ARROW_H diff --git a/cpp/src/gandiva/codegen/bc_file_path.cc.in b/cpp/src/gandiva/codegen/bc_file_path.cc.in new file mode 100644 index 00000000000..bcf6fd9771f --- /dev/null +++ b/cpp/src/gandiva/codegen/bc_file_path.cc.in @@ -0,0 +1,23 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace gandiva { + +// Path to the byte-code file. +extern const char kByteCodeFilePath[] = "${GANDIVA_BC_OUTPUT_PATH}"; + +// Path to the pre-compiled solib file. +extern const char kHelperLibFilePath[] = "${GANDIVA_HELPER_LIB_OUTPUT_PATH}"; + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/bitmap_accumulator.cc b/cpp/src/gandiva/codegen/bitmap_accumulator.cc new file mode 100644 index 00000000000..b4e7a496f9e --- /dev/null +++ b/cpp/src/gandiva/codegen/bitmap_accumulator.cc @@ -0,0 +1,80 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/bitmap_accumulator.h" + +#include + +namespace gandiva { + +void BitMapAccumulator::ComputeResult(uint8_t *dst_bitmap) { + int num_records = eval_batch_.num_records(); + + if (all_invalid_) { + // set all bits to 0. + memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_records)); + } else { + IntersectBitMaps(dst_bitmap, src_maps_, num_records); + } +} + +/// Compute the intersection of multiple bitmaps. +void BitMapAccumulator::IntersectBitMaps(uint8_t *dst_map, + const std::vector &src_maps, + int num_records) { + uint64_t *dst_map64 = reinterpret_cast(dst_map); + int num_words = (num_records + 63) / 64; // aligned to 8-byte. + int num_bytes = num_words * 8; + int nmaps = src_maps.size(); + + switch (nmaps) { + case 0: { + // no src_maps_ bitmap. simply set all bits + memset(dst_map, 0xff, num_bytes); + break; + } + + case 1: { + // one src_maps_ bitmap. copy to dst_map + memcpy(dst_map, src_maps[0], num_bytes); + break; + } + + case 2: { + // two src_maps bitmaps. do 64-bit ANDs + uint64_t *src_maps0_64 = reinterpret_cast(src_maps[0]); + uint64_t *src_maps1_64 = reinterpret_cast(src_maps[1]); + for (int i = 0; i < num_words; ++i) { + dst_map64[i] = src_maps0_64[i] & src_maps1_64[i]; + } + break; + } + + default: { + /* > 2 src_maps bitmaps. do 64-bit ANDs */ + uint64_t *src_maps0_64 = reinterpret_cast(src_maps[0]); + memcpy(dst_map64, src_maps0_64, num_bytes); + for (int m = 1; m < nmaps; ++m) { + for (int i = 0; i < num_words; ++i) { + uint64_t *src_mapsm_64 = reinterpret_cast(src_maps[m]); + dst_map64[i] &= src_mapsm_64[i]; + } + } + + break; + } + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/bitmap_accumulator.h b/cpp/src/gandiva/codegen/bitmap_accumulator.h new file mode 100644 index 00000000000..94c31092d2e --- /dev/null +++ b/cpp/src/gandiva/codegen/bitmap_accumulator.h @@ -0,0 +1,70 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_BITMAP_ACCUMULATOR_H +#define GANDIVA_BITMAP_ACCUMULATOR_H + +#include + +#include "codegen/dex.h" +#include "codegen/dex_visitor.h" +#include "codegen/eval_batch.h" + +namespace gandiva { + +/// \brief Extract bitmap buffer from either the input/buffer vectors or the +/// local validity bitmap, and accumultes them to do the final computation. +class BitMapAccumulator : public DexDefaultVisitor { + public: + explicit BitMapAccumulator(const EvalBatch &eval_batch) + : eval_batch_(eval_batch), all_invalid_(false) {} + + void Visit(const VectorReadValidityDex &dex) { + int idx = dex.ValidityIdx(); + auto bitmap = eval_batch_.GetBuffer(idx); + src_maps_.push_back(bitmap); + } + + void Visit(const LocalBitMapValidityDex &dex) { + int idx = dex.local_bitmap_idx(); + auto bitmap = eval_batch_.GetLocalBitMap(idx); + src_maps_.push_back(bitmap); + } + + void Visit(const TrueDex &dex) { + // bitwise-and with 1 is always 1. so, ignore. + } + + void Visit(const FalseDex &dex) { + // The final result is "all 0s". + all_invalid_ = true; + } + + /// Compute the dst_bmap based on the contents and type of the accumulated bitmap dex. + void ComputeResult(uint8_t *dst_bitmap); + + /// Compute the intersection of the accumulated bitmaps and save the result in + /// dst_bmap. + static void IntersectBitMaps(uint8_t *dst_map, const std::vector &src_maps, + int num_records); + + private: + const EvalBatch &eval_batch_; + std::vector src_maps_; + bool all_invalid_; +}; + +} // namespace gandiva + +#endif // GANDIVA_BITMAP_ACCUMULATOR_H diff --git a/cpp/src/gandiva/codegen/bitmap_accumulator_test.cc b/cpp/src/gandiva/codegen/bitmap_accumulator_test.cc new file mode 100644 index 00000000000..1d81dcdb89e --- /dev/null +++ b/cpp/src/gandiva/codegen/bitmap_accumulator_test.cc @@ -0,0 +1,82 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/bitmap_accumulator.h" + +#include +#include + +#include +#include "codegen/dex.h" + +namespace gandiva { + +class TestBitMapAccumulator : public ::testing::Test { + protected: + void FillBitMap(uint8_t *bmap, int nrecords); + void ByteWiseIntersectBitMaps(uint8_t *dst, const std::vector &srcs, + int nrecords); +}; + +void TestBitMapAccumulator::FillBitMap(uint8_t *bmap, int nrecords) { + int nbytes = nrecords / 8; + unsigned int cur; + + for (int i = 0; i < nbytes; ++i) { + rand_r(&cur); + bmap[i] = cur % UINT8_MAX; + } +} + +void TestBitMapAccumulator::ByteWiseIntersectBitMaps(uint8_t *dst, + const std::vector &srcs, + int nrecords) { + int nbytes = nrecords / 8; + for (int i = 0; i < nbytes; ++i) { + dst[i] = 0xff; + for (uint32_t j = 0; j < srcs.size(); ++j) { + dst[i] &= srcs[j][i]; + } + } +} + +TEST_F(TestBitMapAccumulator, TestIntersectBitMaps) { + const int length = 128; + const int nrecords = length * 8; + uint8_t src_bitmaps[4][length]; + uint8_t dst_bitmap[length]; + uint8_t expected_bitmap[length]; + + for (int i = 0; i < 4; i++) { + FillBitMap(src_bitmaps[i], nrecords); + } + + for (int i = 0; i < 4; i++) { + std::vector src_bitmap_ptrs; + for (int j = 0; j < i; ++j) { + src_bitmap_ptrs.push_back(src_bitmaps[j]); + } + + BitMapAccumulator::IntersectBitMaps(dst_bitmap, src_bitmap_ptrs, nrecords); + ByteWiseIntersectBitMaps(expected_bitmap, src_bitmap_ptrs, nrecords); + EXPECT_EQ(memcmp(dst_bitmap, expected_bitmap, length), 0); + } +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/cache.h b/cpp/src/gandiva/codegen/cache.h new file mode 100644 index 00000000000..2aa8354a5dd --- /dev/null +++ b/cpp/src/gandiva/codegen/cache.h @@ -0,0 +1,48 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_MODULE_CACHE_H +#define GANDIVA_MODULE_CACHE_H + +#include + +#include "codegen/lru_cache.h" + +namespace gandiva { + +template +class Cache { + public: + Cache(size_t capacity = CACHE_SIZE) : cache_(capacity) {} + ValueType GetModule(KeyType cache_key) { + boost::optional result; + mtx_.lock(); + result = cache_.get(cache_key); + mtx_.unlock(); + return result != boost::none ? result.value() : nullptr; + } + + void PutModule(KeyType cache_key, ValueType module) { + mtx_.lock(); + cache_.insert(cache_key, module); + mtx_.unlock(); + } + + private: + LruCache cache_; + static const int CACHE_SIZE = 100; + std::mutex mtx_; +}; +} // namespace gandiva +#endif // GANDIVA_MODULE_CACHE_H diff --git a/cpp/src/gandiva/codegen/compiled_expr.h b/cpp/src/gandiva/codegen/compiled_expr.h new file mode 100644 index 00000000000..6148c8950ce --- /dev/null +++ b/cpp/src/gandiva/codegen/compiled_expr.h @@ -0,0 +1,61 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_COMPILED_EXPR_H +#define GANDIVA_COMPILED_EXPR_H + +#include +#include "codegen/value_validity_pair.h" + +namespace gandiva { + +using EvalFunc = int (*)(uint8_t **buffers, uint8_t **local_bitmaps, int record_count); + +/// \brief Tracks the compiled state for one expression. +class CompiledExpr { + public: + CompiledExpr(ValueValidityPairPtr value_validity, FieldDescriptorPtr output, + llvm::Function *ir_function) + : value_validity_(value_validity), + output_(output), + ir_function_(ir_function), + jit_function_(NULL) {} + + ValueValidityPairPtr value_validity() const { return value_validity_; } + + FieldDescriptorPtr output() const { return output_; } + + llvm::Function *ir_function() const { return ir_function_; } + + EvalFunc jit_function() const { return jit_function_; } + + void set_jit_function(EvalFunc jit_function) { jit_function_ = jit_function; } + + private: + // value & validities for the expression tree (root) + ValueValidityPairPtr value_validity_; + + // output field + FieldDescriptorPtr output_; + + // IR function in the generated code + llvm::Function *ir_function_; + + // JIT function in the generated code (set after the module is optimised and finalized) + EvalFunc jit_function_; +}; + +} // namespace gandiva + +#endif // GANDIVA_COMPILED_EXPR_H diff --git a/cpp/src/gandiva/codegen/condition.h b/cpp/src/gandiva/codegen/condition.h new file mode 100644 index 00000000000..2c5fd51e74b --- /dev/null +++ b/cpp/src/gandiva/codegen/condition.h @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_CONDITION_H +#define GANDIVA_CONDITION_H + +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief A condition expression. +class Condition : public Expression { + public: + Condition(const NodePtr root) + : Expression(root, std::make_shared("cond", arrow::boolean())) {} + + virtual ~Condition() = default; +}; + +} // namespace gandiva + +#endif // GANDIVA_CONDITION_H diff --git a/cpp/src/gandiva/codegen/configuration.cc b/cpp/src/gandiva/codegen/configuration.cc new file mode 100644 index 00000000000..8bf56f8bddd --- /dev/null +++ b/cpp/src/gandiva/codegen/configuration.cc @@ -0,0 +1,37 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/configuration.h" + +#include "boost/functional/hash.hpp" + +namespace gandiva { + +const std::shared_ptr ConfigurationBuilder::default_configuration_ = + InitDefaultConfig(); + +std::size_t Configuration::Hash() const { + boost::hash string_hash; + return string_hash(byte_code_file_path_); +} + +bool Configuration::operator==(const Configuration &other) const { + return other.byte_code_file_path() == byte_code_file_path(); +} + +bool Configuration::operator!=(const Configuration &other) const { + return !(*this == other); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/configuration.h b/cpp/src/gandiva/codegen/configuration.h new file mode 100644 index 00000000000..978cc6282aa --- /dev/null +++ b/cpp/src/gandiva/codegen/configuration.h @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_CONFIGURATION_H +#define GANDIVA_CONFIGURATION_H + +#include +#include + +#include "gandiva/status.h" + +namespace gandiva { + +extern const char kByteCodeFilePath[]; +extern const char kHelperLibFilePath[]; + +class ConfigurationBuilder; +/// \brief runtime config for gandiva +/// +/// It contains elements to customize gandiva execution +/// at run time. +class Configuration { + public: + friend class ConfigurationBuilder; + + const std::string &byte_code_file_path() const { return byte_code_file_path_; } + const std::string &helper_lib_file_path() const { return helper_lib_file_path_; } + + std::size_t Hash() const; + bool operator==(const Configuration &other) const; + bool operator!=(const Configuration &other) const; + + private: + explicit Configuration(const std::string &byte_code_file_path, + const std::string &helper_lib_file_path) + : byte_code_file_path_(byte_code_file_path), + helper_lib_file_path_(helper_lib_file_path) {} + + const std::string byte_code_file_path_; + const std::string helper_lib_file_path_; +}; + +/// \brief configuration builder for gandiva +/// +/// Provides a default configuration and convenience methods +/// to override specific values and build a custom instance +class ConfigurationBuilder { + public: + ConfigurationBuilder() + : byte_code_file_path_(kByteCodeFilePath), + helper_lib_file_path_(kHelperLibFilePath) {} + + ConfigurationBuilder &set_byte_code_file_path(const std::string &byte_code_file_path) { + byte_code_file_path_ = byte_code_file_path; + return *this; + } + + ConfigurationBuilder &set_helper_lib_file_path( + const std::string &helper_lib_file_path) { + helper_lib_file_path_ = helper_lib_file_path; + return *this; + } + + std::shared_ptr build() { + std::shared_ptr configuration( + new Configuration(byte_code_file_path_, helper_lib_file_path_)); + return configuration; + } + + static std::shared_ptr DefaultConfiguration() { + return default_configuration_; + } + + private: + std::string byte_code_file_path_; + std::string helper_lib_file_path_; + + static std::shared_ptr InitDefaultConfig() { + std::shared_ptr configuration( + new Configuration(kByteCodeFilePath, kHelperLibFilePath)); + return configuration; + } + + static const std::shared_ptr default_configuration_; +}; + +} // namespace gandiva +#endif // GANDIVA_CONFIGURATION_H diff --git a/cpp/src/gandiva/codegen/dex.h b/cpp/src/gandiva/codegen/dex.h new file mode 100644 index 00000000000..4484d37db16 --- /dev/null +++ b/cpp/src/gandiva/codegen/dex.h @@ -0,0 +1,274 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_DEX_DEX_H +#define GANDIVA_DEX_DEX_H + +#include +#include + +#include "codegen/dex_visitor.h" +#include "codegen/field_descriptor.h" +#include "codegen/func_descriptor.h" +#include "codegen/function_holder.h" +#include "codegen/literal_holder.h" +#include "codegen/native_function.h" +#include "codegen/value_validity_pair.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief Decomposed expression : the validity and value are separated. +class Dex { + public: + /// Derived classes should simply invoke the Visit api of the visitor. + virtual void Accept(DexVisitor &visitor) = 0; + virtual ~Dex() = default; +}; + +/// Base class for other Vector related Dex. +class VectorReadBaseDex : public Dex { + public: + explicit VectorReadBaseDex(FieldDescriptorPtr field_desc) : field_desc_(field_desc) {} + + const std::string &FieldName() const { return field_desc_->Name(); } + + DataTypePtr FieldType() const { return field_desc_->Type(); } + + FieldPtr Field() const { return field_desc_->field(); } + + protected: + FieldDescriptorPtr field_desc_; +}; + +/// validity component of a ValueVector +class VectorReadValidityDex : public VectorReadBaseDex { + public: + explicit VectorReadValidityDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int ValidityIdx() const { return field_desc_->validity_idx(); } + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// value component of a fixed-len ValueVector +class VectorReadFixedLenValueDex : public VectorReadBaseDex { + public: + explicit VectorReadFixedLenValueDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// value component of a variable-len ValueVector +class VectorReadVarLenValueDex : public VectorReadBaseDex { + public: + explicit VectorReadVarLenValueDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// validity based on a local bitmap. +class LocalBitMapValidityDex : public Dex { + public: + explicit LocalBitMapValidityDex(int local_bitmap_idx) + : local_bitmap_idx_(local_bitmap_idx) {} + + int local_bitmap_idx() const { return local_bitmap_idx_; } + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } + + private: + int local_bitmap_idx_; +}; + +/// base function expression +class FuncDex : public Dex { + public: + FuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction *native_function, + FunctionHolderPtr function_holder, const ValueValidityPairVector &args) + : func_descriptor_(func_descriptor), + native_function_(native_function), + function_holder_(function_holder), + args_(args) {} + + FuncDescriptorPtr func_descriptor() const { return func_descriptor_; } + + const NativeFunction *native_function() const { return native_function_; } + + FunctionHolderPtr function_holder() const { return function_holder_; } + + const ValueValidityPairVector &args() const { return args_; } + + private: + FuncDescriptorPtr func_descriptor_; + const NativeFunction *native_function_; + FunctionHolderPtr function_holder_; + ValueValidityPairVector args_; +}; + +/// A function expression that only deals with non-null inputs, and generates non-null +/// outputs. +class NonNullableFuncDex : public FuncDex { + public: + NonNullableFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction *native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector &args) + : FuncDex(func_descriptor, native_function, function_holder, args) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// A function expression that deals with nullable inputs, but generates non-null +/// outputs. +class NullableNeverFuncDex : public FuncDex { + public: + NullableNeverFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction *native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector &args) + : FuncDex(func_descriptor, native_function, function_holder, args) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// A function expression that deals with nullable inputs, and +/// nullable outputs. +class NullableInternalFuncDex : public FuncDex { + public: + NullableInternalFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction *native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector &args, int local_bitmap_idx) + : FuncDex(func_descriptor, native_function, function_holder, args), + local_bitmap_idx_(local_bitmap_idx) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } + + /// The validity of the function result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + private: + int local_bitmap_idx_; +}; + +/// special validity type that always returns true. +class TrueDex : public Dex { + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// special validity type that always returns false. +class FalseDex : public Dex { + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// decomposed expression for a literal. +class LiteralDex : public Dex { + public: + LiteralDex(DataTypePtr type, const LiteralHolder &holder) + : type_(type), holder_(holder) {} + + const DataTypePtr &type() const { return type_; } + + const LiteralHolder &holder() const { return holder_; } + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } + + private: + DataTypePtr type_; + LiteralHolder holder_; +}; + +/// decomposed if-else expression. +class IfDex : public Dex { + public: + IfDex(ValueValidityPairPtr condition_vv, ValueValidityPairPtr then_vv, + ValueValidityPairPtr else_vv, DataTypePtr result_type, int local_bitmap_idx, + bool is_terminal_else) + : condition_vv_(condition_vv), + then_vv_(then_vv), + else_vv_(else_vv), + result_type_(result_type), + local_bitmap_idx_(local_bitmap_idx), + is_terminal_else_(is_terminal_else) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } + + const ValueValidityPair &condition_vv() const { return *condition_vv_; } + const ValueValidityPair &then_vv() const { return *then_vv_; } + const ValueValidityPair &else_vv() const { return *else_vv_; } + + /// The validity of the result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + /// is this a terminal else ? i.e no nested if-else underneath. + bool is_terminal_else() const { return is_terminal_else_; } + + const DataTypePtr &result_type() const { return result_type_; } + + private: + ValueValidityPairPtr condition_vv_; + ValueValidityPairPtr then_vv_; + ValueValidityPairPtr else_vv_; + DataTypePtr result_type_; + int local_bitmap_idx_; + bool is_terminal_else_; +}; + +// decomposed boolean expression. +class BooleanDex : public Dex { + public: + BooleanDex(const ValueValidityPairVector &args, int local_bitmap_idx) + : args_(args), local_bitmap_idx_(local_bitmap_idx) {} + + const ValueValidityPairVector &args() const { return args_; } + + /// The validity of the result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + private: + ValueValidityPairVector args_; + int local_bitmap_idx_; +}; + +/// Boolean-AND expression +class BooleanAndDex : public BooleanDex { + public: + BooleanAndDex(const ValueValidityPairVector &args, int local_bitmap_idx) + : BooleanDex(args, local_bitmap_idx) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +/// Boolean-OR expression +class BooleanOrDex : public BooleanDex { + public: + BooleanOrDex(const ValueValidityPairVector &args, int local_bitmap_idx) + : BooleanDex(args, local_bitmap_idx) {} + + void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } +}; + +} // namespace gandiva + +#endif // GANDIVA_DEX_DEX_H diff --git a/cpp/src/gandiva/codegen/dex_visitor.h b/cpp/src/gandiva/codegen/dex_visitor.h new file mode 100644 index 00000000000..5beee1b5444 --- /dev/null +++ b/cpp/src/gandiva/codegen/dex_visitor.h @@ -0,0 +1,76 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_DEX_DEXVISITOR_H +#define GANDIVA_DEX_DEXVISITOR_H + +#include "gandiva/logging.h" + +namespace gandiva { + +class VectorReadValidityDex; +class VectorReadFixedLenValueDex; +class VectorReadVarLenValueDex; +class LocalBitMapValidityDex; +class LiteralDex; +class TrueDex; +class FalseDex; +class NonNullableFuncDex; +class NullableNeverFuncDex; +class NullableInternalFuncDex; +class IfDex; +class BooleanAndDex; +class BooleanOrDex; + +/// \brief Visitor for decomposed expression. +class DexVisitor { + public: + virtual void Visit(const VectorReadValidityDex &dex) = 0; + virtual void Visit(const VectorReadFixedLenValueDex &dex) = 0; + virtual void Visit(const VectorReadVarLenValueDex &dex) = 0; + virtual void Visit(const LocalBitMapValidityDex &dex) = 0; + virtual void Visit(const TrueDex &dex) = 0; + virtual void Visit(const FalseDex &dex) = 0; + virtual void Visit(const LiteralDex &dex) = 0; + virtual void Visit(const NonNullableFuncDex &dex) = 0; + virtual void Visit(const NullableNeverFuncDex &dex) = 0; + virtual void Visit(const NullableInternalFuncDex &dex) = 0; + virtual void Visit(const IfDex &dex) = 0; + virtual void Visit(const BooleanAndDex &dex) = 0; + virtual void Visit(const BooleanOrDex &dex) = 0; +}; + +/// Default implementation with only DCHECK(). +#define VISIT_DCHECK(DEX_CLASS) \ + void Visit(const DEX_CLASS &dex) override { DCHECK(0); } + +class DexDefaultVisitor : public DexVisitor { + VISIT_DCHECK(VectorReadValidityDex); + VISIT_DCHECK(VectorReadFixedLenValueDex); + VISIT_DCHECK(VectorReadVarLenValueDex); + VISIT_DCHECK(LocalBitMapValidityDex); + VISIT_DCHECK(TrueDex); + VISIT_DCHECK(FalseDex); + VISIT_DCHECK(LiteralDex); + VISIT_DCHECK(NonNullableFuncDex); + VISIT_DCHECK(NullableNeverFuncDex); + VISIT_DCHECK(NullableInternalFuncDex); + VISIT_DCHECK(IfDex); + VISIT_DCHECK(BooleanAndDex); + VISIT_DCHECK(BooleanOrDex); +}; + +} // namespace gandiva + +#endif // GANDIVA_DEX_DEXVISITOR_H diff --git a/cpp/src/gandiva/codegen/engine.cc b/cpp/src/gandiva/codegen/engine.cc new file mode 100644 index 00000000000..29748a3455f --- /dev/null +++ b/cpp/src/gandiva/codegen/engine.cc @@ -0,0 +1,226 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/engine.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gandiva { + +std::once_flag init_once_flag; + +bool Engine::init_once_done_ = false; +std::set Engine::loaded_libs_ = {}; +std::mutex Engine::mtx_; + +// One-time initializations. +void Engine::InitOnce() { + DCHECK_EQ(init_once_done_, false); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvm::InitializeNativeTargetDisassembler(); + + init_once_done_ = true; +} + +/// factory method to construct the engine. +Status Engine::Make(std::shared_ptr config, + std::unique_ptr *engine) { + std::unique_ptr engine_obj(new Engine()); + + std::call_once(init_once_flag, [&engine_obj] { engine_obj->InitOnce(); }); + engine_obj->context_.reset(new llvm::LLVMContext()); + engine_obj->ir_builder_.reset(new llvm::IRBuilder<>(*(engine_obj->context()))); + + // Create the execution engine + std::unique_ptr cg_module( + new llvm::Module("codegen", *(engine_obj->context()))); + engine_obj->module_ = cg_module.get(); + + llvm::EngineBuilder engineBuilder(std::move(cg_module)); + engineBuilder.setEngineKind(llvm::EngineKind::JIT); + engineBuilder.setOptLevel(llvm::CodeGenOpt::Aggressive); + engineBuilder.setErrorStr(&(engine_obj->llvm_error_)); + engine_obj->execution_engine_.reset(engineBuilder.create()); + if (engine_obj->execution_engine_ == NULL) { + engine_obj->module_ = NULL; + return Status::CodeGenError(engine_obj->llvm_error_); + } + + auto status = engine_obj->LoadPreCompiledHelperLibs(config->helper_lib_file_path()); + GANDIVA_RETURN_NOT_OK(status); + + status = engine_obj->LoadPreCompiledIRFiles(config->byte_code_file_path()); + GANDIVA_RETURN_NOT_OK(status); + + *engine = std::move(engine_obj); + return Status::OK(); +} + +Status Engine::LoadPreCompiledHelperLibs(const std::string &file_path) { + int err = 0; + + mtx_.lock(); + // Load each so lib only once. + if (loaded_libs_.find(file_path) == loaded_libs_.end()) { + err = llvm::sys::DynamicLibrary::LoadLibraryPermanently(file_path.c_str()); + if (!err) { + loaded_libs_.insert(file_path); + } + } + mtx_.unlock(); + + return (err == 0) + ? Status::OK() + : Status::CodeGenError("loading precompiled native file " + file_path + + " failed with error " + std::to_string(err)); +} + +// Handling for pre-compiled IR libraries. +Status Engine::LoadPreCompiledIRFiles(const std::string &byte_code_file_path) { + /// Read from file into memory buffer. + llvm::ErrorOr> buffer_or_error = + llvm::MemoryBuffer::getFile(byte_code_file_path); + if (!buffer_or_error) { + std::stringstream ss; + ss << "Could not load module from IR " << byte_code_file_path << ": " + << buffer_or_error.getError().message(); + return Status::CodeGenError(ss.str()); + } + std::unique_ptr buffer = move(buffer_or_error.get()); + + /// Parse the IR module. + llvm::Expected> module_or_error = + llvm::getOwningLazyBitcodeModule(move(buffer), *context()); + if (!module_or_error) { + std::string error_string; + llvm::handleAllErrors(module_or_error.takeError(), [&](llvm::ErrorInfoBase &eib) { + error_string = eib.message(); + }); + return Status::CodeGenError(error_string); + } + std::unique_ptr ir_module = move(module_or_error.get()); + + /// Verify the IR module + if (llvm::verifyModule(*ir_module, &llvm::errs())) { + return Status::CodeGenError("verify of IR Module failed"); + } + + // Link this to the primary module. + if (llvm::Linker::linkModules(*module_, move(ir_module))) { + return Status::CodeGenError("failed to link IR Modules"); + } + return Status::OK(); +} + +// Optimise and compile the module. +Status Engine::FinalizeModule(bool optimise_ir, bool dump_ir) { + if (dump_ir) { + DumpIR("Before optimise"); + } + + // Setup an optimiser pipeline + if (optimise_ir) { + std::unique_ptr pass_manager( + new llvm::legacy::PassManager()); + + // First round : get rid of all functions that don't need to be compiled. + // This helps in reducing the overall compilation time. + // (Adapted from Apache Impala) + // + // Done by marking all the unused functions as internal, and then, running + // a pass for dead code elimination. + std::unordered_set used_functions; + used_functions.insert(functions_to_compile_.begin(), functions_to_compile_.end()); + + pass_manager->add( + llvm::createInternalizePass([&used_functions](const llvm::GlobalValue &func) { + return (used_functions.find(func.getName().str()) != used_functions.end()); + })); + pass_manager->add(llvm::createGlobalDCEPass()); + pass_manager->run(*module_); + + // Second round : misc passes to allow for inlining, vectorization, .. + pass_manager.reset(new llvm::legacy::PassManager()); + llvm::TargetIRAnalysis target_analysis = + execution_engine_->getTargetMachine()->getTargetIRAnalysis(); + pass_manager->add(llvm::createTargetTransformInfoWrapperPass(target_analysis)); + pass_manager->add(llvm::createFunctionInliningPass()); + pass_manager->add(llvm::createInstructionCombiningPass()); + pass_manager->add(llvm::createPromoteMemoryToRegisterPass()); + pass_manager->add(llvm::createGVNPass()); + pass_manager->add(llvm::createNewGVNPass()); + pass_manager->add(llvm::createCFGSimplificationPass()); + pass_manager->add(llvm::createLoopVectorizePass()); + pass_manager->add(llvm::createSLPVectorizerPass()); + pass_manager->add(llvm::createGlobalOptimizerPass()); + + // run the optimiser + llvm::PassManagerBuilder pass_builder; + pass_builder.OptLevel = 2; + pass_builder.populateModulePassManager(*pass_manager); + pass_manager->run(*module_); + + if (dump_ir) { + DumpIR("After optimise"); + } + } + + if (llvm::verifyModule(*module_, &llvm::errs())) { + return Status::CodeGenError("verify of module failed after optimisation passes"); + } + + // do the compilation + execution_engine_->finalizeObject(); + module_finalized_ = true; + return Status::OK(); +} + +void *Engine::CompiledFunction(llvm::Function *irFunction) { + DCHECK(module_finalized_); + return execution_engine_->getPointerToFunction(irFunction); +} + +void Engine::DumpIR(std::string prefix) { + std::string str; + + llvm::raw_string_ostream stream(str); + module_->print(stream, NULL); + std::cout << "====" << prefix << "===" << str << "\n"; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/engine.h b/cpp/src/gandiva/codegen/engine.h new file mode 100644 index 00000000000..2b539d6c3a2 --- /dev/null +++ b/cpp/src/gandiva/codegen/engine.h @@ -0,0 +1,97 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_ENGINE_H +#define GANDIVA_ENGINE_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "gandiva/configuration.h" +#include "gandiva/logging.h" +#include "gandiva/status.h" + +namespace gandiva { + +/// \brief LLVM Execution engine wrapper. +class Engine { + public: + llvm::LLVMContext *context() { return context_.get(); } + llvm::IRBuilder<> &ir_builder() { return *ir_builder_.get(); } + + llvm::Module *module() { return module_; } + + /// factory method to create and initialize the engine object. + /// + /// \param[out] engine the created engine. + static Status Make(std::shared_ptr config, + std::unique_ptr *engine); + + /// Add the function to the list of IR functions that need to be compiled. + /// Compiling only the functions that are used by the module saves time. + void AddFunctionToCompile(const std::string &fname) { + DCHECK(!module_finalized_); + functions_to_compile_.push_back(fname); + } + + /// Optimise and compile the module. + Status FinalizeModule(bool optimise_ir, bool dump_ir); + + /// Get the compiled function corresponding to the irfunction. + void *CompiledFunction(llvm::Function *irFunction); + + private: + /// private constructor to ensure engine is created + /// only through the factory. + Engine() : module_finalized_(false) {} + + /// do one time inits. + static void InitOnce(); + static bool init_once_done_; + + llvm::ExecutionEngine &execution_engine() { return *execution_engine_.get(); } + + /// load pre-compiled so libraries and merge them into the main module. + Status LoadPreCompiledHelperLibs(const std::string &helper_lib_file_path); + + /// load pre-compiled IR modules and merge them into the main module. + Status LoadPreCompiledIRFiles(const std::string &byte_code_file_path); + + /// dump the IR code to stdout with the prefix string. + void DumpIR(std::string prefix); + + std::unique_ptr context_; + std::unique_ptr execution_engine_; + std::unique_ptr> ir_builder_; + llvm::Module *module_; // This is owned by the execution_engine_, so doesn't need to be + // explicitly deleted. + + std::vector functions_to_compile_; + + bool module_finalized_; + std::string llvm_error_; + + static std::set loaded_libs_; + static std::mutex mtx_; +}; + +} // namespace gandiva + +#endif // GANDIVA_ENGINE_H diff --git a/cpp/src/gandiva/codegen/engine_llvm_test.cc b/cpp/src/gandiva/codegen/engine_llvm_test.cc new file mode 100644 index 00000000000..ee5e74cb581 --- /dev/null +++ b/cpp/src/gandiva/codegen/engine_llvm_test.cc @@ -0,0 +1,131 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/engine.h" + +#include +#include "codegen/llvm_types.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t *elements, int nelements); + +class TestEngine : public ::testing::Test { + protected: + llvm::Function *BuildVecAdd(Engine *engine, LLVMTypes *types); +}; + +llvm::Function *TestEngine::BuildVecAdd(Engine *engine, LLVMTypes *types) { + llvm::IRBuilder<> &builder = engine->ir_builder(); + llvm::LLVMContext *context = engine->context(); + + // Create fn prototype : + // int64_t add_longs(int64_t *elements, int32_t nelements) + std::vector arguments; + arguments.push_back(types->i64_ptr_type()); + arguments.push_back(types->i32_type()); + llvm::FunctionType *prototype = + llvm::FunctionType::get(types->i64_type(), arguments, false /*isVarArg*/); + + // Create fn + std::string func_name = "add_longs"; + engine->AddFunctionToCompile(func_name); + llvm::Function *fn = llvm::Function::Create( + prototype, llvm::GlobalValue::ExternalLinkage, func_name, engine->module()); + assert(fn != NULL); + + // Name the arguments + llvm::Function::arg_iterator args = fn->arg_begin(); + llvm::Value *arg_elements = &*args; + arg_elements->setName("elements"); + ++args; + llvm::Value *arg_nelements = &*args; + arg_nelements->setName("nelements"); + ++args; + + llvm::BasicBlock *loop_entry = llvm::BasicBlock::Create(*context, "entry", fn); + llvm::BasicBlock *loop_body = llvm::BasicBlock::Create(*context, "loop", fn); + llvm::BasicBlock *loop_exit = llvm::BasicBlock::Create(*context, "exit", fn); + + // Loop entry + builder.SetInsertPoint(loop_entry); + builder.CreateBr(loop_body); + + // Loop body + builder.SetInsertPoint(loop_body); + + llvm::PHINode *loop_var = builder.CreatePHI(types->i32_type(), 2, "loop_var"); + llvm::PHINode *sum = builder.CreatePHI(types->i64_type(), 2, "sum"); + + loop_var->addIncoming(types->i32_constant(0), loop_entry); + sum->addIncoming(types->i64_constant(0), loop_entry); + + // setup loop PHI + llvm::Value *loop_update = + builder.CreateAdd(loop_var, types->i32_constant(1), "loop_var+1"); + loop_var->addIncoming(loop_update, loop_body); + + // get the current value + llvm::Value *offset = builder.CreateGEP(arg_elements, loop_var, "offset"); + llvm::Value *current_value = builder.CreateLoad(offset, "value"); + + // setup sum PHI + llvm::Value *sum_update = builder.CreateAdd(sum, current_value, "sum+ith"); + sum->addIncoming(sum_update, loop_body); + + // check loop_var + llvm::Value *loop_var_check = + builder.CreateICmpSLT(loop_update, arg_nelements, "loop_var < nrec"); + builder.CreateCondBr(loop_var_check, loop_body, loop_exit); + + // Loop exit + builder.SetInsertPoint(loop_exit); + builder.CreateRet(sum_update); + return fn; +} + +TEST_F(TestEngine, TestAddUnoptimised) { + std::unique_ptr engine; + Engine::Make(ConfigurationBuilder::DefaultConfiguration(), &engine); + LLVMTypes types(*engine->context()); + llvm::Function *ir_func = BuildVecAdd(engine.get(), &types); + engine->FinalizeModule(false, false); + + add_vector_func_t add_func = + reinterpret_cast(engine->CompiledFunction(ir_func)); + + int64_t my_array[] = {1, 3, -5, 8, 10}; + EXPECT_EQ(add_func(my_array, 5), 17); +} + +TEST_F(TestEngine, TestAddOptimised) { + std::unique_ptr engine; + Engine::Make(ConfigurationBuilder::DefaultConfiguration(), &engine); + LLVMTypes types(*engine->context()); + llvm::Function *ir_func = BuildVecAdd(engine.get(), &types); + engine->FinalizeModule(true, false); + + add_vector_func_t add_func = + reinterpret_cast(engine->CompiledFunction(ir_func)); + + int64_t my_array[] = {1, 3, -5, 8, 10}; + EXPECT_EQ(add_func(my_array, 5), 17); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/eval_batch.h b/cpp/src/gandiva/codegen/eval_batch.h new file mode 100644 index 00000000000..c4211a0b5f3 --- /dev/null +++ b/cpp/src/gandiva/codegen/eval_batch.h @@ -0,0 +1,83 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_EVALBATCH_H +#define GANDIVA_EXPR_EVALBATCH_H + +#include +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" +#include "local_bitmaps_holder.h" + +namespace gandiva { + +/// \brief The buffers corresponding to one batch of records, used for +/// expression evaluation. +class EvalBatch { + public: + explicit EvalBatch(int num_records, int num_buffers, int num_local_bitmaps) + : num_records_(num_records), num_buffers_(num_buffers) { + if (num_buffers > 0) { + buffers_array_.reset(new uint8_t *[num_buffers]); + } + local_bitmaps_holder_.reset(new LocalBitMapsHolder(num_records, num_local_bitmaps)); + } + + int num_records() const { return num_records_; } + + uint8_t **GetBufferArray() const { return buffers_array_.get(); } + + int GetNumBuffers() const { return num_buffers_; } + + uint8_t *GetBuffer(int idx) const { + DCHECK(idx <= num_buffers_); + return (buffers_array_.get())[idx]; + } + + void SetBuffer(int idx, uint8_t *buffer) { + DCHECK(idx <= num_buffers_); + (buffers_array_.get())[idx] = buffer; + } + + int GetNumLocalBitMaps() const { return local_bitmaps_holder_->GetNumLocalBitMaps(); } + + int GetLocalBitmapSize() const { return local_bitmaps_holder_->GetLocalBitMapSize(); } + + uint8_t *GetLocalBitMap(int idx) const { + DCHECK(idx <= GetNumLocalBitMaps()); + return local_bitmaps_holder_->GetLocalBitMap(idx); + } + + uint8_t **GetLocalBitMapArray() const { + return local_bitmaps_holder_->GetLocalBitMapArray(); + } + + private: + /// number of records in the current batch. + int num_records_; + + // number of buffers. + int num_buffers_; + + /// An array of 'num_buffers_', each containing a buffer. The buffer + /// sizes depends on the data type, but all of them have the same + /// number of slots (equal to num_records_). + std::unique_ptr buffers_array_; + + std::unique_ptr local_bitmaps_holder_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_EVALBATCH_H diff --git a/cpp/src/gandiva/codegen/expr_decomposer.cc b/cpp/src/gandiva/codegen/expr_decomposer.cc new file mode 100644 index 00000000000..d477407c76e --- /dev/null +++ b/cpp/src/gandiva/codegen/expr_decomposer.cc @@ -0,0 +1,230 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/expr_decomposer.h" + +#include +#include +#include +#include + +#include "codegen/annotator.h" +#include "codegen/dex.h" +#include "codegen/function_holder_registry.h" +#include "codegen/function_registry.h" +#include "codegen/node.h" +#include "gandiva/function_signature.h" + +namespace gandiva { + +// Decompose a field node - simply seperate out validity & value arrays. +Status ExprDecomposer::Visit(const FieldNode &node) { + auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field()); + + DexPtr validity_dex = std::make_shared(desc); + DexPtr value_dex; + if (desc->HasOffsetsIdx()) { + value_dex = std::make_shared(desc); + } else { + value_dex = std::make_shared(desc); + } + result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); +} + +// Decompose a field node - wherever possible, merge the validity vectors of the +// child nodes. +Status ExprDecomposer::Visit(const FunctionNode &node) { + auto desc = node.descriptor(); + FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); + const NativeFunction *native_function = registry_.LookupSignature(signature); + DCHECK(native_function) << "Missing Signature " << signature.ToString(); + + // decompose the children. + std::vector args; + for (auto &child : node.children()) { + auto status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + + args.push_back(result()); + } + + // Make a function holder, if required. + std::shared_ptr holder; + if (native_function->needs_holder()) { + auto status = FunctionHolderRegistry::Make(desc->name(), node, &holder); + GANDIVA_RETURN_NOT_OK(status); + } + + if (native_function->result_nullable_type() == RESULT_NULL_IF_NULL) { + // These functions are decomposable, merge the validity bits of the children. + + std::vector merged_validity; + for (auto &decomposed : args) { + // Merge the validity_expressions of the children to build a combined validity + // expression. + merged_validity.insert(merged_validity.end(), decomposed->validity_exprs().begin(), + decomposed->validity_exprs().end()); + } + + auto value_dex = + std::make_shared(desc, native_function, holder, args); + result_ = std::make_shared(merged_validity, value_dex); + } else if (native_function->result_nullable_type() == RESULT_NULL_NEVER) { + // These functions always output valid results. So, no validity dex. + auto value_dex = + std::make_shared(desc, native_function, holder, args); + result_ = std::make_shared(value_dex); + } else { + DCHECK(native_function->result_nullable_type() == RESULT_NULL_INTERNAL); + + // Add a local bitmap to track the output validity. + int local_bitmap_idx = annotator_.AddLocalBitMap(); + auto validity_dex = std::make_shared(local_bitmap_idx); + + auto value_dex = std::make_shared( + desc, native_function, holder, args, local_bitmap_idx); + result_ = std::make_shared(validity_dex, value_dex); + } + return Status::OK(); +} + +// Decompose an IfNode +Status ExprDecomposer::Visit(const IfNode &node) { + // Add a local bitmap to track the output validity. + auto status = node.condition()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + auto condition_vv = result(); + + int local_bitmap_idx = PushThenEntry(node); + status = node.then_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + auto then_vv = result(); + PopThenEntry(node); + + PushElseEntry(node, local_bitmap_idx); + status = node.else_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + auto else_vv = result(); + bool is_terminal_else = PopElseEntry(node); + + auto validity_dex = std::make_shared(local_bitmap_idx); + auto value_dex = + std::make_shared(condition_vv, then_vv, else_vv, node.return_type(), + local_bitmap_idx, is_terminal_else); + + result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); +} + +// Decompose a BooleanNode +Status ExprDecomposer::Visit(const BooleanNode &node) { + // decompose the children. + std::vector args; + for (auto &child : node.children()) { + auto status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + + args.push_back(result()); + } + + // Add a local bitmap to track the output validity. + int local_bitmap_idx = annotator_.AddLocalBitMap(); + auto validity_dex = std::make_shared(local_bitmap_idx); + + std::shared_ptr value_dex; + switch (node.expr_type()) { + case BooleanNode::AND: + value_dex = std::make_shared(args, local_bitmap_idx); + break; + case BooleanNode::OR: + value_dex = std::make_shared(args, local_bitmap_idx); + break; + } + result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); +} + +Status ExprDecomposer::Visit(const LiteralNode &node) { + auto value_dex = std::make_shared(node.return_type(), node.holder()); + DexPtr validity_dex; + if (node.is_null()) { + validity_dex = std::make_shared(); + } else { + validity_dex = std::make_shared(); + } + result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); +} + +// The bolow functions use a stack to detect : +// a. nested if-else expressions. +// In such cases, the local bitmap can be re-used. +// b. detect terminal else expressions +// The non-terminal else expressions do not need to track validity (the if statement +// that has a match will do it). +// Both of the above optimisations save CPU cycles during expression evaluation. + +int ExprDecomposer::PushThenEntry(const IfNode &node) { + int local_bitmap_idx; + + if (!if_entries_stack_.empty() && !if_entries_stack_.top()->is_then_) { + auto top = if_entries_stack_.top().get(); + + // inside a nested else statement (i.e if-else-if). use the parent's bitmap. + local_bitmap_idx = top->local_bitmap_idx_; + + // clear the is_terminal bit in the current top entry (else). + top->is_terminal_else_ = false; + } else { + // alloc a new bitmap. + local_bitmap_idx = annotator_.AddLocalBitMap(); + } + + // push new entry to the stack. + std::unique_ptr entry(new IfStackEntry( + node, true /*is_then*/, false /*is_terminal_else*/, local_bitmap_idx)); + if_entries_stack_.push(std::move(entry)); + return local_bitmap_idx; +} + +void ExprDecomposer::PopThenEntry(const IfNode &node) { + DCHECK_EQ(if_entries_stack_.empty(), false) << "PopThenEntry: found empty stack"; + + auto top = if_entries_stack_.top().get(); + DCHECK_EQ(top->is_then_, true) << "PopThenEntry: found else, expected then"; + DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node"; + + if_entries_stack_.pop(); +} + +void ExprDecomposer::PushElseEntry(const IfNode &node, int local_bitmap_idx) { + std::unique_ptr entry(new IfStackEntry( + node, false /*is_then*/, true /*is_terminal_else*/, local_bitmap_idx)); + if_entries_stack_.push(std::move(entry)); +} + +bool ExprDecomposer::PopElseEntry(const IfNode &node) { + DCHECK_EQ(if_entries_stack_.empty(), false) << "PopElseEntry: found empty stack"; + + auto top = if_entries_stack_.top().get(); + DCHECK_EQ(top->is_then_, false) << "PopElseEntry: found then, expected else"; + DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node"; + bool is_terminal_else = top->is_terminal_else_; + + if_entries_stack_.pop(); + return is_terminal_else; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expr_decomposer.h b/cpp/src/gandiva/codegen/expr_decomposer.h new file mode 100644 index 00000000000..12fcd9e3baf --- /dev/null +++ b/cpp/src/gandiva/codegen/expr_decomposer.h @@ -0,0 +1,99 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_DECOMPOSER_H +#define GANDIVA_EXPR_DECOMPOSER_H + +#include +#include +#include + +#include +#include "codegen/node.h" +#include "codegen/node_visitor.h" +#include "gandiva/expression.h" + +namespace gandiva { + +class FunctionRegistry; +class Annotator; + +/// \brief Decomposes an expression tree to seperate out the validity and +/// value expressions. +class ExprDecomposer : public NodeVisitor { + public: + explicit ExprDecomposer(const FunctionRegistry ®istry, Annotator &annotator) + : registry_(registry), annotator_(annotator) {} + + Status Decompose(const Node &root, ValueValidityPairPtr *out) { + auto status = root.Accept(*this); + if (status.ok()) { + *out = std::move(result_); + } + return status; + } + + private: + FRIEND_TEST(TestExprDecomposer, TestStackSimple); + FRIEND_TEST(TestExprDecomposer, TestNested); + FRIEND_TEST(TestExprDecomposer, TestInternalIf); + FRIEND_TEST(TestExprDecomposer, TestParallelIf); + + Status Visit(const FieldNode &node) override; + Status Visit(const FunctionNode &node) override; + Status Visit(const IfNode &node) override; + Status Visit(const LiteralNode &node) override; + Status Visit(const BooleanNode &node) override; + + // stack of if nodes. + class IfStackEntry { + public: + IfStackEntry(const IfNode &if_node, bool is_then, bool is_terminal_else, + int local_bitmap_idx) + : if_node_(if_node), + is_then_(is_then), + is_terminal_else_(is_terminal_else), + local_bitmap_idx_(local_bitmap_idx) {} + + const IfNode &if_node_; + bool is_then_; + bool is_terminal_else_; + int local_bitmap_idx_; + }; + + // push 'then entry' to stack. returns either a new local bitmap or the parent's + // bitmap (in case of nested if-else). + int PushThenEntry(const IfNode &node); + + // pop 'then entry' from stack. + void PopThenEntry(const IfNode &node); + + // push 'else entry' into stack. + void PushElseEntry(const IfNode &node, int local_bitmap_idx); + + // pop 'else entry' from stack. returns 'true' if this is a terminal else condition + // i.e no nested if condition below this node. + bool PopElseEntry(const IfNode &node); + + ValueValidityPairPtr result() { return std::move(result_); } + + const FunctionRegistry ®istry_; + Annotator &annotator_; + std::stack> if_entries_stack_; + ValueValidityPairPtr result_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_DECOMPOSER_H diff --git a/cpp/src/gandiva/codegen/expr_decomposer_test.cc b/cpp/src/gandiva/codegen/expr_decomposer_test.cc new file mode 100644 index 00000000000..463529e6927 --- /dev/null +++ b/cpp/src/gandiva/codegen/expr_decomposer_test.cc @@ -0,0 +1,156 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/expr_decomposer.h" + +#include +#include "codegen/annotator.h" +#include "codegen/dex.h" +#include "codegen/function_registry.h" +#include "codegen/node.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::int32; + +class TestExprDecomposer : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +TEST_F(TestExprDecomposer, TestStackSimple) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + + int idx_a = decomposer.PushThenEntry(node_a); + EXPECT_EQ(idx_a, 0); + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestNested) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + // if (b) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + int idx_a = decomposer.PushThenEntry(node_a); + EXPECT_EQ(idx_a, 0); + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + + { // start b + int idx_b = decomposer.PushThenEntry(node_b); + EXPECT_EQ(idx_b, 0); // must reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + } // end b + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, false); // there was a nested if. + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestInternalIf) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // if (b) _ + // else _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + int idx_a = decomposer.PushThenEntry(node_a); + EXPECT_EQ(idx_a, 0); + + { // start b + int idx_b = decomposer.PushThenEntry(node_b); + EXPECT_EQ(idx_b, 1); // must not reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + } // end b + + decomposer.PopThenEntry(node_a); + decomposer.PushElseEntry(node_a, idx_a); + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); // there was no nested if. + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestParallelIf) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + // if (b) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + int idx_a = decomposer.PushThenEntry(node_a); + EXPECT_EQ(idx_a, 0); + + decomposer.PopThenEntry(node_a); + decomposer.PushElseEntry(node_a, idx_a); + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); // there was no nested if. + + // start b + int idx_b = decomposer.PushThenEntry(node_b); + EXPECT_EQ(idx_b, 1); // must not reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expr_validator.cc b/cpp/src/gandiva/codegen/expr_validator.cc new file mode 100644 index 00000000000..9ca286206e7 --- /dev/null +++ b/cpp/src/gandiva/codegen/expr_validator.cc @@ -0,0 +1,154 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "codegen/expr_validator.h" + +namespace gandiva { + +Status ExprValidator::Validate(const ExpressionPtr &expr) { + if (expr == nullptr) { + return Status::ExpressionValidationError("Expression cannot be null."); + } + Node &root = *expr->root(); + Status status = root.Accept(*this); + if (!status.ok()) { + return status; + } + // validate return type matches + // no need to check if type is supported + // since root type has been validated. + if (!root.return_type()->Equals(*expr->result()->type())) { + std::stringstream ss; + ss << "Return type of root node " << root.return_type()->name() + << " does not match that of expression " << *expr->result()->type(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const FieldNode &node) { + auto llvm_type = types_.IRType(node.return_type()->id()); + if (llvm_type == nullptr) { + std::stringstream ss; + ss << "Field " << node.field()->name() << " has unsupported data type " + << node.return_type()->name(); + return Status::ExpressionValidationError(ss.str()); + } + + auto field_in_schema_entry = field_map_.find(node.field()->name()); + + // validate that field is in schema. + if (field_in_schema_entry == field_map_.end()) { + std::stringstream ss; + ss << "Field " << node.field()->name() << " not in schema."; + return Status::ExpressionValidationError(ss.str()); + } + + FieldPtr field_in_schema = field_in_schema_entry->second; + // validate that field matches the definition in schema. + if (!field_in_schema->Equals(node.field())) { + std::stringstream ss; + ss << "Field definition in schema " << field_in_schema->ToString() + << " different from field in expression " << node.field()->ToString(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const FunctionNode &node) { + auto desc = node.descriptor(); + FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); + const NativeFunction *native_function = registry_.LookupSignature(signature); + if (native_function == nullptr) { + std::stringstream ss; + ss << "Function " << signature.ToString() << " not supported yet. "; + return Status::ExpressionValidationError(ss.str()); + } + + for (auto &child : node.children()) { + Status status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const IfNode &node) { + Status status = node.condition()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + status = node.then_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + status = node.else_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + + auto if_node_ret_type = node.return_type(); + auto then_node_ret_type = node.then_node()->return_type(); + auto else_node_ret_type = node.else_node()->return_type(); + + if (if_node_ret_type != then_node_ret_type) { + std::stringstream ss; + ss << "Return type of if " << *if_node_ret_type << " and then " + << then_node_ret_type->name() << " not matching."; + return Status::ExpressionValidationError(ss.str()); + } + + if (if_node_ret_type != else_node_ret_type) { + std::stringstream ss; + ss << "Return type of if " << *if_node_ret_type << " and else " + << else_node_ret_type->name() << " not matching."; + return Status::ExpressionValidationError(ss.str()); + } + + return Status::OK(); +} + +Status ExprValidator::Visit(const LiteralNode &node) { + auto llvm_type = types_.IRType(node.return_type()->id()); + if (llvm_type == nullptr) { + std::stringstream ss; + ss << "Value " << node.holder() << " has unsupported data type " + << node.return_type()->name(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const BooleanNode &node) { + Status status; + + if (node.children().size() < 2) { + std::stringstream ss; + ss << "Boolean expression has " << node.children().size() + << " children, expected atleast two"; + return Status::ExpressionValidationError(ss.str()); + } + + for (auto &child : node.children()) { + if (child->return_type() != arrow::boolean()) { + std::stringstream ss; + ss << "Boolean expression has a child with return type " + << child->return_type()->name() << ", expected return type boolean"; + return Status::ExpressionValidationError(ss.str()); + } + + status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expr_validator.h b/cpp/src/gandiva/codegen/expr_validator.h new file mode 100644 index 00000000000..83d3d2a2215 --- /dev/null +++ b/cpp/src/gandiva/codegen/expr_validator.h @@ -0,0 +1,72 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_VALIDATOR_H +#define GANDIVA_EXPR_VALIDATOR_H + +#include +#include + +#include "boost/functional/hash.hpp" +#include "codegen/function_registry.h" +#include "codegen/llvm_types.h" +#include "codegen/node.h" +#include "codegen/node_visitor.h" +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/status.h" + +namespace gandiva { + +class FunctionRegistry; + +/// \brief Validates the entire expression tree including +/// data types, signatures and return types +class ExprValidator : public NodeVisitor { + public: + explicit ExprValidator(LLVMTypes &types, SchemaPtr schema) + : types_(types), schema_(schema) { + for (auto &field : schema_->fields()) { + field_map_[field->name()] = field; + } + } + + /// \brief Validates the root node + /// of an expression. + /// 1. Data type of fields and literals. + /// 2. Function signature is supported. + /// 3. For if nodes that return types match + /// for if, then and else nodes. + Status Validate(const ExpressionPtr &expr); + + private: + Status Visit(const FieldNode &node) override; + Status Visit(const FunctionNode &node) override; + Status Visit(const IfNode &node) override; + Status Visit(const LiteralNode &node) override; + Status Visit(const BooleanNode &node) override; + + FunctionRegistry registry_; + + LLVMTypes &types_; + + SchemaPtr schema_; + + using FieldMap = std::unordered_map>; + FieldMap field_map_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_VALIDATOR_H diff --git a/cpp/src/gandiva/codegen/expression.cc b/cpp/src/gandiva/codegen/expression.cc new file mode 100644 index 00000000000..35d995e1f24 --- /dev/null +++ b/cpp/src/gandiva/codegen/expression.cc @@ -0,0 +1,22 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/expression.h" +#include "codegen/node.h" + +namespace gandiva { + +std::string Expression::ToString() { return root()->ToString(); } + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expression.h b/cpp/src/gandiva/codegen/expression.h new file mode 100644 index 00000000000..6bd1d0e631c --- /dev/null +++ b/cpp/src/gandiva/codegen/expression.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_EXPR_EXPRESSION_H +#define GANDIVA_EXPR_EXPRESSION_H + +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief An expression tree with a root node, and a result field. +class Expression { + public: + Expression(const NodePtr root, const FieldPtr result) : root_(root), result_(result) {} + + virtual ~Expression() = default; + + const NodePtr &root() const { return root_; } + + const FieldPtr &result() const { return result_; } + + std::string ToString(); + + private: + const NodePtr root_; + const FieldPtr result_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_EXPRESSION_H diff --git a/cpp/src/gandiva/codegen/expression_registry.cc b/cpp/src/gandiva/codegen/expression_registry.cc new file mode 100644 index 00000000000..0a5875b5b93 --- /dev/null +++ b/cpp/src/gandiva/codegen/expression_registry.cc @@ -0,0 +1,150 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/expression_registry.h" + +#include "boost/iterator/transform_iterator.hpp" + +#include "codegen/function_registry.h" +#include "codegen/llvm_types.h" + +namespace gandiva { + +ExpressionRegistry::ExpressionRegistry() { + function_registry_.reset(new FunctionRegistry()); +} + +ExpressionRegistry::~ExpressionRegistry() {} + +const ExpressionRegistry::FunctionSignatureIterator +ExpressionRegistry::function_signature_begin() { + return FunctionSignatureIterator(function_registry_->begin()); +} + +const ExpressionRegistry::FunctionSignatureIterator +ExpressionRegistry::function_signature_end() const { + return FunctionSignatureIterator(function_registry_->end()); +} + +bool ExpressionRegistry::FunctionSignatureIterator::operator!=( + const FunctionSignatureIterator &func_sign_it) { + return func_sign_it.it_ != this->it_; +} + +FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() { + return (*it_).signature(); +} + +ExpressionRegistry::iterator ExpressionRegistry::FunctionSignatureIterator::operator++( + int increment) { + return it_++; +} + +DataTypeVector ExpressionRegistry::supported_types_ = + ExpressionRegistry::InitSupportedTypes(); + +DataTypeVector ExpressionRegistry::InitSupportedTypes() { + DataTypeVector data_type_vector; + llvm::LLVMContext llvm_context; + LLVMTypes llvm_types(llvm_context); + auto supported_arrow_types = llvm_types.GetSupportedArrowTypes(); + for (auto &type_id : supported_arrow_types) { + AddArrowTypesToVector(type_id, data_type_vector); + } + return data_type_vector; +} + +void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type &type, + DataTypeVector &vector) { + switch (type) { + case arrow::Type::type::BOOL: + vector.push_back(arrow::boolean()); + break; + case arrow::Type::type::UINT8: + vector.push_back(arrow::uint8()); + break; + case arrow::Type::type::INT8: + vector.push_back(arrow::int8()); + break; + case arrow::Type::type::UINT16: + vector.push_back(arrow::uint16()); + break; + case arrow::Type::type::INT16: + vector.push_back(arrow::int16()); + break; + case arrow::Type::type::UINT32: + vector.push_back(arrow::uint32()); + break; + case arrow::Type::type::INT32: + vector.push_back(arrow::int32()); + break; + case arrow::Type::type::UINT64: + vector.push_back(arrow::uint64()); + break; + case arrow::Type::type::INT64: + vector.push_back(arrow::int64()); + break; + case arrow::Type::type::HALF_FLOAT: + vector.push_back(arrow::float16()); + break; + case arrow::Type::type::FLOAT: + vector.push_back(arrow::float32()); + break; + case arrow::Type::type::DOUBLE: + vector.push_back(arrow::float64()); + break; + case arrow::Type::type::STRING: + vector.push_back(arrow::utf8()); + break; + case arrow::Type::type::BINARY: + vector.push_back(arrow::binary()); + break; + case arrow::Type::type::DATE32: + vector.push_back(arrow::date32()); + break; + case arrow::Type::type::DATE64: + vector.push_back(arrow::date64()); + break; + case arrow::Type::type::TIMESTAMP: + vector.push_back(arrow::timestamp(arrow::TimeUnit::SECOND)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::MILLI)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::NANO)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::MICRO)); + break; + case arrow::Type::type::TIME32: + vector.push_back(arrow::time32(arrow::TimeUnit::SECOND)); + vector.push_back(arrow::time32(arrow::TimeUnit::MILLI)); + break; + case arrow::Type::type::TIME64: + vector.push_back(arrow::time64(arrow::TimeUnit::MICRO)); + vector.push_back(arrow::time64(arrow::TimeUnit::NANO)); + break; + case arrow::Type::type::NA: + vector.push_back(arrow::null()); + break; + case arrow::Type::type::FIXED_SIZE_BINARY: + case arrow::Type::type::MAP: + case arrow::Type::type::INTERVAL: + case arrow::Type::type::DECIMAL: + case arrow::Type::type::LIST: + case arrow::Type::type::STRUCT: + case arrow::Type::type::UNION: + case arrow::Type::type::DICTIONARY: + // un-supported types. test ensures that + // when one of these are added build breaks. + DCHECK(false); + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expression_registry.h b/cpp/src/gandiva/codegen/expression_registry.h new file mode 100644 index 00000000000..dba698a117d --- /dev/null +++ b/cpp/src/gandiva/codegen/expression_registry.h @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_TYPES_H +#define GANDIVA_TYPES_H + +#include +#include + +#include "gandiva/arrow.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +class NativeFunction; +class FunctionRegistry; +/// \brief Exports types supported by Gandiva for processing. +/// +/// Has helper methods for clients to programatically discover +/// data types and functions supported by Gandiva. +class ExpressionRegistry { + public: + using iterator = const NativeFunction *; + ExpressionRegistry(); + ~ExpressionRegistry(); + static DataTypeVector supported_types() { return supported_types_; } + class FunctionSignatureIterator { + public: + FunctionSignatureIterator(iterator it) : it_(it) {} + + bool operator!=(const FunctionSignatureIterator &func_sign_it); + + FunctionSignature operator*(); + + iterator operator++(int); + + private: + iterator it_; + }; + const FunctionSignatureIterator function_signature_begin(); + const FunctionSignatureIterator function_signature_end() const; + + private: + static DataTypeVector supported_types_; + static DataTypeVector InitSupportedTypes(); + static void AddArrowTypesToVector(arrow::Type::type &type, DataTypeVector &vector); + std::unique_ptr function_registry_; +}; +} // namespace gandiva +#endif // GANDIVA_TYPES_H diff --git a/cpp/src/gandiva/codegen/expression_registry_test.cc b/cpp/src/gandiva/codegen/expression_registry_test.cc new file mode 100644 index 00000000000..95b8fa732c9 --- /dev/null +++ b/cpp/src/gandiva/codegen/expression_registry_test.cc @@ -0,0 +1,64 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/expression_registry.h" + +#include +#include + +#include +#include "codegen/function_registry.h" +#include "codegen/llvm_types.h" +#include "gandiva/function_signature.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t *elements, int nelements); + +class TestExpressionRegistry : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +// Verify all functions in registry are exported. +TEST_F(TestExpressionRegistry, VerifySupportedFunctions) { + std::vector functions; + ExpressionRegistry expr_registry; + for (auto iter = expr_registry.function_signature_begin(); + iter != expr_registry.function_signature_end(); iter++) { + functions.push_back((*iter)); + } + for (auto &iter : registry_) { + auto function = iter.signature(); + auto element = std::find(functions.begin(), functions.end(), function); + EXPECT_NE(element, functions.end()) + << "function " << iter.pc_name() << " missing in supported functions.\n"; + } +} + +// Verify all types are supported. +TEST_F(TestExpressionRegistry, VerifyDataTypes) { + DataTypeVector data_types = ExpressionRegistry::supported_types(); + llvm::LLVMContext llvm_context; + LLVMTypes llvm_types(llvm_context); + auto supported_arrow_types = llvm_types.GetSupportedArrowTypes(); + for (auto &type_id : supported_arrow_types) { + auto element = + std::find(supported_arrow_types.begin(), supported_arrow_types.end(), type_id); + EXPECT_NE(element, supported_arrow_types.end()) + << "data type " << type_id << " missing in supported data types.\n"; + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/field_descriptor.h b/cpp/src/gandiva/codegen/field_descriptor.h new file mode 100644 index 00000000000..303482514ea --- /dev/null +++ b/cpp/src/gandiva/codegen/field_descriptor.h @@ -0,0 +1,62 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FIELDDESCRIPTOR_H +#define GANDIVA_FIELDDESCRIPTOR_H + +#include + +#include "gandiva/arrow.h" + +namespace gandiva { + +/// \brief Descriptor for an arrow field. Holds indexes into the flattened array of +/// buffers that is passed to LLVM generated functions. +class FieldDescriptor { + public: + static const int kInvalidIdx = -1; + + FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, + int offsets_idx = kInvalidIdx) + : field_(field), + data_idx_(data_idx), + validity_idx_(validity_idx), + offsets_idx_(offsets_idx) {} + + /// Index of validity array in the array-of-buffers + int validity_idx() const { return validity_idx_; } + + /// Index of data array in the array-of-buffers + int data_idx() const { return data_idx_; } + + /// Index of offsets array in the array-of-buffers + int offsets_idx() const { return offsets_idx_; } + + FieldPtr field() const { return field_; } + + const std::string &Name() const { return field_->name(); } + DataTypePtr Type() const { return field_->type(); } + + bool HasOffsetsIdx() const { return offsets_idx_ != kInvalidIdx; } + + private: + FieldPtr field_; + int data_idx_; + int validity_idx_; + int offsets_idx_; +}; + +} // namespace gandiva + +#endif // GANDIVA_FIELDDESCRIPTOR_H diff --git a/cpp/src/gandiva/codegen/filter.cc b/cpp/src/gandiva/codegen/filter.cc new file mode 100644 index 00000000000..b9255d98c1c --- /dev/null +++ b/cpp/src/gandiva/codegen/filter.cc @@ -0,0 +1,114 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/filter.h" + +#include +#include +#include + +#include "codegen/bitmap_accumulator.h" +#include "codegen/cache.h" +#include "codegen/expr_validator.h" +#include "codegen/filter_cache_key.h" +#include "codegen/llvm_generator.h" +#include "codegen/selection_vector_impl.h" +#include "gandiva/condition.h" +#include "gandiva/status.h" + +namespace gandiva { + +Filter::Filter(std::unique_ptr llvm_generator, SchemaPtr schema, + std::shared_ptr configuration) + : llvm_generator_(std::move(llvm_generator)), + schema_(schema), + configuration_(configuration) {} + +Status Filter::Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr configuration, + std::shared_ptr *filter) { + GANDIVA_RETURN_FAILURE_IF_FALSE(schema != nullptr, + Status::Invalid("schema cannot be null")); + GANDIVA_RETURN_FAILURE_IF_FALSE(condition != nullptr, + Status::Invalid("condition cannot be null")); + GANDIVA_RETURN_FAILURE_IF_FALSE(configuration != nullptr, + Status::Invalid("configuration cannot be null")); + static Cache> cache; + FilterCacheKey cacheKey(schema, configuration, *(condition.get())); + std::shared_ptr cachedFilter = cache.GetModule(cacheKey); + if (cachedFilter != nullptr) { + *filter = cachedFilter; + return Status::OK(); + } + // Build LLVM generator, and generate code for the specified expression + std::unique_ptr llvm_gen; + Status status = LLVMGenerator::Make(configuration, &llvm_gen); + GANDIVA_RETURN_NOT_OK(status); + + // Run the validation on the expression. + // Return if the expression is invalid since we will not be able to process further. + ExprValidator expr_validator(llvm_gen->types(), schema); + status = expr_validator.Validate(condition); + GANDIVA_RETURN_NOT_OK(status); + + status = llvm_gen->Build({condition}); + GANDIVA_RETURN_NOT_OK(status); + + // Instantiate the filter with the completely built llvm generator + *filter = std::make_shared(std::move(llvm_gen), schema, configuration); + cache.PutModule(cacheKey, *filter); + return Status::OK(); +} + +Status Filter::Evaluate(const arrow::RecordBatch &batch, + std::shared_ptr out_selection) { + if (!batch.schema()->Equals(*schema_)) { + return Status::Invalid("Schema in RecordBatch must match the schema in Make()"); + } + if (batch.num_rows() == 0) { + return Status::Invalid("RecordBatch must be non-empty."); + } + if (out_selection == nullptr) { + return Status::Invalid("out_selection must be non-null."); + } + if (out_selection->GetMaxSlots() < batch.num_rows()) { + std::stringstream ss; + ss << "out_selection has " << out_selection->GetMaxSlots() + << " slots, which is less than the batch size " << batch.num_rows(); + return Status::Invalid(ss.str()); + } + + // Allocate three local_bitmaps (one for output, one for validity, one to compute the + // intersection). + LocalBitMapsHolder bitmaps(batch.num_rows(), 3 /*local_bitmaps*/); + int bitmap_size = bitmaps.GetLocalBitMapSize(); + + auto validity = std::make_shared(bitmaps.GetLocalBitMap(0), bitmap_size); + auto value = std::make_shared(bitmaps.GetLocalBitMap(1), bitmap_size); + auto array_data = + arrow::ArrayData::Make(arrow::boolean(), batch.num_rows(), {validity, value}); + + // Execute the expression(s). + auto status = llvm_generator_->Execute(batch, {array_data}); + GANDIVA_RETURN_NOT_OK(status); + + // Compute the intersection of the value and validity. + auto result = bitmaps.GetLocalBitMap(2); + BitMapAccumulator::IntersectBitMaps( + result, {bitmaps.GetLocalBitMap(0), bitmaps.GetLocalBitMap((1))}, batch.num_rows()); + + return out_selection->PopulateFromBitMap(result, bitmap_size, batch.num_rows() - 1); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/filter.h b/cpp/src/gandiva/codegen/filter.h new file mode 100644 index 00000000000..966d0a43750 --- /dev/null +++ b/cpp/src/gandiva/codegen/filter.h @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either condess or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_EXPR_FILTER_H +#define GANDIVA_EXPR_FILTER_H + +#include +#include +#include +#include + +#include "gandiva/arrow.h" +#include "gandiva/condition.h" +#include "gandiva/configuration.h" +#include "gandiva/selection_vector.h" +#include "gandiva/status.h" + +namespace gandiva { + +class LLVMGenerator; + +/// \brief filter records based on a condition. +/// +/// A filter is built for a specific schema and condition. Once the filter is built, it +/// can be used to evaluate many row batches. +class Filter { + public: + Filter(std::unique_ptr llvm_generator, SchemaPtr schema, + std::shared_ptr config); + + ~Filter() = default; + + /// Build a filter for the given schema and condition, with the default configuration. + /// + /// \param[in] : schema schema for the record batches, and the condition. + /// \param[in] : condition filter condition. + /// \param[out]: filter the returned filter object + static Status Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr *filter) { + return Make(schema, condition, ConfigurationBuilder::DefaultConfiguration(), filter); + } + + /// \brief Build a filter for the given schema and condition. + /// Customize the filter with runtime configuration. + /// + /// \param[in] : schema schema for the record batches, and the condition. + /// \param[in] : condition filter conditions. + /// \param[in] : config run time configuration. + /// \param[out]: filter the returned filter object + static Status Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr config, + std::shared_ptr *filter); + + /// Evaluate the specified record batch, and populate output selection vector. + /// + /// \param[in] : batch the record batch. schema should be the same as the one in 'Make' + /// \param[in/out]: out_selection the selection array with indices of rows that match + /// the condition. + Status Evaluate(const arrow::RecordBatch &batch, + std::shared_ptr out_selection); + + private: + const std::unique_ptr llvm_generator_; + const SchemaPtr schema_; + const std::shared_ptr configuration_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_FILTER_H diff --git a/cpp/src/gandiva/codegen/filter_cache_key.h b/cpp/src/gandiva/codegen/filter_cache_key.h new file mode 100644 index 00000000000..c591c6823a4 --- /dev/null +++ b/cpp/src/gandiva/codegen/filter_cache_key.h @@ -0,0 +1,66 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FILTER_CACHE_KEY_H +#define GANDIVA_FILTER_CACHE_KEY_H + +#include "boost/functional/hash.hpp" +#include "gandiva/arrow.h" +#include "gandiva/filter.h" + +namespace gandiva { +class FilterCacheKey { + public: + FilterCacheKey(SchemaPtr schema, std::shared_ptr configuration, + Expression &expression) + : schema_(schema), configuration_(configuration) { + static const int kSeedValue = 4; + size_t result = kSeedValue; + expression_as_string_ = expression.ToString(); + boost::hash_combine(result, expression_as_string_); + boost::hash_combine(result, configuration); + boost::hash_combine(result, schema_->ToString()); + hash_code_ = result; + } + + std::size_t Hash() const { return hash_code_; } + + bool operator==(const FilterCacheKey &other) const { + // arrow schema does not overload equality operators. + if (!(schema_->Equals(*other.schema().get(), true))) { + return false; + } + + if (configuration_ != other.configuration_) { + return false; + } + + if (expression_as_string_ != other.expression_as_string_) { + return false; + } + return true; + } + + bool operator!=(const FilterCacheKey &other) const { return !(*this == other); } + + SchemaPtr schema() const { return schema_; } + + private: + const SchemaPtr schema_; + const std::shared_ptr configuration_; + std::string expression_as_string_; + size_t hash_code_; +}; +} // namespace gandiva +#endif // GANDIVA_FILTER_CACHE_KEY_H diff --git a/cpp/src/gandiva/codegen/func_descriptor.h b/cpp/src/gandiva/codegen/func_descriptor.h new file mode 100644 index 00000000000..25aff174138 --- /dev/null +++ b/cpp/src/gandiva/codegen/func_descriptor.h @@ -0,0 +1,49 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCDESCRIPTOR_H +#define GANDIVA_FUNCDESCRIPTOR_H + +#include +#include + +#include "gandiva/arrow.h" + +namespace gandiva { + +/// Descriptor for a function in the expression. +class FuncDescriptor { + public: + FuncDescriptor(const std::string &name, const DataTypeVector ¶ms, + DataTypePtr return_type) + : name_(name), params_(params), return_type_(return_type) {} + + /// base function name. + const std::string &name() const { return name_; } + + /// Data types of the input params. + const DataTypeVector ¶ms() const { return params_; } + + /// Data type of the return parameter. + DataTypePtr return_type() const { return return_type_; } + + private: + std::string name_; + DataTypeVector params_; + DataTypePtr return_type_; +}; + +} // namespace gandiva + +#endif // GANDIVA_FUNCDESCRIPTOR_H diff --git a/cpp/src/gandiva/codegen/function_holder.h b/cpp/src/gandiva/codegen/function_holder.h new file mode 100644 index 00000000000..d5f9c4ee425 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_holder.h @@ -0,0 +1,30 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_HOLDER_H +#define GANDIVA_FUNCTION_HOLDER_H + +namespace gandiva { + +/// Holder for a function that can be invoked from LLVM. +class FunctionHolder { + public: + virtual ~FunctionHolder() = default; +}; + +using FunctionHolderPtr = std::shared_ptr; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_HOLDER_H diff --git a/cpp/src/gandiva/codegen/function_holder_registry.h b/cpp/src/gandiva/codegen/function_holder_registry.h new file mode 100644 index 00000000000..876bfee3ecf --- /dev/null +++ b/cpp/src/gandiva/codegen/function_holder_registry.h @@ -0,0 +1,62 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_HOLDER_REGISTRY_H +#define GANDIVA_FUNCTION_HOLDER_REGISTRY_H + +#include "codegen/function_holder.h" +#include "codegen/like_holder.h" +#include "codegen/node.h" +#include "gandiva/status.h" + +namespace gandiva { + +#define LAMBDA_MAKER(derived) \ + [](const FunctionNode &node, FunctionHolderPtr *holder) { \ + std::shared_ptr derived_instance; \ + auto status = derived::Make(node, &derived_instance); \ + if (status.ok()) { \ + *holder = derived_instance; \ + } \ + return status; \ + } + +/// Static registry of function holders. +class FunctionHolderRegistry { + public: + using maker_type = std::function; + using map_type = std::unordered_map; + + static Status Make(const std::string &name, const FunctionNode &node, + FunctionHolderPtr *holder) { + auto found = makers().find(name); + if (found == makers().end()) { + return Status::Invalid("function holder not registered for function " + name); + } + + return found->second(node, holder); + } + + private: + static map_type &makers() { + static map_type maker_map = { + {"like", LAMBDA_MAKER(LikeHolder)}, + }; + return maker_map; + } +}; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_HOLDER_REGISTRY_H diff --git a/cpp/src/gandiva/codegen/function_holder_stubs.cc b/cpp/src/gandiva/codegen/function_holder_stubs.cc new file mode 100644 index 00000000000..45d8c4b5f55 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_holder_stubs.cc @@ -0,0 +1,23 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/like_holder.h" + +// Wrapper C functions for "like" to be invoked from LLVM. +extern "C" bool like_utf8_utf8(int64_t ptr, const char *data, int data_len, + const char *pattern, int pattern_len) { + gandiva::helpers::LikeHolder *holder = + reinterpret_cast(ptr); + return (*holder)(std::string(data, data_len)); +} diff --git a/cpp/src/gandiva/codegen/function_registry.cc b/cpp/src/gandiva/codegen/function_registry.cc new file mode 100644 index 00000000000..d9cfcbc6897 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_registry.cc @@ -0,0 +1,383 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/function_registry.h" + +#include + +namespace gandiva { + +using arrow::binary; +using arrow::boolean; +using arrow::date64; +using arrow::float32; +using arrow::float64; +using arrow::int16; +using arrow::int32; +using arrow::int64; +using arrow::int8; +using arrow::uint16; +using arrow::uint32; +using arrow::uint64; +using arrow::uint8; +using arrow::utf8; +using std::vector; + +#define STRINGIFY(a) #a + +// Binary functions that : +// - have the same input type for both params +// - output type is same as the input type +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32 +#define BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, TYPE(), true, \ + RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Binary functions that : +// - have different input types, or output type +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. eg. mod_int64_int32 +#define BINARY_GENERIC_SAFE_NULL_IF_NULL(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + NativeFunction(#NAME, DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), true, \ + RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2)) + +// Binary functions that : +// - have the same input type +// - output type is boolean +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. +// eg. equal_int32_int32 +#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), true, \ + RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Unary functions that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32 +#define UNARY_SAFE_NULL_IF_NULL(NAME, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, DataTypeVector{IN_TYPE()}, OUT_TYPE(), true, \ + RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##IN_TYPE)) + +// Unary functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. eg. isnull_int32 +#define UNARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, boolean(), true, RESULT_NULL_NEVER, \ + STRINGIFY(NAME##_##TYPE)) + +// Binary functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type names, +// eg. is_distinct_from_int32_int32 +#define BINARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), true, \ + RESULT_NULL_NEVER, STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Extract functions (used with data/time types) that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. eg. extractYear_date +#define EXTRACT_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), true, RESULT_NULL_IF_NULL, \ + STRINGIFY(NAME##_##TYPE)) + +// Hash32 functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32_int8 +#define HASH32_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int32(), true, RESULT_NULL_NEVER, \ + STRINGIFY(NAME##_##TYPE)) + +// Hash32 functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32_int8 +#define HASH64_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), true, RESULT_NULL_NEVER, \ + STRINGIFY(NAME##_##TYPE)) + +// Hash32 functions with seed that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 +#define HASH32_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), int32()}, int32(), true, \ + RESULT_NULL_NEVER, STRINGIFY(NAME##WithSeed_##TYPE)) + +// Hash64 functions with seed that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 +#define HASH64_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), int64()}, int64(), true, \ + RESULT_NULL_NEVER, STRINGIFY(NAME##WithSeed_##TYPE)) + +// Iterate the inner macro over all numeric types +#define NUMERIC_TYPES(INNER, NAME) \ + INNER(NAME, int8), INNER(NAME, int16), INNER(NAME, int32), INNER(NAME, int64), \ + INNER(NAME, uint8), INNER(NAME, uint16), INNER(NAME, uint32), INNER(NAME, uint64), \ + INNER(NAME, float32), INNER(NAME, float64) + +// Iterate the inner macro over numeric and date/time types +#define NUMERIC_DATE_TYPES(INNER, NAME) \ + NUMERIC_TYPES(INNER, NAME), DATE_TYPES(INNER, NAME), TIME_TYPES(INNER, NAME) + +// Iterate the inner macro over all date types +#define DATE_TYPES(INNER, NAME) INNER(NAME, date64), INNER(NAME, timestamp) + +// Iterate the inner macro over all time types +#define TIME_TYPES(INNER, NAME) INNER(NAME, time32) + +// Iterate the inner macro over all data types +#define VAR_LEN_TYPES(INNER, NAME) INNER(NAME, utf8), INNER(NAME, binary) + +// Iterate the inner macro over all numeric types, date types and bool type +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME) \ + NUMERIC_DATE_TYPES(INNER, NAME), INNER(NAME, boolean) + +// Iterate the inner macro over all numeric types, date types, bool and varlen types +#define NUMERIC_BOOL_DATE_VAR_LEN_TYPES(INNER, NAME) \ + NUMERIC_BOOL_DATE_TYPES(INNER, NAME), VAR_LEN_TYPES(INNER, NAME) + +// list of registered native functions. +NativeFunction FunctionRegistry::pc_registry_[] = { + // Arithmetic operations + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, add), + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, subtract), + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, multiply), + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, divide), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int32, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64), + NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, equal), + NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, not_equal), + NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, less_than), + NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, less_than_or_equal_to), + NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than), + NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than_or_equal_to), + + // cast operations + UNARY_SAFE_NULL_IF_NULL(castBIGINT, int32, int64), + UNARY_SAFE_NULL_IF_NULL(castFLOAT4, int32, float32), + UNARY_SAFE_NULL_IF_NULL(castFLOAT4, int64, float32), + UNARY_SAFE_NULL_IF_NULL(castFLOAT8, int32, float64), + UNARY_SAFE_NULL_IF_NULL(castFLOAT8, int64, float64), + UNARY_SAFE_NULL_IF_NULL(castFLOAT8, float32, float64), + + // nullable never operations + NUMERIC_BOOL_DATE_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, isnull), + NUMERIC_BOOL_DATE_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, isnotnull), + NUMERIC_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, isnumeric), + + // nullable never binary operations + NUMERIC_BOOL_DATE_TYPES(BINARY_SAFE_NULL_NEVER_BOOL, is_distinct_from), + NUMERIC_BOOL_DATE_TYPES(BINARY_SAFE_NULL_NEVER_BOOL, is_not_distinct_from), + + // date/timestamp operations + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractMillennium), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractCentury), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDecade), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractYear), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDoy), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractQuarter), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractMonth), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractWeek), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDow), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDay), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractHour), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractMinute), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractSecond), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractEpoch), + + // date_trunc operations on date/timestamp + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Millennium), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Century), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Decade), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Year), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Quarter), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Month), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Week), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Day), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Hour), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Minute), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, date_trunc_Second), + + // time operations + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractHour), + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractMinute), + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractSecond), + + // timestamp diff operations + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffSecond, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffMinute, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffHour, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffDay, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffWeek, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffMonth, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffQuarter, timestamp, timestamp, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampdiffYear, timestamp, timestamp, int32), + + // timestamp add int32 operations + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddSecond, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMinute, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddHour, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddDay, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddWeek, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMonth, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddQuarter, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddYear, timestamp, int32, timestamp), + // date add int32 operations + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddSecond, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMinute, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddHour, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddDay, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddWeek, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMonth, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddQuarter, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddYear, date64, int32, date64), + + // timestamp add int64 operations + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddSecond, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMinute, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddHour, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddDay, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddWeek, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMonth, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddQuarter, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddYear, timestamp, int64, timestamp), + // date add int64 operations + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddSecond, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMinute, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddHour, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddDay, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddWeek, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddMonth, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddQuarter, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(timestampaddYear, date64, int64, date64), + + // date_add(date64, int32), date_add(timestamp, int32) + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, timestamp, int32, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, timestamp, int32, timestamp), + + // date_add(date64, int64), date_add(timestamp, int64) + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, timestamp, int64, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, timestamp, int64, timestamp), + + // date_add(int32, date64), date_add(int32, timestamp) + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, int32, date64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, int32, date64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, int32, timestamp, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, int32, timestamp, timestamp), + + // date_add(int64, date64), date_add(int64, timestamp) + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, int64, date64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, int64, date64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_add, int64, timestamp, timestamp), + BINARY_GENERIC_SAFE_NULL_IF_NULL(add, int64, timestamp, timestamp), + + // date_sub(date64, int32), subtract and date_diff + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_sub, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(subtract, date64, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_diff, date64, int32, date64), + // date_sub(timestamp, int32), subtract and date_diff + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_sub, timestamp, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(subtract, timestamp, int32, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_diff, timestamp, int32, date64), + + // date_sub(date64, int64), subtract and date_diff + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_sub, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(subtract, date64, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_diff, date64, int64, date64), + // date_sub(timestamp, int64), subtract and date_diff + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_sub, timestamp, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(subtract, timestamp, int64, date64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(date_diff, timestamp, int64, date64), + + // hash functions + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SAFE_NULL_NEVER, hash), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SAFE_NULL_NEVER, hash32), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SAFE_NULL_NEVER, hash32AsDouble), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SEED_SAFE_NULL_NEVER, hash32), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SEED_SAFE_NULL_NEVER, hash32AsDouble), + + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SAFE_NULL_NEVER, hash64), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SAFE_NULL_NEVER, hash64AsDouble), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SEED_SAFE_NULL_NEVER, hash64), + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SEED_SAFE_NULL_NEVER, hash64AsDouble), + + // utf8/binary operations + UNARY_SAFE_NULL_IF_NULL(octet_length, utf8, int32), + UNARY_SAFE_NULL_IF_NULL(octet_length, binary, int32), + UNARY_SAFE_NULL_IF_NULL(bit_length, utf8, int32), + UNARY_SAFE_NULL_IF_NULL(bit_length, binary, int32), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, equal), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, not_equal), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, less_than), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, less_than_or_equal_to), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than), + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than_or_equal_to), + + NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), true /*null_safe*/, + RESULT_NULL_IF_NULL, "like_utf8_utf8", true /*needs_holder*/), + + // Null internal (sample) + NativeFunction("half_or_null", DataTypeVector{int32()}, int32(), true /*null_safe*/, + RESULT_NULL_INTERNAL, "half_or_null_int32"), +}; // namespace gandiva + +FunctionRegistry::iterator FunctionRegistry::begin() const { + return std::begin(pc_registry_); +} + +FunctionRegistry::iterator FunctionRegistry::end() const { + return std::end(pc_registry_); +} + +FunctionRegistry::SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap(); + +FunctionRegistry::SignatureMap FunctionRegistry::InitPCMap() { + SignatureMap map; + + int num_entries = sizeof(pc_registry_) / sizeof(NativeFunction); + printf("Registry has %d pre-compiled functions\n", num_entries); + + for (int i = 0; i < num_entries; i++) { + const NativeFunction *entry = &pc_registry_[i]; + + DCHECK(map.find(&entry->signature()) == map.end()); + map[&entry->signature()] = entry; + // printf("%s -> %s\n", entry->signature().ToString().c_str(), + // entry->pc_name().c_str()); + } + return map; +} + +const NativeFunction *FunctionRegistry::LookupSignature( + const FunctionSignature &signature) const { + auto got = pc_registry_map_.find(&signature); + return got == pc_registry_map_.end() ? NULL : got->second; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/function_registry.h b/cpp/src/gandiva/codegen/function_registry.h new file mode 100644 index 00000000000..27f749da018 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_registry.h @@ -0,0 +1,64 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_REGISTRY_H +#define GANDIVA_FUNCTION_REGISTRY_H + +#include + +#include "codegen/native_function.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +///\brief Registry of pre-compiled IR functions. +class FunctionRegistry { + public: + using iterator = const NativeFunction *; + + /// Lookup a pre-compiled function by its signature. + const NativeFunction *LookupSignature(const FunctionSignature &signature) const; + + iterator begin() const; + iterator end() const; + + private: + struct KeyHash { + std::size_t operator()(const FunctionSignature *k) const { return k->Hash(); } + }; + + struct KeyEquals { + bool operator()(const FunctionSignature *s1, const FunctionSignature *s2) const { + return *s1 == *s2; + } + }; + + static DataTypePtr time32() { return arrow::time32(arrow::TimeUnit::MILLI); } + + static DataTypePtr time64() { return arrow::time64(arrow::TimeUnit::MICRO); } + + static DataTypePtr timestamp() { return arrow::timestamp(arrow::TimeUnit::MILLI); } + + typedef std::unordered_map + SignatureMap; + static SignatureMap InitPCMap(); + + static NativeFunction pc_registry_[]; + static SignatureMap pc_registry_map_; +}; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_REGISTRY_H diff --git a/cpp/src/gandiva/codegen/function_registry_test.cc b/cpp/src/gandiva/codegen/function_registry_test.cc new file mode 100644 index 00000000000..7d3a7230a0f --- /dev/null +++ b/cpp/src/gandiva/codegen/function_registry_test.cc @@ -0,0 +1,50 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/function_registry.h" + +#include + +namespace gandiva { + +class TestFunctionRegistry : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +TEST_F(TestFunctionRegistry, TestFound) { + FunctionSignature add_i32_i32("add", {arrow::int32(), arrow::int32()}, arrow::int32()); + + const NativeFunction *function = registry_.LookupSignature(add_i32_i32); + EXPECT_NE(function, nullptr); + EXPECT_EQ(function->signature(), add_i32_i32); + EXPECT_EQ(function->pc_name(), "add_int32_int32"); +} + +TEST_F(TestFunctionRegistry, TestNotFound) { + FunctionSignature addX_i32_i32("addX", {arrow::int32(), arrow::int32()}, + arrow::int32()); + EXPECT_EQ(registry_.LookupSignature(addX_i32_i32), nullptr); + + FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()}, + arrow::int64()); + EXPECT_EQ(registry_.LookupSignature(add_i32_i32_ret64), nullptr); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/function_signature.cc b/cpp/src/gandiva/codegen/function_signature.cc new file mode 100644 index 00000000000..e6f92b8b71c --- /dev/null +++ b/cpp/src/gandiva/codegen/function_signature.cc @@ -0,0 +1,63 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "boost/functional/hash.hpp" + +namespace gandiva { + +bool FunctionSignature::operator==(const FunctionSignature &other) const { + if (param_types_.size() != other.param_types_.size() || + !DataTypeEquals(ret_type_, other.ret_type_) || base_name_ != other.base_name_) { + return false; + } + + for (size_t idx = 0; idx < param_types_.size(); idx++) { + if (!DataTypeEquals(param_types_[idx], other.param_types_[idx])) { + return false; + } + } + return true; +} + +/// calculated based on base_name, datatpype id of parameters and datatype id +/// of return type. +std::size_t FunctionSignature::Hash() const { + static const size_t kSeedValue = 17; + size_t result = kSeedValue; + boost::hash_combine(result, base_name_); + boost::hash_combine(result, ret_type_->id()); + // not using hash_range since we only want to include the id from the data type + for (auto ¶m_type : param_types_) { + boost::hash_combine(result, param_type->id()); + } + return result; +} + +std::string FunctionSignature::ToString() const { + std::stringstream s; + + s << ret_type_->ToString() << " " << base_name_ << "("; + for (uint32_t i = 0; i < param_types_.size(); i++) { + if (i > 0) { + s << ", "; + } + + s << param_types_[i]->ToString(); + } + + s << ")"; + return s.str(); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/function_signature.h b/cpp/src/gandiva/codegen/function_signature.h new file mode 100644 index 00000000000..76c9888128b --- /dev/null +++ b/cpp/src/gandiva/codegen/function_signature.h @@ -0,0 +1,70 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_SIGNATURE_H +#define GANDIVA_FUNCTION_SIGNATURE_H + +#include +#include +#include + +#include "gandiva/arrow.h" +#include "gandiva/logging.h" + +namespace gandiva { + +/// \brief Signature for a function : includes the base name, input param types and +/// output types. +class FunctionSignature { + public: + FunctionSignature(const std::string &base_name, const DataTypeVector ¶m_types, + DataTypePtr ret_type) + : base_name_(base_name), param_types_(param_types), ret_type_(ret_type) { + DCHECK_GT(base_name.length(), 0); + DCHECK_GE(param_types.size(), 0); + for (auto it = param_types_.begin(); it != param_types_.end(); it++) { + DCHECK(*it); + } + DCHECK(ret_type); + } + + bool operator==(const FunctionSignature &other) const; + + /// calculated based on base_name, datatpype id of parameters and datatype id + /// of return type. + std::size_t Hash() const; + + DataTypePtr ret_type() const { return ret_type_; } + + const std::string &base_name() const { return base_name_; } + + DataTypeVector param_types() const { return param_types_; } + + std::string ToString() const; + + private: + // TODO : for some of the types, this shouldn't match type specific data. eg. for + // decimals, this shouldn't match precision/scale. + bool DataTypeEquals(const DataTypePtr left, const DataTypePtr right) const { + return left->Equals(right); + } + + std::string base_name_; + DataTypeVector param_types_; + DataTypePtr ret_type_; +}; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_SIGNATURE_H diff --git a/cpp/src/gandiva/codegen/function_signature_test.cc b/cpp/src/gandiva/codegen/function_signature_test.cc new file mode 100644 index 00000000000..3316a2d9273 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_signature_test.cc @@ -0,0 +1,102 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/function_signature.h" + +#include + +#include + +namespace gandiva { + +class TestFunctionSignature : public ::testing::Test { + protected: + virtual void SetUp() { + local_i32_type_ = std::make_shared(); + local_i64_type_ = std::make_shared(); + local_date32_type_ = std::make_shared(); + } + + virtual void TearDown() { + local_i32_type_.reset(); + local_i64_type_.reset(); + local_date32_type_.reset(); + } + + // virtual void TearDown() {} + DataTypePtr local_i32_type_; + DataTypePtr local_i64_type_; + DataTypePtr local_date32_type_; +}; + +TEST_F(TestFunctionSignature, TestToString) { + EXPECT_EQ( + FunctionSignature("myfunc", {arrow::int32(), arrow::float32()}, arrow::float64()) + .ToString(), + "double myfunc(int32, float)"); +} + +TEST_F(TestFunctionSignature, TestEqualsName) { + EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int32()), + FunctionSignature("add", {arrow::int32()}, arrow::int32())); + + EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int64()), + FunctionSignature("add", {local_i32_type_}, local_i64_type_)); + + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("sub", {arrow::int32()}, arrow::int32())); +} + +TEST_F(TestFunctionSignature, TestEqualsParamCount) { + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32(), arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::int32()}, arrow::int32())); +} + +TEST_F(TestFunctionSignature, TestEqualsParamValue) { + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::int64()}, arrow::int32())); + + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::float32(), arrow::float32()}, arrow::int32())); + + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32(), arrow::int64()}, arrow::int32()) == + FunctionSignature("add", {arrow::int64(), arrow::int32()}, arrow::int32())); + + EXPECT_EQ(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()), + FunctionSignature("extract_month", {local_date32_type_}, local_i64_type_)); + + EXPECT_FALSE(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()) == + FunctionSignature("extract_month", {arrow::date64()}, arrow::date32())); +} + +TEST_F(TestFunctionSignature, TestEqualsReturn) { + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int64()) == + FunctionSignature("add", {arrow::int32()}, arrow::int32())); +} + +TEST_F(TestFunctionSignature, TestHash) { + FunctionSignature f1("add", {arrow::int32(), arrow::int32()}, arrow::int64()); + FunctionSignature f2("add", {local_i32_type_, local_i32_type_}, local_i64_type_); + EXPECT_EQ(f1.Hash(), f2.Hash()); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/gandiva_aliases.h b/cpp/src/gandiva/codegen/gandiva_aliases.h new file mode 100644 index 00000000000..696cfefe4a8 --- /dev/null +++ b/cpp/src/gandiva/codegen/gandiva_aliases.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_ALIASES_H +#define GANDIVA_ALIASES_H + +#include +#include + +namespace gandiva { + +class Dex; +using DexPtr = std::shared_ptr; +using DexVector = std::vector>; + +class ValueValidityPair; +using ValueValidityPairPtr = std::shared_ptr; +using ValueValidityPairVector = std::vector; + +class FieldDescriptor; +using FieldDescriptorPtr = std::shared_ptr; + +class FuncDescriptor; +using FuncDescriptorPtr = std::shared_ptr; + +class LValue; +using LValuePtr = std::shared_ptr; + +class Expression; +using ExpressionPtr = std::shared_ptr; +using ExpressionVector = std::vector; + +class Condition; +using ConditionPtr = std::shared_ptr; + +class Node; +using NodePtr = std::shared_ptr; +using NodeVector = std::vector>; + +class EvalBatch; +using EvalBatchPtr = std::shared_ptr; + +class FunctionSignature; +using FuncSignaturePtr = std::shared_ptr; +using FuncSignatureVector = std::vector; + +} // namespace gandiva + +#endif // GANDIVA_ALIASES_H diff --git a/cpp/src/gandiva/codegen/like_holder.cc b/cpp/src/gandiva/codegen/like_holder.cc new file mode 100644 index 00000000000..6992b51efb6 --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder.cc @@ -0,0 +1,60 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/like_holder.h" + +#include +#include "codegen/node.h" +#include "codegen/regex_util.h" + +namespace gandiva { + +#ifdef GDV_HELPERS +namespace helpers { +#endif + +Status LikeHolder::Make(const FunctionNode &node, std::shared_ptr *holder) { + if (node.children().size() != 2) { + return Status::Invalid("'like' function requires two parameters"); + } + + auto literal = dynamic_cast(node.children().at(1).get()); + if (literal == nullptr) { + return Status::Invalid("'like' function requires a literal as the second parameter"); + } + + auto literal_type = literal->return_type()->id(); + if (literal_type != arrow::Type::STRING && literal_type != arrow::Type::BINARY) { + return Status::Invalid( + "'like' function requires a string literal as the second parameter"); + } + auto pattern = boost::get(literal->holder()); + return Make(pattern, holder); +} + +Status LikeHolder::Make(const std::string &sql_pattern, + std::shared_ptr *holder) { + std::string posix_pattern; + auto status = RegexUtil::SqlLikePatternToPosix(sql_pattern, posix_pattern); + GANDIVA_RETURN_NOT_OK(status); + + *holder = std::shared_ptr(new LikeHolder(posix_pattern)); + return Status::OK(); +} + +#ifdef GDV_HELPERS +} +#endif + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/like_holder.h b/cpp/src/gandiva/codegen/like_holder.h new file mode 100644 index 00000000000..51b14e42f0f --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder.h @@ -0,0 +1,54 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LIKE_HOLDER_H +#define GANDIVA_LIKE_HOLDER_H + +#include +#include "codegen/function_holder.h" +#include "codegen/node.h" +#include "gandiva/status.h" + +namespace gandiva { + +#ifdef GDV_HELPERS +namespace helpers { +#endif + +/// Function Holder for SQL 'like' +class LikeHolder : public FunctionHolder { + public: + ~LikeHolder() override = default; + + static Status Make(const FunctionNode &node, std::shared_ptr *holder); + + static Status Make(const std::string &sql_pattern, std::shared_ptr *holder); + + /// Return true if the data matches the pattern. + bool operator()(const std::string &data) { return std::regex_match(data, regex_); } + + private: + LikeHolder(const std::string &pattern) : pattern_(pattern), regex_(pattern) {} + + std::string pattern_; // posix pattern string, to help debugging + std::regex regex_; // compiled regex for the pattern +}; + +#ifdef GDV_HELPERS +} +#endif + +} // namespace gandiva + +#endif // GANDIVA_LIKE_HOLDER_H diff --git a/cpp/src/gandiva/codegen/like_holder_test.cc b/cpp/src/gandiva/codegen/like_holder_test.cc new file mode 100644 index 00000000000..d349e4f0726 --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder_test.cc @@ -0,0 +1,81 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/like_holder.h" +#include "codegen/regex_util.h" + +#include +#include + +#include + +namespace gandiva { + +class TestLikeHolder : public ::testing::Test {}; + +TEST_F(TestLikeHolder, TestMatchAny) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make("ab%", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like("ab")); + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abcd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("cab")); +} + +TEST_F(TestLikeHolder, TestMatchOne) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make("ab_", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("abcd")); + EXPECT_FALSE(like("dabc")); +} + +TEST_F(TestLikeHolder, TestPosixSpecial) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make(".*ab_", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like(".*abc")); // . and * aren't special in sql regex + EXPECT_FALSE(like("xxabc")); +} + +TEST_F(TestLikeHolder, TestRegexEscape) { + std::string res; + auto status = RegexUtil::SqlLikePatternToPosix("#%hello#_abc_def##", '#', res); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(res, "%hello_abc.def#"); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/literal_holder.h b/cpp/src/gandiva/codegen/literal_holder.h new file mode 100644 index 00000000000..c05121f5218 --- /dev/null +++ b/cpp/src/gandiva/codegen/literal_holder.h @@ -0,0 +1,30 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LITERAL_HOLDER +#define GANDIVA_LITERAL_HOLDER + +#include + +#include + +namespace gandiva { + +using LiteralHolder = + boost::variant; + +} // namespace gandiva + +#endif // GANDIVA_LITERAL_HOLDER diff --git a/cpp/src/gandiva/codegen/llvm_generator.cc b/cpp/src/gandiva/codegen/llvm_generator.cc new file mode 100644 index 00000000000..6e92a8420e8 --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_generator.cc @@ -0,0 +1,1020 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/llvm_generator.h" + +#include +#include +#include +#include +#include + +#include "codegen/bitmap_accumulator.h" +#include "codegen/dex.h" +#include "codegen/expr_decomposer.h" +#include "codegen/function_registry.h" +#include "codegen/lvalue.h" +#include "gandiva/expression.h" + +namespace gandiva { + +#define ADD_TRACE(...) \ + if (enable_ir_traces_) { \ + AddTrace(__VA_ARGS__); \ + } + +LLVMGenerator::LLVMGenerator() + : dump_ir_(false), optimise_ir_(true), enable_ir_traces_(false) {} + +Status LLVMGenerator::Make(std::shared_ptr config, + std::unique_ptr *llvm_generator) { + std::unique_ptr llvmgen_obj(new LLVMGenerator()); + Status status = Engine::Make(config, &(llvmgen_obj->engine_)); + GANDIVA_RETURN_NOT_OK(status); + llvmgen_obj->types_.reset(new LLVMTypes(*(llvmgen_obj->engine_)->context())); + *llvm_generator = std::move(llvmgen_obj); + return Status::OK(); +} + +Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) { + int idx = compiled_exprs_.size(); + + // decompose the expression to separate out value and validities. + ExprDecomposer decomposer(function_registry_, annotator_); + ValueValidityPairPtr value_validity; + auto status = decomposer.Decompose(*expr->root(), &value_validity); + GANDIVA_RETURN_NOT_OK(status); + + // Generate the IR function for the decomposed expression. + llvm::Function *ir_function = nullptr; + status = CodeGenExprValue(value_validity->value_expr(), output, idx, &ir_function); + GANDIVA_RETURN_NOT_OK(status); + + std::unique_ptr compiled_expr( + new CompiledExpr(value_validity, output, ir_function)); + compiled_exprs_.push_back(std::move(compiled_expr)); + return Status::OK(); +} + +/// Build and optimise module for projection expression. +Status LLVMGenerator::Build(const ExpressionVector &exprs) { + Status status; + + for (auto &expr : exprs) { + auto output = annotator_.AddOutputFieldDescriptor(expr->result()); + status = Add(expr, output); + GANDIVA_RETURN_NOT_OK(status); + } + + // optimise, compile and finalize the module + status = engine_->FinalizeModule(optimise_ir_, dump_ir_); + GANDIVA_RETURN_NOT_OK(status); + + // setup the jit functions for each expression. + for (auto &compiled_expr : compiled_exprs_) { + llvm::Function *ir_func = compiled_expr->ir_function(); + EvalFunc fn = reinterpret_cast(engine_->CompiledFunction(ir_func)); + compiled_expr->set_jit_function(fn); + } + return Status::OK(); +} + +/// Execute the compiled module against the provided vectors. +Status LLVMGenerator::Execute(const arrow::RecordBatch &record_batch, + const ArrayDataVector &output_vector) { + DCHECK_GT(record_batch.num_rows(), 0); + + auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector); + DCHECK_GT(eval_batch->GetNumBuffers(), 0); + + for (auto &compiled_expr : compiled_exprs_) { + // generate data/offset vectors. + EvalFunc jit_function = compiled_expr->jit_function(); + jit_function(eval_batch->GetBufferArray(), eval_batch->GetLocalBitMapArray(), + record_batch.num_rows()); + + // generate validity vectors. + ComputeBitMapsForExpr(*compiled_expr, *eval_batch); + } + return Status::OK(); +} + +llvm::Value *LLVMGenerator::LoadVectorAtIndex(llvm::Value *arg_addrs, int idx, + const std::string &name) { + llvm::IRBuilder<> &builder = ir_builder(); + llvm::Value *offset = + builder.CreateGEP(arg_addrs, types_->i32_constant(idx), name + "_mem_addr"); + return builder.CreateLoad(offset, name + "_mem"); +} + +/// Get reference to validity array at specified index in the args list. +llvm::Value *LLVMGenerator::GetValidityReference(llvm::Value *arg_addrs, int idx, + FieldPtr field) { + const std::string &name = field->name(); + llvm::Value *load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder().CreateIntToPtr(load, types_->i64_ptr_type(), name + "_varray"); +} + +/// Get reference to data array at specified index in the args list. +llvm::Value *LLVMGenerator::GetDataReference(llvm::Value *arg_addrs, int idx, + FieldPtr field) { + const std::string &name = field->name(); + llvm::Value *load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Type *base_type = types_->DataVecType(field->type()); + llvm::Value *ret; + if (base_type->isPointerTy()) { + ret = ir_builder().CreateIntToPtr(load, base_type, name + "_darray"); + } else { + llvm::Type *pointer_type = types_->ptr_type(base_type); + ret = ir_builder().CreateIntToPtr(load, pointer_type, name + "_darray"); + } + return ret; +} + +/// Get reference to offsets array at specified index in the args list. +llvm::Value *LLVMGenerator::GetOffsetsReference(llvm::Value *arg_addrs, int idx, + FieldPtr field) { + const std::string &name = field->name(); + llvm::Value *load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder().CreateIntToPtr(load, types_->i32_ptr_type(), name + "_oarray"); +} + +/// Get reference to local bitmap array at specified index in the args list. +llvm::Value *LLVMGenerator::GetLocalBitMapReference(llvm::Value *arg_bitmaps, int idx) { + llvm::Value *load = LoadVectorAtIndex(arg_bitmaps, idx, ""); + return ir_builder().CreateIntToPtr(load, types_->i64_ptr_type(), + std::to_string(idx) + "_lbmap"); +} + +/// \brief Generate code for one expression. + +// Sample IR code for "c1:int + c2:int" +// +// The C-code equivalent is : +// ------------------------------ +// int expr_0(int64_t *addrs, int64_t *local_bitmaps, int nrecords) { +// int *outVec = (int *) addrs[5]; +// int *c0Vec = (int *) addrs[1]; +// int *c1Vec = (int *) addrs[3]; +// for (int loop_var = 0; loop_var < nrecords; ++loop_var) { +// int c0 = c0Vec[loop_var]; +// int c1 = c1Vec[loop_var]; +// int out = c0 + c1; +// outVec[loop_var] = out; +// } +// } +// +// IR Code +// -------- +// +// define i32 @expr_0(i64* %args, i64* %local_bitmaps, i32 %nrecords) { +// entry: +// %outmemAddr = getelementptr i64, i64* %args, i32 5 +// %outmem = load i64, i64* %outmemAddr +// %outVec = inttoptr i64 %outmem to i32* +// %c0memAddr = getelementptr i64, i64* %args, i32 1 +// %c0mem = load i64, i64* %c0memAddr +// %c0Vec = inttoptr i64 %c0mem to i32* +// %c1memAddr = getelementptr i64, i64* %args, i32 3 +// %c1mem = load i64, i64* %c1memAddr +// %c1Vec = inttoptr i64 %c1mem to i32* +// br label %loop +// loop: ; preds = %loop, %entry +// %loop_var = phi i32 [ 0, %entry ], [ %"loop_var+1", %loop ] +// %"loop_var+1" = add i32 %loop_var, 1 +// %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var +// %c0 = load i32, i32* %0 +// %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var +// %c1 = load i32, i32* %1 +// %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1) +// %2 = getelementptr i32, i32* %outVec, i32 %loop_var +// store i32 %add_int_int, i32* %2 +// %"loop_var < nrec" = icmp slt i32 %"loop_var+1", %nrecords +// br i1 %"loop_var < nrec", label %loop, label %exit +// exit: ; preds = %loop +// ret i32 0 +// } + +Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr output, + int suffix_idx, llvm::Function **fn) { + llvm::IRBuilder<> &builder = ir_builder(); + + // Create fn prototype : + // int expr_1 (long **addrs, long **bitmaps, int nrec) + std::vector arguments; + arguments.push_back(types_->i64_ptr_type()); + arguments.push_back(types_->i64_ptr_type()); + arguments.push_back(types_->i32_type()); + llvm::FunctionType *prototype = + llvm::FunctionType::get(types_->i32_type(), arguments, false /*isVarArg*/); + + // Create fn + std::string func_name = "expr_" + std::to_string(suffix_idx); + engine_->AddFunctionToCompile(func_name); + *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name, + module()); + GANDIVA_RETURN_FAILURE_IF_FALSE((*fn != nullptr), + Status::CodeGenError("Error creating function.")); + // Name the arguments + llvm::Function::arg_iterator args = (*fn)->arg_begin(); + llvm::Value *arg_addrs = &*args; + arg_addrs->setName("args"); + ++args; + llvm::Value *arg_local_bitmaps = &*args; + arg_local_bitmaps->setName("local_bitmaps"); + ++args; + llvm::Value *arg_nrecords = &*args; + arg_nrecords->setName("nrecords"); + ++args; + + llvm::BasicBlock *loop_entry = llvm::BasicBlock::Create(context(), "entry", *fn); + llvm::BasicBlock *loop_body = llvm::BasicBlock::Create(context(), "loop", *fn); + llvm::BasicBlock *loop_exit = llvm::BasicBlock::Create(context(), "exit", *fn); + + // Add reference to output vector (in entry block) + builder.SetInsertPoint(loop_entry); + llvm::Value *output_ref = + GetDataReference(arg_addrs, output->data_idx(), output->field()); + + // Loop body + builder.SetInsertPoint(loop_body); + + // define loop_var : start with 0, +1 after each iter + llvm::PHINode *loop_var = builder.CreatePHI(types_->i32_type(), 2, "loop_var"); + + // The visitor can add code to both the entry/loop blocks. + Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, loop_var); + value_expr->Accept(visitor); + LValuePtr output_value = visitor.result(); + + // The "current" block may have changed due to code generation in the visitor. + llvm::BasicBlock *loop_body_tail = builder.GetInsertBlock(); + + // add jump to "loop block" at the end of the "setup block". + builder.SetInsertPoint(loop_entry); + builder.CreateBr(loop_body); + + // save the value in the output vector. + builder.SetInsertPoint(loop_body_tail); + if (output->Type()->id() == arrow::Type::BOOL) { + SetPackedBitValue(output_ref, loop_var, output_value->data()); + } else { + llvm::Value *slot_offset = builder.CreateGEP(output_ref, loop_var); + builder.CreateStore(output_value->data(), slot_offset); + } + ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data()); + + // check loop_var + loop_var->addIncoming(types_->i32_constant(0), loop_entry); + llvm::Value *loop_update = + builder.CreateAdd(loop_var, types_->i32_constant(1), "loop_var+1"); + loop_var->addIncoming(loop_update, loop_body_tail); + + llvm::Value *loop_var_check = + builder.CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec"); + builder.CreateCondBr(loop_var_check, loop_body, loop_exit); + + // Loop exit + builder.SetInsertPoint(loop_exit); + builder.CreateRet(types_->i32_constant(0)); + return Status::OK(); +} + +/// Return value of a bit in bitMap. +llvm::Value *LLVMGenerator::GetPackedBitValue(llvm::Value *bitmap, + llvm::Value *position) { + ADD_TRACE("fetch bit at position %T", position); + + llvm::Value *bitmap8 = ir_builder().CreateBitCast( + bitmap, types_->ptr_type(types_->i8_type()), "bitMapCast"); + return AddFunctionCall("bitMapGetBit", types_->i1_type(), {bitmap8, position}); +} + +/// Set the value of a bit in bitMap. +void LLVMGenerator::SetPackedBitValue(llvm::Value *bitmap, llvm::Value *position, + llvm::Value *value) { + ADD_TRACE("set bit at position %T", position); + ADD_TRACE(" to value %T ", value); + + llvm::Value *bitmap8 = ir_builder().CreateBitCast( + bitmap, types_->ptr_type(types_->i8_type()), "bitMapCast"); + AddFunctionCall("bitMapSetBit", types_->void_type(), {bitmap8, position, value}); +} + +/// Clear the bit in bitMap if value = false. +void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value *bitmap, llvm::Value *position, + llvm::Value *value) { + ADD_TRACE("ClearIfFalse bit at position %T", position); + ADD_TRACE(" value %T ", value); + + llvm::Value *bitmap8 = ir_builder().CreateBitCast( + bitmap, types_->ptr_type(types_->i8_type()), "bitMapCast"); + AddFunctionCall("bitMapClearBitIfFalse", types_->void_type(), + {bitmap8, position, value}); +} + +/// Extract the bitmap addresses, and do an intersection. +void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr &compiled_expr, + const EvalBatch &eval_batch) { + auto validities = compiled_expr.value_validity()->validity_exprs(); + + // Extract all the source bitmap addresses. + BitMapAccumulator accumulator(eval_batch); + for (auto &validity_dex : validities) { + validity_dex->Accept(accumulator); + } + + // Extract the destination bitmap address. + int out_idx = compiled_expr.output()->validity_idx(); + uint8_t *dst_bitmap = eval_batch.GetBuffer(out_idx); + + // Compute the destination bitmap. + accumulator.ComputeResult(dst_bitmap); +} + +void LLVMGenerator::CheckAndAddPrototype(const std::string &full_name, + llvm::Type *ret_type, + const std::vector &args) { + auto fn = module()->getFunction(full_name); + if (fn != nullptr) { + // prototype already added to module. + return; + } + + // Create fn prototype for evaluation + std::vector arg_types; + for (auto &value : args) { + arg_types.push_back(value->getType()); + } + llvm::FunctionType *prototype = + llvm::FunctionType::get(ret_type, arg_types, false /*isVarArg*/); + + fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, full_name, + module()); + DCHECK_NE(fn, nullptr) << " cpp function " << full_name << " does not exist"; +} + +llvm::Value *LLVMGenerator::AddFunctionCall(const std::string &full_name, + llvm::Type *ret_type, + const std::vector &args, + bool has_holder) { + if (has_holder) { + CheckAndAddPrototype(full_name, ret_type, args); + } else { + // add to list of functions that need to be compiled + engine_->AddFunctionToCompile(full_name); + } + + // find the llvm function. + llvm::Function *fn = module()->getFunction(full_name); + DCHECK_NE(fn, nullptr) << "missing function " + full_name; + + if (enable_ir_traces_ && !full_name.compare("printf") && + !full_name.compare("printff")) { + // Trace for debugging + ADD_TRACE("invoke native fn " + full_name); + } + + // build a call to the llvm function. + llvm::Value *value; + if (ret_type->isVoidTy()) { + // void functions can't have a name for the call. + value = ir_builder().CreateCall(fn, args); + } else { + value = ir_builder().CreateCall(fn, args, full_name); + DCHECK(value->getType() == ret_type); + } + return value; +} + +#define ADD_VISITOR_TRACE(...) \ + if (generator_->enable_ir_traces_) { \ + generator_->AddTrace(__VA_ARGS__); \ + } + +// Visitor for generating the code for a decomposed expression. +LLVMGenerator::Visitor::Visitor(LLVMGenerator *generator, llvm::Function *function, + llvm::BasicBlock *entry_block, llvm::Value *arg_addrs, + llvm::Value *arg_local_bitmaps, llvm::Value *loop_var) + : generator_(generator), + function_(function), + entry_block_(entry_block), + arg_addrs_(arg_addrs), + arg_local_bitmaps_(arg_local_bitmaps), + loop_var_(loop_var) { + ADD_VISITOR_TRACE("Iteration %T", loop_var); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex &dex) { + llvm::IRBuilder<> &builder = ir_builder(); + + llvm::Value *slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + + llvm::Value *slot_value; + if (dex.FieldType()->id() == arrow::Type::BOOL) { + slot_value = generator_->GetPackedBitValue(slot_ref, loop_var_); + } else { + llvm::Value *slot_offset = builder.CreateGEP(slot_ref, loop_var_); + slot_value = builder.CreateLoad(slot_offset, dex.FieldName()); + } + + ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T", + slot_value); + result_.reset(new LValue(slot_value)); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex &dex) { + llvm::IRBuilder<> &builder = ir_builder(); + llvm::Value *slot; + + // compute len from the offsets array. + llvm::Value *offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + + // => offset_start = offsets[loop_var] + slot = builder.CreateGEP(offsets_slot_ref, loop_var_); + llvm::Value *offset_start = builder.CreateLoad(slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value *loop_var_next = + builder.CreateAdd(loop_var_, generator_->types_->i32_constant(1), "loop_var+1"); + slot = builder.CreateGEP(offsets_slot_ref, loop_var_next); + llvm::Value *offset_end = builder.CreateLoad(slot, "offset_end"); + + // => len_value = offset_end - offset_start + llvm::Value *len_value = + builder.CreateSub(offset_end, offset_start, dex.FieldName() + "Len"); + + // get the data from the data array, at offset 'offset_start'. + llvm::Value *data_slot_ref = + GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value *data_value = builder.CreateGEP(data_slot_ref, offset_start); + ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T", + len_value); + result_.reset(new LValue(data_value, len_value)); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex &dex) { + llvm::Value *slot_ref = + GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field()); + llvm::Value *validity = generator_->GetPackedBitValue(slot_ref, loop_var_); + + ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity); + result_.reset(new LValue(validity)); +} + +void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex &dex) { + llvm::Value *slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx()); + llvm::Value *validity = generator_->GetPackedBitValue(slot_ref, loop_var_); + + ADD_VISITOR_TRACE( + "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T", + validity); + result_.reset(new LValue(validity)); +} + +void LLVMGenerator::Visitor::Visit(const TrueDex &dex) { + result_.reset(new LValue(generator_->types_->true_constant())); +} + +void LLVMGenerator::Visitor::Visit(const FalseDex &dex) { + result_.reset(new LValue(generator_->types_->false_constant())); +} + +void LLVMGenerator::Visitor::Visit(const LiteralDex &dex) { + LLVMTypes *types = generator_->types_.get(); + llvm::Value *value = nullptr; + llvm::Value *len = nullptr; + + switch (dex.type()->id()) { + case arrow::Type::BOOL: + value = types->i1_constant(boost::get(dex.holder())); + break; + + case arrow::Type::UINT8: + value = types->i8_constant(boost::get(dex.holder())); + break; + + case arrow::Type::UINT16: + value = types->i16_constant(boost::get(dex.holder())); + break; + + case arrow::Type::UINT32: + value = types->i32_constant(boost::get(dex.holder())); + break; + + case arrow::Type::UINT64: + value = types->i64_constant(boost::get(dex.holder())); + break; + + case arrow::Type::INT8: + value = types->i8_constant(boost::get(dex.holder())); + break; + + case arrow::Type::INT16: + value = types->i16_constant(boost::get(dex.holder())); + break; + + case arrow::Type::INT32: + value = types->i32_constant(boost::get(dex.holder())); + break; + + case arrow::Type::INT64: + value = types->i64_constant(boost::get(dex.holder())); + break; + + case arrow::Type::FLOAT: + value = types->float_constant(boost::get(dex.holder())); + break; + + case arrow::Type::DOUBLE: + value = types->double_constant(boost::get(dex.holder())); + break; + + case arrow::Type::STRING: + case arrow::Type::BINARY: { + const std::string &str = boost::get(dex.holder()); + + llvm::Constant *str_int_cast = types->i64_constant((int64_t)str.c_str()); + value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type()); + len = types->i32_constant(str.length()); + break; + } + + case arrow::Type::DATE64: + value = types->i64_constant(boost::get(dex.holder())); + break; + + case arrow::Type::TIME32: + value = types->i32_constant(boost::get(dex.holder())); + break; + + case arrow::Type::TIME64: + value = types->i64_constant(boost::get(dex.holder())); + break; + + case arrow::Type::TIMESTAMP: + value = types->i64_constant(boost::get(dex.holder())); + break; + + default: + DCHECK(0); + } + ADD_VISITOR_TRACE("visit Literal %T", value); + result_.reset(new LValue(value, len)); +} + +void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex &dex) { + ADD_VISITOR_TRACE("visit NonNullableFunc base function " + + dex.func_descriptor()->name()); + LLVMTypes *types = generator_->types_.get(); + + // build the function params (ignore validity). + auto params = BuildParams(dex.function_holder().get(), dex.args(), false); + + const NativeFunction *native_function = dex.native_function(); + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); + + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); + result_.reset(new LValue(value)); +} + +void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex &dex) { + ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name()); + LLVMTypes *types = generator_->types_.get(); + + // build function params along with validity. + auto params = BuildParams(dex.function_holder().get(), dex.args(), true); + + const NativeFunction *native_function = dex.native_function(); + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); + result_.reset(new LValue(value)); +} + +void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex &dex) { + ADD_VISITOR_TRACE("visit NullableInternal base function " + + dex.func_descriptor()->name()); + llvm::IRBuilder<> &builder = ir_builder(); + LLVMTypes *types = generator_->types_.get(); + + // build function params along with validity. + auto params = BuildParams(dex.function_holder().get(), dex.args(), true); + + // add an extra arg for validity (alloced on stack). + llvm::AllocaInst *result_valid_ptr = + new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_); + params.push_back(result_valid_ptr); + + const NativeFunction *native_function = dex.native_function(); + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); + + // load the result validity and truncate to i1. + llvm::Value *result_valid_i8 = builder.CreateLoad(result_valid_ptr); + llvm::Value *result_valid = builder.CreateTrunc(result_valid_i8, types->i1_type()); + + // set validity bit in the local bitmap. + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid); + + result_.reset(new LValue(value)); +} + +void LLVMGenerator::Visitor::Visit(const IfDex &dex) { + ADD_VISITOR_TRACE("visit IfExpression"); + llvm::IRBuilder<> &builder = ir_builder(); + LLVMTypes *types = generator_->types_.get(); + + // Evaluate condition. + LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv()); + + // Check if the result is valid, and there is match. + llvm::Value *validAndMatched = + builder.CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch"); + + // Create blocks for the then, else and merge cases. + llvm::LLVMContext &context = generator_->context(); + llvm::BasicBlock *then_bb = llvm::BasicBlock::Create(context, "then", function_); + llvm::BasicBlock *else_bb = llvm::BasicBlock::Create(context, "else", function_); + llvm::BasicBlock *merge_bb = llvm::BasicBlock::Create(context, "merge", function_); + + builder.CreateCondBr(validAndMatched, then_bb, else_bb); + + // Emit the then block. + builder.SetInsertPoint(then_bb); + ADD_VISITOR_TRACE("branch to then block"); + LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv()); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity()); + ADD_VISITOR_TRACE("IfExpression result validity %T in matching then", + then_lvalue->validity()); + builder.CreateBr(merge_bb); + + // refresh then_bb for phi (could have changed due to code generation of then_vv). + then_bb = builder.GetInsertBlock(); + + // Emit the else block. + builder.SetInsertPoint(else_bb); + LValuePtr else_lvalue; + if (dex.is_terminal_else()) { + ADD_VISITOR_TRACE("branch to terminal else block"); + + else_lvalue = BuildValueAndValidity(dex.else_vv()); + // update the local bitmap with the validity. + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity()); + ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else", + else_lvalue->validity()); + } else { + ADD_VISITOR_TRACE("branch to non-terminal else block"); + + // this is a non-terminal else. let the child (nested if/else) handle validity. + auto value_expr = dex.else_vv().value_expr(); + value_expr->Accept(*this); + else_lvalue = result(); + } + builder.CreateBr(merge_bb); + + // refresh else_bb for phi (could have changed due to code generation of else_vv). + else_bb = builder.GetInsertBlock(); + + // Emit the merge block. + builder.SetInsertPoint(merge_bb); + llvm::Type *result_llvm_type = types->DataVecType(dex.result_type()); + llvm::PHINode *result_value = builder.CreatePHI(result_llvm_type, 2, "res_value"); + result_value->addIncoming(then_lvalue->data(), then_bb); + result_value->addIncoming(else_lvalue->data(), else_bb); + + llvm::PHINode *result_length = nullptr; + if (then_lvalue->length() != nullptr) { + result_length = builder.CreatePHI(types->i32_type(), 2, "res_length"); + result_length->addIncoming(then_lvalue->length(), then_bb); + result_length->addIncoming(else_lvalue->length(), else_bb); + + ADD_VISITOR_TRACE("IfExpression result length %T", result_length); + } + ADD_VISITOR_TRACE("IfExpression result value %T", result_value); + + result_.reset(new LValue(result_value, result_length)); +} + +// Boolean AND +// if any arg is valid and false, +// short-circuit and return FALSE (value=false, valid=true) +// else if all args are valid and true +// return TRUE (value=true, valid=true) +// else +// return NULL (value=true, valid=false) + +void LLVMGenerator::Visitor::Visit(const BooleanAndDex &dex) { + ADD_VISITOR_TRACE("visit BooleanAndExpression"); + llvm::IRBuilder<> &builder = ir_builder(); + LLVMTypes *types = generator_->types_.get(); + llvm::LLVMContext &context = generator_->context(); + + // Create blocks for short-circuit. + llvm::BasicBlock *short_circuit_bb = + llvm::BasicBlock::Create(context, "short_circuit", function_); + llvm::BasicBlock *non_short_circuit_bb = + llvm::BasicBlock::Create(context, "non_short_circuit", function_); + llvm::BasicBlock *merge_bb = llvm::BasicBlock::Create(context, "merge", function_); + + llvm::Value *all_exprs_valid = types->true_constant(); + for (auto &pair : dex.args()) { + LValuePtr current = BuildValueAndValidity(*pair); + + ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data()); + ADD_VISITOR_TRACE("BooleanAndExpression arg valdity %T", current->validity()); + + // short-circuit if valid and false + llvm::Value *is_false = builder.CreateNot(current->data()); + llvm::Value *valid_and_false = + builder.CreateAnd(is_false, current->validity(), "valid_and_false"); + + llvm::BasicBlock *else_bb = llvm::BasicBlock::Create(context, "else", function_); + builder.CreateCondBr(valid_and_false, short_circuit_bb, else_bb); + + // Emit the else block. + builder.SetInsertPoint(else_bb); + // remember if any nulls were encountered. + all_exprs_valid = + builder.CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd"); + // continue to evaluate the next pair in list. + } + builder.CreateBr(non_short_circuit_bb); + + // Short-circuit case (atleast one of the expressions is valid and false). + // No need to set validity bit (valid by default). + builder.SetInsertPoint(short_circuit_bb); + ADD_VISITOR_TRACE("BooleanAndExpression result value false"); + ADD_VISITOR_TRACE("BooleanAndExpression result valdity true"); + builder.CreateBr(merge_bb); + + // non short-circuit case (All expressions are either true or null). + // result valid if all of the exprs are non-null. + builder.SetInsertPoint(non_short_circuit_bb); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid); + ADD_VISITOR_TRACE("BooleanAndExpression result value true"); + ADD_VISITOR_TRACE("BooleanAndExpression result valdity %T", all_exprs_valid); + builder.CreateBr(merge_bb); + + builder.SetInsertPoint(merge_bb); + llvm::PHINode *result_value = builder.CreatePHI(types->i1_type(), 2, "res_value"); + result_value->addIncoming(types->false_constant(), short_circuit_bb); + result_value->addIncoming(types->true_constant(), non_short_circuit_bb); + result_.reset(new LValue(result_value)); +} + +// Boolean OR +// if any arg is valid and true, +// short-circuit and return TRUE (value=true, valid=true) +// else if all args are valid and false +// return FALSE (value=false, valid=true) +// else +// return NULL (value=false, valid=false) + +void LLVMGenerator::Visitor::Visit(const BooleanOrDex &dex) { + ADD_VISITOR_TRACE("visit BooleanOrExpression"); + llvm::IRBuilder<> &builder = ir_builder(); + LLVMTypes *types = generator_->types_.get(); + llvm::LLVMContext &context = generator_->context(); + + // Create blocks for short-circuit. + llvm::BasicBlock *short_circuit_bb = + llvm::BasicBlock::Create(context, "short_circuit", function_); + llvm::BasicBlock *non_short_circuit_bb = + llvm::BasicBlock::Create(context, "non_short_circuit", function_); + llvm::BasicBlock *merge_bb = llvm::BasicBlock::Create(context, "merge", function_); + + llvm::Value *all_exprs_valid = types->true_constant(); + for (auto &pair : dex.args()) { + LValuePtr current = BuildValueAndValidity(*pair); + + ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data()); + ADD_VISITOR_TRACE("BooleanOrExpression arg valdity %T", current->validity()); + + // short-circuit if valid and true. + llvm::Value *valid_and_true = + builder.CreateAnd(current->data(), current->validity(), "valid_and_true"); + + llvm::BasicBlock *else_bb = llvm::BasicBlock::Create(context, "else", function_); + builder.CreateCondBr(valid_and_true, short_circuit_bb, else_bb); + + // Emit the else block. + builder.SetInsertPoint(else_bb); + // remember if any nulls were encountered. + all_exprs_valid = + builder.CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd"); + // continue to evaluate the next pair in list. + } + builder.CreateBr(non_short_circuit_bb); + + // Short-circuit case (atleast one of the expressions is valid and true). + // No need to set validity bit (valid by default). + builder.SetInsertPoint(short_circuit_bb); + ADD_VISITOR_TRACE("BooleanOrExpression result value true"); + ADD_VISITOR_TRACE("BooleanOrExpression result valdity true"); + builder.CreateBr(merge_bb); + + // non short-circuit case (All expressions are either false or null). + // result valid if all of the exprs are non-null. + builder.SetInsertPoint(non_short_circuit_bb); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid); + ADD_VISITOR_TRACE("BooleanOrExpression result value false"); + ADD_VISITOR_TRACE("BooleanOrExpression result valdity %T", all_exprs_valid); + builder.CreateBr(merge_bb); + + builder.SetInsertPoint(merge_bb); + llvm::PHINode *result_value = builder.CreatePHI(types->i1_type(), 2, "res_value"); + result_value->addIncoming(types->true_constant(), short_circuit_bb); + result_value->addIncoming(types->false_constant(), non_short_circuit_bb); + result_.reset(new LValue(result_value)); +} + +LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair &pair) { + // generate code for value + auto value_expr = pair.value_expr(); + value_expr->Accept(*this); + auto value = result()->data(); + auto length = result()->length(); + + // generate code for validity + auto validity = BuildCombinedValidity(pair.validity_exprs()); + + return std::make_shared(value, length, validity); +} + +std::vector LLVMGenerator::Visitor::BuildParams( + FunctionHolder *holder, const ValueValidityPairVector &args, bool with_validity) { + LLVMTypes *types = generator_->types_.get(); + std::vector params; + + // if the function has holder, add the holder pointer first. + if (holder != nullptr) { + llvm::Constant *ptr_int_cast = types->i64_constant((int64_t)holder); + auto ptr = llvm::ConstantExpr::getIntToPtr(ptr_int_cast, types->i8_ptr_type()); + params.push_back(ptr); + } + + // build the function params, along with the validities. + for (auto &pair : args) { + // build value. + DexPtr value_expr = pair->value_expr(); + value_expr->Accept(*this); + LValue &result_ref = *result(); + params.push_back(result_ref.data()); + + // build length (for var len data types) + if (result_ref.length() != nullptr) { + params.push_back(result_ref.length()); + } + + // build validity. + if (with_validity) { + llvm::Value *validity_expr = BuildCombinedValidity(pair->validity_exprs()); + params.push_back(validity_expr); + } + } + return params; +} + +// Bitwise-AND of a vector of bits to get the combined validity. +llvm::Value *LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector &validities) { + llvm::IRBuilder<> &builder = ir_builder(); + LLVMTypes *types = generator_->types_.get(); + + llvm::Value *isValid = types->true_constant(); + for (auto &dex : validities) { + dex->Accept(*this); + isValid = builder.CreateAnd(isValid, result()->data(), "validityBitAnd"); + } + ADD_VISITOR_TRACE("combined validity is %T", isValid); + return isValid; +} + +llvm::Value *LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type, + FieldPtr field) { + llvm::IRBuilder<> &builder = ir_builder(); + + // Switch to the entry block to create a reference. + llvm::BasicBlock *saved_block = builder.GetInsertBlock(); + builder.SetInsertPoint(entry_block_); + + llvm::Value *slot_ref = nullptr; + switch (buffer_type) { + case kBufferTypeValidity: + slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field); + break; + + case kBufferTypeData: + slot_ref = generator_->GetDataReference(arg_addrs_, idx, field); + break; + + case kBufferTypeOffsets: + slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field); + break; + } + + // Revert to the saved block. + builder.SetInsertPoint(saved_block); + return slot_ref; +} + +llvm::Value *LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) { + llvm::IRBuilder<> &builder = ir_builder(); + + // Switch to the entry block to create a reference. + llvm::BasicBlock *saved_block = builder.GetInsertBlock(); + builder.SetInsertPoint(entry_block_); + + llvm::Value *slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx); + + // Revert to the saved block. + builder.SetInsertPoint(saved_block); + return slot_ref; +} + +/// The local bitmap is pre-filled with 1s. Clear only if invalid. +void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx, + llvm::Value *is_valid) { + llvm::Value *slot_ref = GetLocalBitMapReference(local_bitmap_idx); + generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid); +} + +// Hooks for tracing/printfs. +// +// replace %T with the type-specific format specifier. +// For some reason, float/double literals are getting lost when printing with the generic +// printf. so, use a wrapper instead. +std::string LLVMGenerator::ReplaceFormatInTrace(const std::string &in_msg, + llvm::Value *value, + std::string *print_fn) { + std::string msg = in_msg; + std::size_t pos = msg.find("%T"); + if (pos == std::string::npos) { + DCHECK(0); + return msg; + } + + llvm::Type *type = value->getType(); + const char *fmt = ""; + if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) || + type->isIntegerTy(32)) { + fmt = "%d"; + } else if (type->isIntegerTy(64)) { + // bigint + fmt = "%lld"; + } else if (type->isFloatTy()) { + // float + fmt = "%f"; + *print_fn = "print_float"; + } else if (type->isDoubleTy()) { + // float + fmt = "%lf"; + *print_fn = "print_double"; + } else { + DCHECK(0); + } + msg.replace(pos, 2, fmt); + return msg; +} + +void LLVMGenerator::AddTrace(const std::string &msg, llvm::Value *value) { + if (!enable_ir_traces_) { + return; + } + + std::string dmsg = "IR_TRACE:: " + msg + "\n"; + std::string print_fn_name = "printf"; + if (value != nullptr) { + dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name); + } + trace_strings_.push_back(dmsg); + + // cast this to an llvm pointer. + const char *str = trace_strings_.back().c_str(); + llvm::Constant *str_int_cast = types_->i64_constant((int64_t)str); + llvm::Constant *str_ptr_cast = + llvm::ConstantExpr::getIntToPtr(str_int_cast, types_->i8_ptr_type()); + + std::vector args; + args.push_back(str_ptr_cast); + if (value != nullptr) { + args.push_back(value); + } + AddFunctionCall(print_fn_name, types_->i32_type(), args); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/llvm_generator.h b/cpp/src/gandiva/codegen/llvm_generator.h new file mode 100644 index 00000000000..b5d12920278 --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_generator.h @@ -0,0 +1,201 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LLVMGENERATOR_H +#define GANDIVA_LLVMGENERATOR_H + +#include +#include +#include +#include + +#include +#include "codegen/annotator.h" +#include "codegen/compiled_expr.h" +#include "codegen/dex_visitor.h" +#include "codegen/engine.h" +#include "codegen/function_registry.h" +#include "codegen/llvm_types.h" +#include "codegen/lvalue.h" +#include "codegen/value_validity_pair.h" +#include "gandiva/configuration.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +class FunctionHolder; + +/// Builds an LLVM module and generates code for the specified set of expressions. +class LLVMGenerator { + public: + /// \brief Factory method to initialize the generator. + static Status Make(std::shared_ptr config, + std::unique_ptr *llvm_generator); + + /// \brief Build the code for the expression trees. Each element in the vector + /// represents an expression tree + Status Build(const ExpressionVector &exprs); + + /// \brief Execute the built expression against the provided arguments. + Status Execute(const arrow::RecordBatch &record_batch, + const ArrayDataVector &output_vector); + + LLVMTypes &types() { return *types_; } + llvm::Module *module() { return engine_->module(); } + + private: + LLVMGenerator(); + + FRIEND_TEST(TestLLVMGenerator, TestAdd); + FRIEND_TEST(TestLLVMGenerator, TestNullInternal); + + llvm::LLVMContext &context() { return *(engine_->context()); } + llvm::IRBuilder<> &ir_builder() { return engine_->ir_builder(); } + + /// Visitor to generate the code for a decomposed expression. + class Visitor : public DexVisitor { + public: + Visitor(LLVMGenerator *generator, llvm::Function *function, + llvm::BasicBlock *entry_block, llvm::Value *arg_addrs, + llvm::Value *arg_local_bitmaps, llvm::Value *loop_var); + + void Visit(const VectorReadValidityDex &dex) override; + void Visit(const VectorReadFixedLenValueDex &dex) override; + void Visit(const VectorReadVarLenValueDex &dex) override; + void Visit(const LocalBitMapValidityDex &dex) override; + void Visit(const TrueDex &dex) override; + void Visit(const FalseDex &dex) override; + void Visit(const LiteralDex &dex) override; + void Visit(const NonNullableFuncDex &dex) override; + void Visit(const NullableNeverFuncDex &dex) override; + void Visit(const NullableInternalFuncDex &dex) override; + void Visit(const IfDex &dex) override; + void Visit(const BooleanAndDex &dex) override; + void Visit(const BooleanOrDex &dex) override; + + LValuePtr result() { return result_; } + + private: + enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets }; + + llvm::IRBuilder<> &ir_builder() { return generator_->ir_builder(); } + llvm::Module *module() { return generator_->module(); } + + // Generate the code to build the combined validity (bitwise and) from the + // vector of validities. + llvm::Value *BuildCombinedValidity(const DexVector &validities); + + // Generate the code to build the validity and the value for the given pair. + LValuePtr BuildValueAndValidity(const ValueValidityPair &pair); + + // Generate code to build the params. + std::vector BuildParams(FunctionHolder *holder, + const ValueValidityPairVector &args, + bool with_validity); + + // Switch to the entry_block and get reference of the validity/value/offsets buffer + llvm::Value *GetBufferReference(int idx, BufferType buffer_type, FieldPtr field); + + // Switch to the entry_block and get reference to the local bitmap. + llvm::Value *GetLocalBitMapReference(int idx); + + // Clear the bit in the local bitmap, if is_valid is 'false' + void ClearLocalBitMapIfNotValid(int local_bitmap_idx, llvm::Value *is_valid); + + LLVMGenerator *generator_; + LValuePtr result_; + llvm::Function *function_; + llvm::BasicBlock *entry_block_; + llvm::Value *arg_addrs_; + llvm::Value *arg_local_bitmaps_; + llvm::Value *loop_var_; + }; + + // Generate the code for one expression, with the output of the expression going to + // 'output'. + Status Add(const ExpressionPtr expr, const FieldDescriptorPtr output); + + /// Generate code to load the vector at specified index in the 'arg_addrs' array. + llvm::Value *LoadVectorAtIndex(llvm::Value *arg_addrs, int idx, + const std::string &name); + + /// Generate code to load the vector at specified index and cast it as bitmap. + llvm::Value *GetValidityReference(llvm::Value *arg_addrs, int idx, FieldPtr field); + + /// Generate code to load the vector at specified index and cast it as data array. + llvm::Value *GetDataReference(llvm::Value *arg_addrs, int idx, FieldPtr field); + + /// Generate code to load the vector at specified index and cast it as offsets array. + llvm::Value *GetOffsetsReference(llvm::Value *arg_addrs, int idx, FieldPtr field); + + /// Generate code for the value array of one expression. + Status CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr output, int suffix_idx, + llvm::Function **fn); + + /// Generate code to load the local bitmap specified index and cast it as bitmap. + llvm::Value *GetLocalBitMapReference(llvm::Value *arg_bitmaps, int idx); + + /// Generate code to get the bit value at 'position' in the bitmap. + llvm::Value *GetPackedBitValue(llvm::Value *bitmap, llvm::Value *position); + + /// Generate code to set the bit value at 'position' in the bitmap to 'value'. + void SetPackedBitValue(llvm::Value *bitmap, llvm::Value *position, llvm::Value *value); + + /// Generate code to clear the bit value at 'position' in the bitmap if 'value' + /// is false. + void ClearPackedBitValueIfFalse(llvm::Value *bitmap, llvm::Value *position, + llvm::Value *value); + + /// For non-IR functions, add prototype to the module on first encounter. + void CheckAndAddPrototype(const std::string &full_name, llvm::Type *ret_type, + const std::vector &args); + + /// Generate code to make a function call (to a pre-compiled IR function) which takes + /// 'args' and has a return type 'ret_type'. + llvm::Value *AddFunctionCall(const std::string &full_name, llvm::Type *ret_type, + const std::vector &args, + bool has_holder = false); + + /// Compute the result bitmap for the expression. + /// + /// \param[in] : the compiled expression (includes the bitmap indices to be used for + /// computing the validity bitmap of the result). + /// \param[in] : eval_batch (includes input/output buffer addresses) + void ComputeBitMapsForExpr(const CompiledExpr &compiled_expr, + const EvalBatch &eval_batch); + + /// Replace the %T in the trace msg with the correct type corresponding to 'type' + /// eg. %d for int32, %ld for int64, .. + std::string ReplaceFormatInTrace(const std::string &msg, llvm::Value *value, + std::string *print_fn); + + /// Generate the code to print a trace msg with one optional argument (%T) + void AddTrace(const std::string &msg, llvm::Value *value = nullptr); + + std::unique_ptr engine_; + std::vector> compiled_exprs_; + std::unique_ptr types_; + FunctionRegistry function_registry_; + Annotator annotator_; + + // used for debug + bool dump_ir_; + bool optimise_ir_; + bool enable_ir_traces_; + std::vector trace_strings_; +}; + +} // namespace gandiva + +#endif // GANDIVA_LLVMGENERATOR_H diff --git a/cpp/src/gandiva/codegen/llvm_generator_test.cc b/cpp/src/gandiva/codegen/llvm_generator_test.cc new file mode 100644 index 00000000000..701a148236f --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_generator_test.cc @@ -0,0 +1,194 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/llvm_generator.h" + +#include +#include + +#include +#include "codegen/dex.h" +#include "codegen/func_descriptor.h" +#include "codegen/function_registry.h" +#include "gandiva/configuration.h" +#include "gandiva/expression.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t *elements, int nelements); + +class TestLLVMGenerator : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +// Verify that a valid pc function exists for every function in the registry. +TEST_F(TestLLVMGenerator, VerifyPCFunctions) { + std::unique_ptr generator; + Status status = + LLVMGenerator::Make(ConfigurationBuilder::DefaultConfiguration(), &generator); + EXPECT_TRUE(status.ok()) << status.message(); + + llvm::Module *module = generator->module(); + for (auto &iter : registry_) { + if (iter.needs_holder()) { + // TODO : need a way to verify these too. + continue; + } + + llvm::Function *fn = module->getFunction(iter.pc_name()); + EXPECT_NE(fn, nullptr) << "function " << iter.pc_name() + << " missing in precompiled module\n"; + } +} + +TEST_F(TestLLVMGenerator, TestAdd) { + // Setup LLVM generator to do an arithmetic add of two vectors + std::unique_ptr generator; + Status status = + LLVMGenerator::Make(ConfigurationBuilder::DefaultConfiguration(), &generator); + EXPECT_TRUE(status.ok()); + Annotator annotator; + + auto field0 = std::make_shared("f0", arrow::int32()); + auto desc0 = annotator.CheckAndAddInputFieldDescriptor(field0); + auto validity_dex0 = std::make_shared(desc0); + auto value_dex0 = std::make_shared(desc0); + auto pair0 = std::make_shared(validity_dex0, value_dex0); + + auto field1 = std::make_shared("f1", arrow::int32()); + auto desc1 = annotator.CheckAndAddInputFieldDescriptor(field1); + auto validity_dex1 = std::make_shared(desc1); + auto value_dex1 = std::make_shared(desc1); + auto pair1 = std::make_shared(validity_dex1, value_dex1); + + DataTypeVector params{arrow::int32(), arrow::int32()}; + auto func_desc = std::make_shared("add", params, arrow::int32()); + FunctionSignature signature(func_desc->name(), func_desc->params(), + func_desc->return_type()); + const NativeFunction *native_func = + generator->function_registry_.LookupSignature(signature); + + std::vector pairs{pair0, pair1}; + auto func_dex = std::make_shared(func_desc, native_func, + FunctionHolderPtr(nullptr), pairs); + + auto field_sum = std::make_shared("out", arrow::int32()); + auto desc_sum = annotator.CheckAndAddInputFieldDescriptor(field_sum); + + llvm::Function *ir_func = nullptr; + + status = generator->CodeGenExprValue(func_dex, desc_sum, 0, &ir_func); + ASSERT_TRUE(status.ok()); + + generator->engine_->FinalizeModule(true, false); + EvalFunc eval_func = (EvalFunc)generator->engine_->CompiledFunction(ir_func); + + int num_records = 4; + uint32_t a0[] = {1, 2, 3, 4}; + uint32_t a1[] = {5, 6, 7, 8}; + uint64_t in_bitmap = 0xffffffffffffffffull; + + uint32_t out[] = {0, 0, 0, 0}; + uint64_t out_bitmap = 0; + + uint8_t *addrs[] = { + reinterpret_cast(a0), reinterpret_cast(&in_bitmap), + reinterpret_cast(a1), reinterpret_cast(&in_bitmap), + reinterpret_cast(out), reinterpret_cast(&out_bitmap), + }; + eval_func(addrs, nullptr, num_records); + + uint32_t expected[] = {6, 8, 10, 12}; + for (int i = 0; i < num_records; i++) { + EXPECT_EQ(expected[i], out[i]); + } +} + +TEST_F(TestLLVMGenerator, TestNullInternal) { + // Setup LLVM generator to evaluate a NULL_INTERNAL type function. + std::unique_ptr generator; + Status status = + LLVMGenerator::Make(ConfigurationBuilder::DefaultConfiguration(), &generator); + EXPECT_TRUE(status.ok()); + Annotator annotator; + + // generator.enable_ir_traces_ = true; + auto field0 = std::make_shared("f0", arrow::int32()); + auto desc0 = annotator.CheckAndAddInputFieldDescriptor(field0); + auto validity_dex0 = std::make_shared(desc0); + auto value_dex0 = std::make_shared(desc0); + auto pair0 = std::make_shared(validity_dex0, value_dex0); + + DataTypeVector params{arrow::int32()}; + auto func_desc = + std::make_shared("half_or_null", params, arrow::int32()); + FunctionSignature signature(func_desc->name(), func_desc->params(), + func_desc->return_type()); + const NativeFunction *native_func = + generator->function_registry_.LookupSignature(signature); + + int local_bitmap_idx = annotator.AddLocalBitMap(); + std::vector pairs{pair0}; + auto func_dex = std::make_shared( + func_desc, native_func, FunctionHolderPtr(nullptr), pairs, local_bitmap_idx); + + auto field_result = std::make_shared("out", arrow::int32()); + auto desc_result = annotator.CheckAndAddInputFieldDescriptor(field_result); + + llvm::Function *ir_func; + status = generator->CodeGenExprValue(func_dex, desc_result, 0, &ir_func); + ASSERT_TRUE(status.ok()); + + generator->engine_->FinalizeModule(true /*optimise_ir*/, false /*dump_ir*/); + + EvalFunc eval_func = (EvalFunc)generator->engine_->CompiledFunction(ir_func); + + int num_records = 4; + uint32_t a0[] = {1, 2, 3, 4}; + uint64_t in_bitmap = 0xffffffffffffffffull; + + uint32_t out[] = {0, 0, 0, 0}; + uint64_t out_bitmap = 0; + + uint64_t local_bitmap = UINT64_MAX; + + uint8_t *addrs[] = { + reinterpret_cast(a0), + reinterpret_cast(&in_bitmap), + reinterpret_cast(out), + reinterpret_cast(&out_bitmap), + }; + + uint8_t *local_bitmap_addrs[] = { + reinterpret_cast(&local_bitmap), + }; + + eval_func(addrs, local_bitmap_addrs, num_records); + + uint32_t expected_value[] = {0, 1, 0, 2}; + bool expected_validity[] = {false, true, false, true}; + + for (int i = 0; i < num_records; i++) { + EXPECT_EQ(expected_value[i], out[i]); + EXPECT_EQ(expected_validity[i], (local_bitmap & (1 << i)) != 0); + } +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/llvm_types.cc b/cpp/src/gandiva/codegen/llvm_types.cc new file mode 100644 index 00000000000..3b474f39ac8 --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_types.cc @@ -0,0 +1,43 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/llvm_types.h" + +namespace gandiva { + +// LLVM doesn't distinguish between signed and unsigned types. + +LLVMTypes::LLVMTypes(llvm::LLVMContext &context) : context_(context) { + arrow_id_to_llvm_type_map_ = { + {arrow::Type::type::BOOL, i1_type()}, + {arrow::Type::type::INT8, i8_type()}, + {arrow::Type::type::INT16, i16_type()}, + {arrow::Type::type::INT32, i32_type()}, + {arrow::Type::type::INT64, i64_type()}, + {arrow::Type::type::UINT8, i8_type()}, + {arrow::Type::type::UINT16, i16_type()}, + {arrow::Type::type::UINT32, i32_type()}, + {arrow::Type::type::UINT64, i64_type()}, + {arrow::Type::type::FLOAT, float_type()}, + {arrow::Type::type::DOUBLE, double_type()}, + {arrow::Type::type::DATE64, i64_type()}, + {arrow::Type::type::TIME32, i32_type()}, + {arrow::Type::type::TIME64, i64_type()}, + {arrow::Type::type::TIMESTAMP, i64_type()}, + {arrow::Type::type::STRING, i8_ptr_type()}, + {arrow::Type::type::BINARY, i8_ptr_type()}, + }; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/llvm_types.h b/cpp/src/gandiva/codegen/llvm_types.h new file mode 100644 index 00000000000..cf4603168d8 --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_types.h @@ -0,0 +1,121 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LLVM_TYPES_H +#define GANDIVA_LLVM_TYPES_H + +#include +#include + +#include +#include +#include "gandiva/arrow.h" + +namespace gandiva { + +/// \brief Holder for llvm types, and mappings between arrow types and llvm types. +class LLVMTypes { + public: + explicit LLVMTypes(llvm::LLVMContext &context); + + llvm::Type *i1_type() { return llvm::Type::getInt1Ty(context_); } + + llvm::Type *i8_type() { return llvm::Type::getInt8Ty(context_); } + + llvm::Type *i16_type() { return llvm::Type::getInt16Ty(context_); } + + llvm::Type *i32_type() { return llvm::Type::getInt32Ty(context_); } + + llvm::Type *i64_type() { return llvm::Type::getInt64Ty(context_); } + + llvm::Type *float_type() { return llvm::Type::getFloatTy(context_); } + + llvm::Type *double_type() { return llvm::Type::getDoubleTy(context_); } + + llvm::PointerType *i8_ptr_type() { return llvm::PointerType::get(i8_type(), 0); } + + llvm::PointerType *i32_ptr_type() { return llvm::PointerType::get(i32_type(), 0); } + + llvm::PointerType *i64_ptr_type() { return llvm::PointerType::get(i64_type(), 0); } + + llvm::PointerType *ptr_type(llvm::Type *base_type) { + return llvm::PointerType::get(base_type, 0); + } + + llvm::Type *void_type() { return llvm::Type::getVoidTy(context_); } + + llvm::Constant *true_constant() { + return llvm::ConstantInt::get(context_, llvm::APInt(1, 1)); + } + + llvm::Constant *false_constant() { + return llvm::ConstantInt::get(context_, llvm::APInt(1, 0)); + } + + llvm::Constant *i1_constant(bool val) { + return llvm::ConstantInt::get(context_, llvm::APInt(1, val)); + } + + llvm::Constant *i8_constant(bool val) { + return llvm::ConstantInt::get(context_, llvm::APInt(8, val)); + } + + llvm::Constant *i16_constant(bool val) { + return llvm::ConstantInt::get(context_, llvm::APInt(16, val)); + } + + llvm::Constant *i32_constant(int32_t val) { + return llvm::ConstantInt::get(context_, llvm::APInt(32, val)); + } + + llvm::Constant *i64_constant(int64_t val) { + return llvm::ConstantInt::get(context_, llvm::APInt(64, val)); + } + + llvm::Constant *float_constant(float val) { + return llvm::ConstantFP::get(float_type(), val); + } + + llvm::Constant *double_constant(double val) { + return llvm::ConstantFP::get(double_type(), val); + } + + /// For a given data type, find the ir type used for the data vector slot. + llvm::Type *DataVecType(const DataTypePtr &data_type) { + return IRType(data_type->id()); + } + + /// For a given minor type, find the corresponding ir type. + llvm::Type *IRType(arrow::Type::type arrow_type) { + auto found = arrow_id_to_llvm_type_map_.find(arrow_type); + return (found == arrow_id_to_llvm_type_map_.end()) ? NULL : found->second; + } + + std::vector GetSupportedArrowTypes() { + std::vector retval; + for (auto const &element : arrow_id_to_llvm_type_map_) { + retval.push_back(element.first); + } + return retval; + } + + private: + std::map arrow_id_to_llvm_type_map_; + + llvm::LLVMContext &context_; +}; + +} // namespace gandiva + +#endif // GANDIVA_LLVM_TYPES_H diff --git a/cpp/src/gandiva/codegen/llvm_types_test.cc b/cpp/src/gandiva/codegen/llvm_types_test.cc new file mode 100644 index 00000000000..6d75768a804 --- /dev/null +++ b/cpp/src/gandiva/codegen/llvm_types_test.cc @@ -0,0 +1,62 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/llvm_types.h" + +#include + +namespace gandiva { + +class TestLLVMTypes : public ::testing::Test { + protected: + virtual void SetUp() { types_ = new LLVMTypes(context_); } + virtual void TearDown() { delete types_; } + + llvm::LLVMContext context_; + LLVMTypes *types_; +}; + +TEST_F(TestLLVMTypes, TestFound) { + EXPECT_EQ(types_->IRType(arrow::Type::BOOL), types_->i1_type()); + EXPECT_EQ(types_->IRType(arrow::Type::INT32), types_->i32_type()); + EXPECT_EQ(types_->IRType(arrow::Type::INT64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::FLOAT), types_->float_type()); + EXPECT_EQ(types_->IRType(arrow::Type::DOUBLE), types_->double_type()); + EXPECT_EQ(types_->IRType(arrow::Type::DATE64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::TIME64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::TIMESTAMP), types_->i64_type()); + + EXPECT_EQ(types_->DataVecType(arrow::boolean()), types_->i1_type()); + EXPECT_EQ(types_->DataVecType(arrow::int32()), types_->i32_type()); + EXPECT_EQ(types_->DataVecType(arrow::int64()), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::float32()), types_->float_type()); + EXPECT_EQ(types_->DataVecType(arrow::float64()), types_->double_type()); + EXPECT_EQ(types_->DataVecType(arrow::date64()), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::time64(arrow::TimeUnit::MICRO)), + types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)), + types_->i64_type()); +} + +TEST_F(TestLLVMTypes, TestNotFound) { + EXPECT_EQ(types_->IRType(arrow::Type::type::UNION), nullptr); + EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/local_bitmaps_holder.h b/cpp/src/gandiva/codegen/local_bitmaps_holder.h new file mode 100644 index 00000000000..9d1c28f0bdc --- /dev/null +++ b/cpp/src/gandiva/codegen/local_bitmaps_holder.h @@ -0,0 +1,81 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LOCAL_BITMAPS_HOLDER_H +#define GANDIVA_LOCAL_BITMAPS_HOLDER_H + +#include +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief The buffers corresponding to one batch of records, used for +/// expression evaluation. +class LocalBitMapsHolder { + public: + LocalBitMapsHolder(int num_records, int num_local_bitmaps); + + int GetNumLocalBitMaps() const { return local_bitmaps_vec_.size(); } + + int GetLocalBitMapSize() const { return local_bitmap_size_; } + + uint8_t **GetLocalBitMapArray() const { return local_bitmaps_array_.get(); } + + uint8_t *GetLocalBitMap(int idx) const { + DCHECK(idx <= GetNumLocalBitMaps()); + return local_bitmaps_array_.get()[idx]; + } + + private: + /// number of records in the current batch. + int num_records_; + + /// A container of 'local_bitmaps_', each sized to accomodate 'num_records'. + std::vector> local_bitmaps_vec_; + + /// An array of the local bitmaps. + std::unique_ptr local_bitmaps_array_; + + int local_bitmap_size_; +}; + +inline LocalBitMapsHolder::LocalBitMapsHolder(int num_records, int num_local_bitmaps) + : num_records_(num_records) { + // alloc an array for the pointers to the bitmaps. + if (num_local_bitmaps > 0) { + local_bitmaps_array_.reset(new uint8_t *[num_local_bitmaps]); + } + + // 64-bit aligned bitmaps. + uint32_t roundUp64Multiple = (num_records_ + 63) >> 6; + local_bitmap_size_ = roundUp64Multiple * 8; + + // Alloc 'num_local_bitmaps_' number of bitmaps, each of capacity 'num_records_'. + for (int i = 0; i < num_local_bitmaps; ++i) { + // TODO : round-up to a slab friendly multiple. + std::unique_ptr bitmap(new uint8_t[local_bitmap_size_]); + + // keep pointer to the bitmap in the array. + (local_bitmaps_array_.get())[i] = bitmap.get(); + + // pre-fill with 1s (assuming that the probability of is_valid is higher). + memset(bitmap.get(), 0xff, local_bitmap_size_); + local_bitmaps_vec_.push_back(std::move(bitmap)); + } +} + +} // namespace gandiva + +#endif // GANDIVA_LOCAL_BITMAPS_HOLDER_H diff --git a/cpp/src/gandiva/codegen/logging.h b/cpp/src/gandiva/codegen/logging.h new file mode 100644 index 00000000000..84c42045366 --- /dev/null +++ b/cpp/src/gandiva/codegen/logging.h @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_LOGGING_H +#define GANDIVA_LOGGING_H + +// TODO : setup logging or use glog. +#include + +#endif // GANDIVA_LOGGING_H diff --git a/cpp/src/gandiva/codegen/lru_cache.h b/cpp/src/gandiva/codegen/lru_cache.h new file mode 100644 index 00000000000..9d4f3e04605 --- /dev/null +++ b/cpp/src/gandiva/codegen/lru_cache.h @@ -0,0 +1,120 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LRU_CACHE_H +#define LRU_CACHE_H + +#include +#include +#include + +#include + +// modified from boost LRU cache -> the boost cache supported only an +// ordered map. +namespace gandiva { +// a cache which evicts the least recently used item when it is full +template +class LruCache { + public: + using key_type = Key; + using value_type = Value; + using list_type = std::list; + struct hasher { + template + std::size_t operator()(const I &i) const { + return i.Hash(); + } + }; + using map_type = + std::unordered_map, + hasher>; + + LruCache(size_t capacity) : cache_capacity_(capacity) {} + + ~LruCache() {} + + size_t size() const { return map_.size(); } + + size_t capacity() const { return cache_capacity_; } + + bool empty() const { return map_.empty(); } + + bool contains(const key_type &key) { return map_.find(key) != map_.end(); } + + void insert(const key_type &key, const value_type &value) { + typename map_type::iterator i = map_.find(key); + if (i == map_.end()) { + // insert item into the cache, but first check if it is full + if (size() >= cache_capacity_) { + // cache is full, evict the least recently used item + evict(); + } + + // insert the new item + lru_list_.push_front(key); + map_[key] = std::make_pair(value, lru_list_.begin()); + } + } + + boost::optional get(const key_type &key) { + // lookup value in the cache + typename map_type::iterator value_for_key = map_.find(key); + if (value_for_key == map_.end()) { + // value not in cache + return boost::none; + } + + // return the value, but first update its place in the most + // recently used list + typename list_type::iterator postition_in_lru_list = value_for_key->second.second; + if (postition_in_lru_list != lru_list_.begin()) { + // move item to the front of the most recently used list + lru_list_.erase(postition_in_lru_list); + lru_list_.push_front(key); + + // update iterator in map + postition_in_lru_list = lru_list_.begin(); + const value_type &value = value_for_key->second.first; + map_[key] = std::make_pair(value, postition_in_lru_list); + + // return the value + return value; + } else { + // the item is already at the front of the most recently + // used list so just return it + return value_for_key->second.first; + } + } + + void clear() { + map_.clear(); + lru_list_.clear(); + } + + private: + void evict() { + // evict item from the end of most recently used list + typename list_type::iterator i = --lru_list_.end(); + map_.erase(*i); + lru_list_.erase(i); + } + + private: + map_type map_; + list_type lru_list_; + size_t cache_capacity_; +}; +} // namespace gandiva +#endif // LRU_CACHE_H diff --git a/cpp/src/gandiva/codegen/lru_cache_test.cc b/cpp/src/gandiva/codegen/lru_cache_test.cc new file mode 100644 index 00000000000..b729803d4ba --- /dev/null +++ b/cpp/src/gandiva/codegen/lru_cache_test.cc @@ -0,0 +1,60 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/lru_cache.h" + +#include +#include + +#include + +namespace gandiva { + +class TestCacheKey { + public: + TestCacheKey(int tmp) : tmp_(tmp) {} + std::size_t Hash() const { return tmp_; } + bool operator==(const TestCacheKey &other) const { return tmp_ == other.tmp_; } + + private: + int tmp_; +}; + +class TestLruCache : public ::testing::Test { + public: + TestLruCache() : cache_(2) {} + + protected: + LruCache cache_; +}; + +TEST_F(TestLruCache, TestEvict) { + cache_.insert(TestCacheKey(1), "hello"); + cache_.insert(TestCacheKey(2), "hello"); + cache_.insert(TestCacheKey(1), "hello"); + cache_.insert(TestCacheKey(3), "hello"); + // should have evicted key 1 + ASSERT_EQ(2, cache_.size()); + ASSERT_EQ(cache_.get(1), boost::none); +} + +TEST_F(TestLruCache, TestLruBehavior) { + cache_.insert(TestCacheKey(1), "hello"); + cache_.insert(TestCacheKey(2), "hello"); + cache_.get(TestCacheKey(1)); + cache_.insert(TestCacheKey(3), "hello"); + // should have evicted key 2. + ASSERT_EQ(cache_.get(1).value(), "hello"); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/lvalue.h b/cpp/src/gandiva/codegen/lvalue.h new file mode 100644 index 00000000000..076d64be91d --- /dev/null +++ b/cpp/src/gandiva/codegen/lvalue.h @@ -0,0 +1,41 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LVALUE_H +#define GANDIVA_LVALUE_H + +#include + +namespace gandiva { + +/// \brief Tracks validity/value builders in LLVM. +class LValue { + public: + explicit LValue(llvm::Value *data, llvm::Value *length = nullptr, + llvm::Value *validity = nullptr) + : data_(data), length_(length), validity_(validity) {} + + llvm::Value *data() { return data_; } + llvm::Value *length() { return length_; } + llvm::Value *validity() { return validity_; } + + private: + llvm::Value *data_; + llvm::Value *length_; + llvm::Value *validity_; +}; + +} // namespace gandiva + +#endif // GANDIVA_LVALUE_H diff --git a/cpp/src/gandiva/codegen/native_function.h b/cpp/src/gandiva/codegen/native_function.h new file mode 100644 index 00000000000..8e846dad619 --- /dev/null +++ b/cpp/src/gandiva/codegen/native_function.h @@ -0,0 +1,71 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_NATIVE_FUNCTION_H +#define GANDIVA_NATIVE_FUNCTION_H + +#include +#include +#include + +#include "gandiva/function_signature.h" + +namespace gandiva { + +enum ResultNullableType { + /// result validity is an intersection of the validity of the children. + RESULT_NULL_IF_NULL, + /// result is always valid. + RESULT_NULL_NEVER, + /// result validity depends on some internal logic. + RESULT_NULL_INTERNAL, +}; + +/// \brief Holder for the mapping from a function in an expression to a +/// precompiled function. +class NativeFunction { + public: + const FunctionSignature &signature() const { return signature_; } + std::string pc_name() const { return pc_name_; } + ResultNullableType result_nullable_type() const { return result_nullable_type_; } + bool param_null_safe() const { return param_null_safe_; } + bool needs_holder() const { return needs_holder_; } + + private: + NativeFunction(const std::string &base_name, const DataTypeVector ¶m_types, + DataTypePtr ret_type, bool param_null_safe, + const ResultNullableType &result_nullable_type, + const std::string &pc_name, bool needs_holder = false) + : signature_(base_name, param_types, ret_type), + param_null_safe_(param_null_safe), + needs_holder_(needs_holder), + result_nullable_type_(result_nullable_type), + pc_name_(pc_name) {} + + FunctionSignature signature_; + + /// attributes + bool param_null_safe_; + bool needs_holder_; + ResultNullableType result_nullable_type_; + + /// pre-compiled function name. + std::string pc_name_; + + friend class FunctionRegistry; +}; + +} // end namespace gandiva + +#endif // GANDIVA_NATIVE_FUNCTION_H diff --git a/cpp/src/gandiva/codegen/node.h b/cpp/src/gandiva/codegen/node.h new file mode 100644 index 00000000000..2d781777b11 --- /dev/null +++ b/cpp/src/gandiva/codegen/node.h @@ -0,0 +1,209 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_NODE_H +#define GANDIVA_EXPR_NODE_H + +#include +#include + +#include "codegen/func_descriptor.h" +#include "codegen/literal_holder.h" +#include "codegen/node_visitor.h" +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/status.h" + +namespace gandiva { + +/// \brief Represents a node in the expression tree. Validity and value are +/// in a joined state. +class Node { + public: + explicit Node(DataTypePtr return_type) : return_type_(return_type) {} + + virtual ~Node() = default; + + const DataTypePtr &return_type() const { return return_type_; } + + /// Derived classes should simply invoke the Visit api of the visitor. + virtual Status Accept(NodeVisitor &visitor) const = 0; + + virtual std::string ToString() = 0; + + protected: + DataTypePtr return_type_; +}; + +/// \brief Node in the expression tree, representing a literal. +class LiteralNode : public Node { + public: + LiteralNode(DataTypePtr type, const LiteralHolder &holder, bool is_null) + : Node(type), holder_(holder), is_null_(is_null) {} + + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } + + const LiteralHolder &holder() const { return holder_; } + + bool is_null() const { return is_null_; } + + std::string ToString() override { + std::stringstream ss; + ss << "(" << return_type()->ToString() << ") "; + if (is_null()) { + ss << std::string("null"); + return ss.str(); + } + ss << holder(); + return ss.str(); + } + + private: + LiteralHolder holder_; + bool is_null_; +}; + +/// \brief Node in the expression tree, representing an arrow field. +class FieldNode : public Node { + public: + explicit FieldNode(FieldPtr field) : Node(field->type()), field_(field) {} + + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } + + const FieldPtr &field() const { return field_; } + + std::string ToString() override { return field()->type()->name(); } + + private: + FieldPtr field_; +}; + +/// \brief Node in the expression tree, representing a function. +class FunctionNode : public Node { + public: + FunctionNode(FuncDescriptorPtr descriptor, const NodeVector &children, + DataTypePtr retType) + : Node(retType), descriptor_(descriptor), children_(children) {} + + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } + + const FuncDescriptorPtr &descriptor() const { return descriptor_; } + const NodeVector &children() const { return children_; } + + std::string ToString() override { + std::stringstream ss; + ss << descriptor()->return_type()->name() << " " << descriptor()->name() << "("; + bool skip_comma = true; + for (auto child : children()) { + if (skip_comma) { + ss << child->ToString(); + skip_comma = false; + } else { + ss << ", " << child->ToString(); + } + } + ss << ")"; + return ss.str(); + } + + /// Make a function node with params types specified by 'children', and + /// having return type ret_type. + static NodePtr MakeFunction(const std::string &name, const NodeVector &children, + DataTypePtr return_type); + + private: + FuncDescriptorPtr descriptor_; + NodeVector children_; +}; + +inline NodePtr FunctionNode::MakeFunction(const std::string &name, + const NodeVector &children, + DataTypePtr return_type) { + DataTypeVector param_types; + for (auto &child : children) { + param_types.push_back(child->return_type()); + } + + auto func_desc = FuncDescriptorPtr(new FuncDescriptor(name, param_types, return_type)); + return NodePtr(new FunctionNode(func_desc, children, return_type)); +} + +/// \brief Node in the expression tree, representing an if-else expression. +class IfNode : public Node { + public: + IfNode(NodePtr condition, NodePtr then_node, NodePtr else_node, DataTypePtr result_type) + : Node(result_type), + condition_(condition), + then_node_(then_node), + else_node_(else_node) {} + + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } + + const NodePtr &condition() const { return condition_; } + const NodePtr &then_node() const { return then_node_; } + const NodePtr &else_node() const { return else_node_; } + + std::string ToString() override { + std::stringstream ss; + ss << "if (" << condition()->ToString() << ") { "; + ss << then_node()->ToString() << " } else { "; + ss << else_node()->ToString() << " }"; + return ss.str(); + } + + private: + NodePtr condition_; + NodePtr then_node_; + NodePtr else_node_; +}; + +/// \brief Node in the expression tree, representing an and/or boolean expression. +class BooleanNode : public Node { + public: + enum ExprType : char { AND, OR }; + + BooleanNode(ExprType expr_type, const NodeVector &children) + : Node(arrow::boolean()), expr_type_(expr_type), children_(children) {} + + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } + + ExprType expr_type() const { return expr_type_; } + + const NodeVector &children() const { return children_; } + + std::string ToString() override { + std::stringstream ss; + bool first = true; + for (auto &child : children_) { + if (!first) { + if (expr_type() == BooleanNode::AND) { + ss << " && "; + } else { + ss << " || "; + } + } + ss << child->ToString(); + first = false; + } + return ss.str(); + } + + private: + ExprType expr_type_; + NodeVector children_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_NODE_H diff --git a/cpp/src/gandiva/codegen/node_visitor.h b/cpp/src/gandiva/codegen/node_visitor.h new file mode 100644 index 00000000000..26e992d85c6 --- /dev/null +++ b/cpp/src/gandiva/codegen/node_visitor.h @@ -0,0 +1,41 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_NODE_VISITOR_H +#define GANDIVA_NODE_VISITOR_H + +#include "gandiva/logging.h" +#include "gandiva/status.h" + +namespace gandiva { + +class FieldNode; +class FunctionNode; +class IfNode; +class LiteralNode; +class BooleanNode; + +/// \brief Visitor for nodes in the expression tree. +class NodeVisitor { + public: + virtual Status Visit(const FieldNode &node) = 0; + virtual Status Visit(const FunctionNode &node) = 0; + virtual Status Visit(const IfNode &node) = 0; + virtual Status Visit(const LiteralNode &node) = 0; + virtual Status Visit(const BooleanNode &node) = 0; +}; + +} // namespace gandiva + +#endif // GANDIVA_NODE_VISITOR_H diff --git a/cpp/src/gandiva/codegen/projector.cc b/cpp/src/gandiva/codegen/projector.cc new file mode 100644 index 00000000000..38768b1d538 --- /dev/null +++ b/cpp/src/gandiva/codegen/projector.cc @@ -0,0 +1,222 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/projector.h" + +#include +#include +#include + +#include "codegen/cache.h" +#include "codegen/expr_validator.h" +#include "codegen/llvm_generator.h" +#include "codegen/projector_cache_key.h" +#include "gandiva/status.h" + +namespace gandiva { + +Projector::Projector(std::unique_ptr llvm_generator, SchemaPtr schema, + const FieldVector &output_fields, + std::shared_ptr configuration) + : llvm_generator_(std::move(llvm_generator)), + schema_(schema), + output_fields_(output_fields), + configuration_(configuration) {} + +Status Projector::Make(SchemaPtr schema, const ExpressionVector &exprs, + std::shared_ptr *projector) { + return Projector::Make(schema, exprs, ConfigurationBuilder::DefaultConfiguration(), + projector); +} + +Status Projector::Make(SchemaPtr schema, const ExpressionVector &exprs, + std::shared_ptr configuration, + std::shared_ptr *projector) { + GANDIVA_RETURN_FAILURE_IF_FALSE(schema != nullptr, + Status::Invalid("schema cannot be null")); + GANDIVA_RETURN_FAILURE_IF_FALSE(!exprs.empty(), + Status::Invalid("expressions need to be non-empty")); + GANDIVA_RETURN_FAILURE_IF_FALSE(configuration != nullptr, + Status::Invalid("configuration cannot be null")); + + // see if equivalent projector was already built + static Cache> cache; + ProjectorCacheKey cache_key(schema, configuration, exprs); + std::shared_ptr cached_projector = cache.GetModule(cache_key); + if (cached_projector != nullptr) { + *projector = cached_projector; + return Status::OK(); + } + // Build LLVM generator, and generate code for the specified expressions + std::unique_ptr llvm_gen; + Status status = LLVMGenerator::Make(configuration, &llvm_gen); + GANDIVA_RETURN_NOT_OK(status); + + // Run the validation on the expressions. + // Return if any of the expression is invalid since + // we will not be able to process further. + ExprValidator expr_validator(llvm_gen->types(), schema); + for (auto &expr : exprs) { + status = expr_validator.Validate(expr); + GANDIVA_RETURN_NOT_OK(status); + } + + status = llvm_gen->Build(exprs); + GANDIVA_RETURN_NOT_OK(status); + + // save the output field types. Used for validation at Evaluate() time. + std::vector output_fields; + for (auto &expr : exprs) { + output_fields.push_back(expr->result()); + } + + // Instantiate the projector with the completely built llvm generator + *projector = std::shared_ptr( + new Projector(std::move(llvm_gen), schema, output_fields, configuration)); + cache.PutModule(cache_key, *projector); + return Status::OK(); +} + +Status Projector::Evaluate(const arrow::RecordBatch &batch, + const ArrayDataVector &output_data_vecs) { + Status status = ValidateEvaluateArgsCommon(batch); + GANDIVA_RETURN_NOT_OK(status); + + if (output_data_vecs.size() != output_fields_.size()) { + std::stringstream ss; + ss << "number of buffers for output_data_vecs is " << output_data_vecs.size() + << ", expected " << output_fields_.size(); + return Status::Invalid(ss.str()); + } + + int idx = 0; + for (auto &array_data : output_data_vecs) { + if (array_data == nullptr) { + std::stringstream ss; + ss << "array for output field " << output_fields_[idx]->name() << "is null."; + return Status::Invalid(ss.str()); + } + + Status status = + ValidateArrayDataCapacity(*array_data, *(output_fields_[idx]), batch.num_rows()); + GANDIVA_RETURN_NOT_OK(status); + ++idx; + } + return llvm_generator_->Execute(batch, output_data_vecs); +} + +Status Projector::Evaluate(const arrow::RecordBatch &batch, arrow::MemoryPool *pool, + arrow::ArrayVector *output) { + Status status = ValidateEvaluateArgsCommon(batch); + GANDIVA_RETURN_NOT_OK(status); + + if (output == nullptr) { + return Status::Invalid("output must be non-null."); + } + + if (pool == nullptr) { + return Status::Invalid("memory pool must be non-null."); + } + + // Allocate the output data vecs. + ArrayDataVector output_data_vecs; + for (auto &field : output_fields_) { + ArrayDataPtr output_data; + + status = AllocArrayData(field->type(), batch.num_rows(), pool, &output_data); + GANDIVA_RETURN_NOT_OK(status); + + output_data_vecs.push_back(output_data); + } + + // Execute the expression(s). + status = llvm_generator_->Execute(batch, output_data_vecs); + GANDIVA_RETURN_NOT_OK(status); + + // Create and return array arrays. + output->clear(); + for (auto &array_data : output_data_vecs) { + output->push_back(arrow::MakeArray(array_data)); + } + return Status::OK(); +} + +// TODO : handle variable-len vectors +Status Projector::AllocArrayData(const DataTypePtr &type, int num_records, + arrow::MemoryPool *pool, ArrayDataPtr *array_data) { + if (!arrow::is_primitive(type->id())) { + return Status::Invalid("Unsupported output data type " + type->ToString()); + } + + arrow::Status astatus; + std::shared_ptr null_bitmap; + int64_t size = arrow::BitUtil::BytesForBits(num_records); + astatus = arrow::AllocateBuffer(pool, size, &null_bitmap); + GANDIVA_RETURN_ARROW_NOT_OK(astatus); + + std::shared_ptr data; + const auto &fw_type = dynamic_cast(*type); + int64_t data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + astatus = arrow::AllocateBuffer(pool, data_len, &data); + GANDIVA_RETURN_ARROW_NOT_OK(astatus); + + *array_data = arrow::ArrayData::Make(type, num_records, {null_bitmap, data}); + return Status::OK(); +} + +Status Projector::ValidateEvaluateArgsCommon(const arrow::RecordBatch &batch) { + if (!batch.schema()->Equals(*schema_)) { + return Status::Invalid("Schema in RecordBatch must match the schema in Make()"); + } + if (batch.num_rows() == 0) { + return Status::Invalid("RecordBatch must be non-empty."); + } + return Status::OK(); +} + +Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData &array_data, + const arrow::Field &field, int num_records) { + // verify that there are atleast two buffers (validity and data). + if (array_data.buffers.size() < 2) { + std::stringstream ss; + ss << "number of buffers for output field " << field.name() << "is " + << array_data.buffers.size() << ", must have minimum 2."; + return Status::Invalid(ss.str()); + } + + // verify size of bitmap buffer. + int64_t min_bitmap_len = arrow::BitUtil::BytesForBits(num_records); + int64_t bitmap_len = array_data.buffers[0]->capacity(); + if (bitmap_len < min_bitmap_len) { + std::stringstream ss; + ss << "bitmap buffer for output field " << field.name() << "has size " << bitmap_len + << ", must have minimum size " << min_bitmap_len; + return Status::Invalid(ss.str()); + } + + // verify size of data buffer. + // TODO : handle variable-len vectors + const auto &fw_type = dynamic_cast(*field.type()); + int64_t min_data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + int64_t data_len = array_data.buffers[1]->capacity(); + if (data_len < min_data_len) { + std::stringstream ss; + ss << "data buffer for output field " << field.name() << " has size " << data_len + << ", must have minimum size " << min_data_len; + return Status::Invalid(ss.str()); + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/projector.h b/cpp/src/gandiva/codegen/projector.h new file mode 100644 index 00000000000..2d10c48b8de --- /dev/null +++ b/cpp/src/gandiva/codegen/projector.h @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_EXPR_PROJECTOR_H +#define GANDIVA_EXPR_PROJECTOR_H + +#include +#include +#include +#include + +#include "gandiva/arrow.h" +#include "gandiva/configuration.h" +#include "gandiva/expression.h" +#include "gandiva/status.h" + +namespace gandiva { + +class LLVMGenerator; + +/// \brief projection using expressions. +/// +/// A projector is built for a specific schema and vector of expressions. +/// Once the projector is built, it can be used to evaluate many row batches. +class Projector { + public: + /// Build a default projector for the given schema to evaluate + /// the vector of expressions. + /// + /// \param[in] : schema schema for the record batches, and the expressions. + /// \param[in] : exprs vector of expressions. + /// \param[out]: projector the returned projector object + static Status Make(SchemaPtr schema, const ExpressionVector &exprs, + std::shared_ptr *projector); + + /// Build a projector for the given schema to evaluate the vector of expressions. + /// Customize the projector with runtime configuration. + /// + /// \param[in] : schema schema for the record batches, and the expressions. + /// \param[in] : exprs vector of expressions. + /// \param[in] : run time configuration. + /// \param[out]: projector the returned projector object + static Status Make(SchemaPtr schema, const ExpressionVector &exprs, + std::shared_ptr, + std::shared_ptr *projector); + + /// Evaluate the specified record batch, and return the allocated and populated output + /// arrays. The output arrays will be allocated from the memory pool 'pool', and added + /// to the vector 'output'. + /// + /// \param[in] : batch the record batch. schema should be the same as the one in 'Make' + /// \param[in] : pool memory pool used to allocate output arrays (if required). + /// \param[out]: output the vector of allocated/populated arrays. + Status Evaluate(const arrow::RecordBatch &batch, arrow::MemoryPool *pool, + arrow::ArrayVector *ouput); + + /// Evaluate the specified record batch, and populate the output arrays. The output + /// arrays of sufficient capacity must be allocated by the caller. + /// + /// \param[in] : batch the record batch. schema should be the same as the one in 'Make' + /// \param[in/out]: vector of arrays, the arrays are allocated by the caller and + /// populated by Evaluate. + Status Evaluate(const arrow::RecordBatch &batch, const ArrayDataVector &output); + + private: + Projector(std::unique_ptr llvm_generator, SchemaPtr schema, + const FieldVector &output_fields, std::shared_ptr); + + /// Allocate an ArrowData of length 'length'. + Status AllocArrayData(const DataTypePtr &type, int length, arrow::MemoryPool *pool, + ArrayDataPtr *array_data); + + /// Validate that the ArrayData has sufficient capacity to accomodate 'num_records'. + Status ValidateArrayDataCapacity(const arrow::ArrayData &array_data, + const arrow::Field &field, int num_records); + + /// Validate the common args for Evaluate() APIs. + Status ValidateEvaluateArgsCommon(const arrow::RecordBatch &batch); + + const std::unique_ptr llvm_generator_; + const SchemaPtr schema_; + const FieldVector output_fields_; + const std::shared_ptr configuration_; +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_PROJECTOR_H diff --git a/cpp/src/gandiva/codegen/projector_cache_key.h b/cpp/src/gandiva/codegen/projector_cache_key.h new file mode 100644 index 00000000000..36fa97eebcc --- /dev/null +++ b/cpp/src/gandiva/codegen/projector_cache_key.h @@ -0,0 +1,68 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_PROJECTOR_CACHE_KEY_H +#define GANDIVA_PROJECTOR_CACHE_KEY_H + +#include "gandiva/arrow.h" +#include "gandiva/projector.h" + +namespace gandiva { +class ProjectorCacheKey { + public: + ProjectorCacheKey(SchemaPtr schema, std::shared_ptr configuration, + ExpressionVector expression_vector) + : schema_(schema), configuration_(configuration) { + static const int kSeedValue = 4; + size_t result = kSeedValue; + for (auto &expr : expression_vector) { + std::string expr_as_string = expr->ToString(); + expressions_as_strings_.push_back(expr_as_string); + boost::hash_combine(result, expr_as_string); + } + boost::hash_combine(result, configuration); + boost::hash_combine(result, schema_->ToString()); + hash_code_ = result; + } + + std::size_t Hash() const { return hash_code_; } + + bool operator==(const ProjectorCacheKey &other) const { + // arrow schema does not overload equality operators. + if (!(schema_->Equals(*other.schema().get(), true))) { + return false; + } + + if (configuration_ != other.configuration_) { + return false; + } + + if (expressions_as_strings_ != other.expressions_as_strings_) { + return false; + } + return true; + } + + bool operator!=(const ProjectorCacheKey &other) const { return !(*this == other); } + + SchemaPtr schema() const { return schema_; } + + private: + const SchemaPtr schema_; + const std::shared_ptr configuration_; + std::vector expressions_as_strings_; + size_t hash_code_; +}; +} // namespace gandiva +#endif // GANDIVA_PROJECTOR_CACHE_KEY_H diff --git a/cpp/src/gandiva/codegen/regex_util.cc b/cpp/src/gandiva/codegen/regex_util.cc new file mode 100644 index 00000000000..f0cf7cbf6ed --- /dev/null +++ b/cpp/src/gandiva/codegen/regex_util.cc @@ -0,0 +1,72 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/regex_util.h" + +namespace gandiva { + +#ifdef GDV_HELPERS +namespace helpers { +#endif + +const std::set RegexUtil::posix_regex_specials_ = { + '[', ']', '(', ')', '|', '^', '-', '+', '*', '?', '{', '}', '$', '\\'}; + +Status RegexUtil::SqlLikePatternToPosix(const std::string &sql_pattern, char escape_char, + std::string &posix_pattern) { + /// Characters that are considered special by posix regex. These needs to be + /// escaped with '\\'. + posix_pattern.clear(); + for (size_t idx = 0; idx < sql_pattern.size(); ++idx) { + auto cur = sql_pattern.at(idx); + + // Escape any char that is special for posix regex + if (posix_regex_specials_.find(cur) != posix_regex_specials_.end()) { + posix_pattern += "\\"; + } + + if (cur == escape_char) { + // escape char must be followed by '_', '%' or the escape char itself. + ++idx; + if (idx == sql_pattern.size()) { + std::stringstream msg; + msg << "unexpected escape char at the end of pattern " << sql_pattern; + return Status::Invalid(msg.str()); + } + + cur = sql_pattern.at(idx); + if (cur == '_' || cur == '%' || cur == escape_char) { + posix_pattern += cur; + } else { + std::stringstream msg; + msg << "invalid escape sequence in pattern " << sql_pattern << " at offset " + << idx; + return Status::Invalid(msg.str()); + } + } else if (cur == '_') { + posix_pattern += '.'; + } else if (cur == '%') { + posix_pattern += ".*"; + } else { + posix_pattern += cur; + } + } + return Status::OK(); +} + +#ifdef GDV_HELPERS +} // namespace helpers +#endif + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/regex_util.h b/cpp/src/gandiva/codegen/regex_util.h new file mode 100644 index 00000000000..698ffba6296 --- /dev/null +++ b/cpp/src/gandiva/codegen/regex_util.h @@ -0,0 +1,49 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_REGEX_UTIL_H +#define GANDIVA_REGEX_UTIL_H + +#include +#include "gandiva/status.h" + +namespace gandiva { + +#ifdef GDV_HELPERS +namespace helpers { +#endif + +/// \brief Utility class for converting sql patterns to posix patterns. +class RegexUtil { + public: + // Convert an sql pattern to an std::regex pattern + static Status SqlLikePatternToPosix(const std::string &like_pattern, char escape_char, + std::string &posix_pattern); + + static Status SqlLikePatternToPosix(const std::string &like_pattern, + std::string &posix_pattern) { + return SqlLikePatternToPosix(like_pattern, 0 /*escape_char*/, posix_pattern); + } + + private: + static const std::set posix_regex_specials_; +}; + +#ifdef GDV_HELPERS +} // namespace helpers +#endif + +} // namespace gandiva + +#endif // GANDIVA_REGEX_UTIL_H diff --git a/cpp/src/gandiva/codegen/selection_vector.cc b/cpp/src/gandiva/codegen/selection_vector.cc new file mode 100644 index 00000000000..d816b7b8d43 --- /dev/null +++ b/cpp/src/gandiva/codegen/selection_vector.cc @@ -0,0 +1,138 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/selection_vector.h" + +#include +#include +#include + +#include "codegen/selection_vector_impl.h" +#include "gandiva/status.h" + +namespace gandiva { + +Status SelectionVector::PopulateFromBitMap(const uint8_t *bitmap, int bitmap_size, + int max_bitmap_index) { + if (bitmap_size % 8 != 0) { + std::stringstream ss; + ss << "bitmap size " << bitmap_size << " must be padded to 64-bit size"; + return Status::Invalid(ss.str()); + } + if (static_cast(max_bitmap_index) > GetMaxSupportedValue()) { + std::stringstream ss; + ss << "max_bitmap_index " << max_bitmap_index << " must be <= maxSupportedValue " + << GetMaxSupportedValue() << " in selection vector"; + return Status::Invalid(ss.str()); + } + + // jump 8-bytes at a time, add the index corresponding to each valid bit to the + // the selection vector. + int selection_idx = 0; + const uint64_t *bitmap_64 = reinterpret_cast(bitmap); + for (int bitmap_idx = 0; bitmap_idx < bitmap_size / 8; ++bitmap_idx) { + uint64_t current_word = bitmap_64[bitmap_idx]; + + while (current_word != 0) { + uint64_t highest_only = current_word & -current_word; + int pos_in_word = __builtin_ctzl(highest_only); + + int pos_in_bitmap = bitmap_idx * 64 + pos_in_word; + if (pos_in_bitmap > max_bitmap_index) { + // the bitmap may be slighly larger for alignment/padding. + break; + } + + if (selection_idx >= GetMaxSlots()) { + return Status::Invalid("selection vector has no remaining slots"); + } + SetIndex(selection_idx, pos_in_bitmap); + ++selection_idx; + + current_word ^= highest_only; + } + } + + SetNumSlots(selection_idx); + return Status::OK(); +} + +Status SelectionVector::MakeInt16(int max_slots, std::shared_ptr buffer, + std::shared_ptr *selection_vector) { + auto status = SelectionVectorInt16::ValidateBuffer(max_slots, buffer); + GANDIVA_RETURN_NOT_OK(status); + + *selection_vector = std::make_shared(max_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt16(int max_slots, arrow::MemoryPool *pool, + std::shared_ptr *selection_vector) { + std::shared_ptr buffer; + auto status = SelectionVectorInt16::AllocateBuffer(max_slots, pool, &buffer); + GANDIVA_RETURN_NOT_OK(status); + + *selection_vector = std::make_shared(max_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt32(int max_slots, std::shared_ptr buffer, + std::shared_ptr *selection_vector) { + auto status = SelectionVectorInt32::ValidateBuffer(max_slots, buffer); + GANDIVA_RETURN_NOT_OK(status); + + *selection_vector = std::make_shared(max_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt32(int max_slots, arrow::MemoryPool *pool, + std::shared_ptr *selection_vector) { + std::shared_ptr buffer; + auto status = SelectionVectorInt32::AllocateBuffer(max_slots, pool, &buffer); + GANDIVA_RETURN_NOT_OK(status); + + *selection_vector = std::make_shared(max_slots, buffer); + return Status::OK(); +} + +template +Status SelectionVectorImpl::AllocateBuffer( + int max_slots, arrow::MemoryPool *pool, std::shared_ptr *buffer) { + auto buffer_len = max_slots * sizeof(C_TYPE); + auto astatus = arrow::AllocateBuffer(pool, buffer_len, buffer); + GANDIVA_RETURN_ARROW_NOT_OK(astatus); + + return Status::OK(); +} + +template +Status SelectionVectorImpl::ValidateBuffer( + int max_slots, std::shared_ptr buffer) { + // verify buffer is mutable + if (!buffer->is_mutable()) { + return Status::Invalid("buffer for selection vector must be mutable"); + } + + // verify size of buffer. + int64_t min_len = max_slots * sizeof(C_TYPE); + if (buffer->size() < min_len) { + std::stringstream ss; + ss << "buffer for selection_data has size " << buffer->size() + << ", must have minimum size " << min_len; + return Status::Invalid(ss.str()); + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/selection_vector.h b/cpp/src/gandiva/codegen/selection_vector.h new file mode 100644 index 00000000000..421d9ea3534 --- /dev/null +++ b/cpp/src/gandiva/codegen/selection_vector.h @@ -0,0 +1,94 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_SELECTION_VECTOR__H +#define GANDIVA_SELECTION_VECTOR__H + +#include "gandiva/arrow.h" +#include "gandiva/logging.h" +#include "gandiva/status.h" + +namespace gandiva { + +/// \brief Selection Vector : vector of indices in a row-batch for a selection, +/// backed by an arrow-array. +class SelectionVector { + public: + virtual ~SelectionVector() = default; + + /// Get the value at a given index. + virtual uint GetIndex(int index) const = 0; + + /// Set the value at a given index. + virtual void SetIndex(int index, uint value) = 0; + + // Get the max supported value in the selection vector. + virtual uint GetMaxSupportedValue() const = 0; + + /// The maximum slots (capacity) of the selection vector. + virtual int GetMaxSlots() const = 0; + + /// The number of slots (size) of the selection vector. + virtual int GetNumSlots() const = 0; + + /// Set the number of slots in the selection vector. + virtual void SetNumSlots(int num_slots) = 0; + + /// Convert to arrow-array. + virtual ArrayPtr ToArray() const = 0; + + /// \brief populate selection vector for all the set bits in the bitmap. + /// + /// \param[in] : bitmap the bitmap + /// \param[in] : bitmap_size size of the bitmap in bytes + /// \param[in] : max_bitmap_index max valid index in bitmap (can be lesser than + /// capacity in the bitmap, due to alignment/padding). + Status PopulateFromBitMap(const uint8_t *bitmap, int bitmap_size, int max_bitmap_index); + + /// \brief make selection vector with int16 type records. + /// + /// \param[in] : max_slots max number of slots + /// \param[in] : buffer buffer sized to accomodate max_slots + /// \param[out]: selection_vector selection vector backed by 'buffer' + static Status MakeInt16(int max_slots, std::shared_ptr buffer, + std::shared_ptr *selection_vector); + + /// \param[in] : max_slots max number of slots + /// \param[in] : pool memory pool to allocate buffer + /// \param[out]: selection_vector selection vector backed by a buffer allocated from the + /// pool. + static Status MakeInt16(int max_slots, arrow::MemoryPool *pool, + std::shared_ptr *selection_vector); + + /// \brief make selection vector with int32 type records. + /// + /// \param[in] : max_slots max number of slots + /// \param[in] : buffer buffer sized to accomodate max_slots + /// \param[out]: selection_vector selection vector backed by 'buffer' + static Status MakeInt32(int max_slots, std::shared_ptr buffer, + std::shared_ptr *selection_vector); + + /// \brief make selection vector with int32 type records. + /// + /// \param[in] : max_slots max number of slots + /// \param[in] : pool memory pool to allocate buffer + /// \param[out]: selection_vector selection vector backed by a buffer allocated from the + /// pool. + static Status MakeInt32(int max_slots, arrow::MemoryPool *pool, + std::shared_ptr *selection_vector); +}; + +} // namespace gandiva + +#endif // GANDIVA_SELECTION_VECTOR__H diff --git a/cpp/src/gandiva/codegen/selection_vector_impl.h b/cpp/src/gandiva/codegen/selection_vector_impl.h new file mode 100644 index 00000000000..cfe73aab621 --- /dev/null +++ b/cpp/src/gandiva/codegen/selection_vector_impl.h @@ -0,0 +1,90 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_SELECTION_VECTOR_IMP_H +#define GANDIVA_SELECTION_VECTOR_IMP_H + +#include "gandiva/arrow.h" +#include "gandiva/logging.h" +#include "gandiva/selection_vector.h" +#include "gandiva/status.h" + +namespace gandiva { + +/// \brief template implementation of selection vector with a specific ctype and arrow +/// type. +template +class SelectionVectorImpl : public SelectionVector { + public: + SelectionVectorImpl(int max_slots, std::shared_ptr buffer) + : max_slots_(max_slots), num_slots_(0), buffer_(buffer) { + raw_data_ = reinterpret_cast(buffer->mutable_data()); + } + + uint GetIndex(int index) const override { + DCHECK_LE(index, max_slots_); + return raw_data_[index]; + } + + void SetIndex(int index, uint value) override { + DCHECK_LE(index, max_slots_); + DCHECK_LE(value, GetMaxSupportedValue()); + + raw_data_[index] = value; + } + + ArrayPtr ToArray() const override; + + int GetMaxSlots() const override { return max_slots_; } + + int GetNumSlots() const override { return num_slots_; } + + void SetNumSlots(int num_slots) override { + DCHECK_LE(num_slots, max_slots_); + num_slots_ = num_slots; + } + + uint GetMaxSupportedValue() const override { + return std::numeric_limits::max(); + } + + static Status AllocateBuffer(int max_slots, arrow::MemoryPool *pool, + std::shared_ptr *buffer); + + static Status ValidateBuffer(int max_slots, std::shared_ptr buffer); + + protected: + /// maximum slots in the vector + int max_slots_; + + /// number of slots in the vector + int num_slots_; + + std::shared_ptr buffer_; + C_TYPE *raw_data_; +}; + +template +ArrayPtr SelectionVectorImpl::ToArray() const { + auto data_type = arrow::TypeTraits::type_singleton(); + auto array_data = arrow::ArrayData::Make(data_type, num_slots_, {nullptr, buffer_}); + return arrow::MakeArray(array_data); +} + +using SelectionVectorInt16 = SelectionVectorImpl; +using SelectionVectorInt32 = SelectionVectorImpl; + +} // namespace gandiva + +#endif // GANDIVA_SELECTION_VECTOR_IMPL_H diff --git a/cpp/src/gandiva/codegen/selection_vector_test.cc b/cpp/src/gandiva/codegen/selection_vector_test.cc new file mode 100644 index 00000000000..03a9ce95c37 --- /dev/null +++ b/cpp/src/gandiva/codegen/selection_vector_test.cc @@ -0,0 +1,207 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/selection_vector.h" + +#include + +#include + +namespace gandiva { + +class TestSelectionVector : public ::testing::Test { + protected: + virtual void SetUp() { pool_ = arrow::default_memory_pool(); } + + arrow::MemoryPool *pool_; +}; + +static inline uint32_t RoundUpNumi64(uint32_t value) { return (value + 63) >> 6; } + +TEST_F(TestSelectionVector, TestInt16Make) { + int max_slots = 10; + + // Test with pool allocation + std::shared_ptr selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + EXPECT_EQ(selection->GetMaxSlots(), max_slots); + EXPECT_EQ(selection->GetNumSlots(), 0); + + // Test with pre-alloced buffer + std::shared_ptr selection2; + std::shared_ptr buffer; + auto buffer_len = max_slots * sizeof(int16_t); + auto astatus = arrow::AllocateBuffer(pool_, buffer_len, &buffer); + EXPECT_EQ(astatus.ok(), true); + + status = SelectionVector::MakeInt16(max_slots, buffer, &selection2); + EXPECT_EQ(status.ok(), true) << status.message(); + EXPECT_EQ(selection2->GetMaxSlots(), max_slots); + EXPECT_EQ(selection2->GetNumSlots(), 0); +} + +TEST_F(TestSelectionVector, TestInt16MakeNegative) { + int max_slots = 10; + + std::shared_ptr selection; + std::shared_ptr buffer; + auto buffer_len = max_slots * sizeof(int16_t); + + // alloc a buffer that's insufficient. + auto astatus = arrow::AllocateBuffer(pool_, buffer_len - 16, &buffer); + EXPECT_EQ(astatus.ok(), true); + + auto status = SelectionVector::MakeInt16(max_slots, buffer, &selection); + EXPECT_EQ(status.IsInvalid(), true); +} + +TEST_F(TestSelectionVector, TestInt16Set) { + int max_slots = 10; + + std::shared_ptr selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + selection->SetIndex(0, 100); + EXPECT_EQ(selection->GetIndex(0), 100); + + selection->SetIndex(1, 200); + EXPECT_EQ(selection->GetIndex(1), 200); + + selection->SetNumSlots(2); + EXPECT_EQ(selection->GetNumSlots(), 2); + + // TopArray() should return an array with 100,200 + auto array_raw = selection->ToArray(); + const auto &array = dynamic_cast(*array_raw); + EXPECT_EQ(array.length(), 2) << array_raw->ToString(); + EXPECT_EQ(array.Value(0), 100) << array_raw->ToString(); + EXPECT_EQ(array.Value(1), 200) << array_raw->ToString(); +} + +TEST_F(TestSelectionVector, TestInt16PopulateFromBitMap) { + int max_slots = 200; + + std::shared_ptr selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = RoundUpNumi64(max_slots) * 8; + std::unique_ptr bitmap(new uint8_t[bitmap_size]); + memset(bitmap.get(), 0, bitmap_size); + + arrow::BitUtil::SetBit(bitmap.get(), 0); + arrow::BitUtil::SetBit(bitmap.get(), 5); + arrow::BitUtil::SetBit(bitmap.get(), 121); + arrow::BitUtil::SetBit(bitmap.get(), 220); + + status = selection->PopulateFromBitMap(bitmap.get(), bitmap_size, max_slots - 1); + EXPECT_EQ(status.ok(), true) << status.message(); + + EXPECT_EQ(selection->GetNumSlots(), 3); + EXPECT_EQ(selection->GetIndex(0), 0); + EXPECT_EQ(selection->GetIndex(1), 5); + EXPECT_EQ(selection->GetIndex(2), 121); +} + +TEST_F(TestSelectionVector, TestInt16PopulateFromBitMapNegative) { + int max_slots = 2; + + std::shared_ptr selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = 16; + std::unique_ptr bitmap(new uint8_t[bitmap_size]); + memset(bitmap.get(), 0, bitmap_size); + + arrow::BitUtil::SetBit(bitmap.get(), 0); + arrow::BitUtil::SetBit(bitmap.get(), 1); + arrow::BitUtil::SetBit(bitmap.get(), 2); + + // The bitmap has three set bits, whereas the selection vector has capacity for only 2. + status = selection->PopulateFromBitMap(bitmap.get(), bitmap_size, 2); + EXPECT_EQ(status.IsInvalid(), true); +} + +TEST_F(TestSelectionVector, TestInt32Set) { + int max_slots = 10; + + std::shared_ptr selection; + auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + selection->SetIndex(0, 100); + EXPECT_EQ(selection->GetIndex(0), 100); + + selection->SetIndex(1, 200); + EXPECT_EQ(selection->GetIndex(1), 200); + + selection->SetIndex(2, 100000); + EXPECT_EQ(selection->GetIndex(2), 100000); + + selection->SetNumSlots(3); + EXPECT_EQ(selection->GetNumSlots(), 3); + + // TopArray() should return an array with 100,200,100000 + auto array_raw = selection->ToArray(); + const auto &array = dynamic_cast(*array_raw); + EXPECT_EQ(array.length(), 3) << array_raw->ToString(); + EXPECT_EQ(array.Value(0), 100) << array_raw->ToString(); + EXPECT_EQ(array.Value(1), 200) << array_raw->ToString(); + EXPECT_EQ(array.Value(2), 100000) << array_raw->ToString(); +} + +TEST_F(TestSelectionVector, TestInt32PopulateFromBitMap) { + int max_slots = 200; + + std::shared_ptr selection; + auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = RoundUpNumi64(max_slots) * 8; + std::unique_ptr bitmap(new uint8_t[bitmap_size]); + memset(bitmap.get(), 0, bitmap_size); + + arrow::BitUtil::SetBit(bitmap.get(), 0); + arrow::BitUtil::SetBit(bitmap.get(), 5); + arrow::BitUtil::SetBit(bitmap.get(), 121); + arrow::BitUtil::SetBit(bitmap.get(), 220); + + status = selection->PopulateFromBitMap(bitmap.get(), bitmap_size, max_slots - 1); + EXPECT_EQ(status.ok(), true) << status.message(); + + EXPECT_EQ(selection->GetNumSlots(), 3); + EXPECT_EQ(selection->GetIndex(0), 0); + EXPECT_EQ(selection->GetIndex(1), 5); + EXPECT_EQ(selection->GetIndex(2), 121); +} + +TEST_F(TestSelectionVector, TestInt32MakeNegative) { + int max_slots = 10; + + std::shared_ptr selection; + std::shared_ptr buffer; + auto buffer_len = max_slots * sizeof(int32_t); + + // alloc a buffer that's insufficient. + auto astatus = arrow::AllocateBuffer(pool_, buffer_len - 1, &buffer); + EXPECT_EQ(astatus.ok(), true); + + auto status = SelectionVector::MakeInt32(max_slots, buffer, &selection); + EXPECT_EQ(status.IsInvalid(), true); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/status.h b/cpp/src/gandiva/codegen/status.h new file mode 100644 index 00000000000..33c491dd57d --- /dev/null +++ b/cpp/src/gandiva/codegen/status.h @@ -0,0 +1,260 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Adapted from Apache Arrow Status. + */ +#ifndef GANDIVA_STATUS_H +#define GANDIVA_STATUS_H + +#include +#include +#include +#include + +#define GANDIVA_RETURN_NOT_OK(status) \ + do { \ + Status _status = (status); \ + if (!_status.ok()) { \ + std::stringstream ss; \ + ss << __FILE__ << ":" << __LINE__ << " code: " << _status.CodeAsString() << " \n " \ + << _status.message(); \ + return Status(_status.code(), ss.str()); \ + } \ + } while (0) + +#define GANDIVA_RETURN_FAILURE_IF_FALSE(condition, status) \ + do { \ + if (!(condition)) { \ + Status _status = (status); \ + std::stringstream ss; \ + ss << __FILE__ << ":" << __LINE__ << " code: " << _status.CodeAsString() << " \n " \ + << _status.message(); \ + return Status(_status.code(), ss.str()); \ + } \ + } while (0) + +// Check arrow status & convert to gandiva status on error. +#define GANDIVA_RETURN_ARROW_NOT_OK(astatus) \ + do { \ + if (!(astatus).ok()) { \ + return Status(StatusCode::ArrowError, (astatus).message()); \ + } \ + } while (0) + +namespace gandiva { + +enum class StatusCode : char { + OK = 0, + Invalid = 1, + CodeGenError = 2, + ArrowError = 3, + ExpressionValidationError = 4, +}; + +class Status { + public: + // Create a success status. + Status() : state_(NULL) {} + ~Status() { delete state_; } + + Status(StatusCode code, const std::string& msg); + + // Copy the specified status. + Status(const Status& s); + Status& operator=(const Status& s); + + // Move the specified status. + Status(Status&& s); + Status& operator=(Status&& s); + + // AND the statuses. + Status operator&(const Status& s) const; + Status operator&(Status&& s) const; + Status& operator&=(const Status& s); + Status& operator&=(Status&& s); + + // Return a success status. + static Status OK() { return Status(); } + + // Return error status of an appropriate type. + static Status CodeGenError(const std::string& msg) { + return Status(StatusCode::CodeGenError, msg); + } + + static Status Invalid(const std::string& msg) { + return Status(StatusCode::Invalid, msg); + } + + static Status ArrowError(const std::string& msg) { + return Status(StatusCode::ArrowError, msg); + } + + static Status ExpressionValidationError(const std::string& msg) { + return Status(StatusCode::ExpressionValidationError, msg); + } + + // Returns true if the status indicates success. + bool ok() const { return (state_ == NULL); } + + bool IsCodeGenError() const { return code() == StatusCode::CodeGenError; } + + bool IsInvalid() const { return code() == StatusCode::Invalid; } + + bool IsArrowError() const { return code() == StatusCode::ArrowError; } + + bool IsExpressionValidationError() const { + return code() == StatusCode::ExpressionValidationError; + } + + // Return a string representation of this status suitable for printing. + // Returns the string "OK" for success. + std::string ToString() const; + + // Return a string representation of the status code, without the message + // text or posix code information. + std::string CodeAsString() const; + + StatusCode code() const { return ok() ? StatusCode::OK : state_->code; } + + std::string message() const { return ok() ? "" : state_->msg; } + + private: + struct State { + StatusCode code; + std::string msg; + }; + // OK status has a `NULL` state_. Otherwise, `state_` points to + // a `State` structure containing the error code and message(s) + State* state_; + + void CopyFrom(const Status& s); + void MoveFrom(Status& s); +}; + +static inline std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +inline Status::Status(const Status& s) + : state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {} + +inline Status& Status::operator=(const Status& s) { + // The following condition catches both aliasing (when this == &s), + // and the common case where both s and *this are ok. + if (state_ != s.state_) { + CopyFrom(s); + } + return *this; +} + +inline Status::Status(Status&& s) : state_(s.state_) { s.state_ = NULL; } + +inline Status& Status::operator=(Status&& s) { + MoveFrom(s); + return *this; +} + +inline Status Status::operator&(const Status& s) const { + if (ok()) { + return s; + } else { + return *this; + } +} + +inline Status Status::operator&(Status&& s) const { + if (ok()) { + return std::move(s); + } else { + return *this; + } +} + +inline Status& Status::operator&=(const Status& s) { + if (ok() && !s.ok()) { + CopyFrom(s); + } + return *this; +} + +inline Status& Status::operator&=(Status&& s) { + if (ok() && !s.ok()) { + MoveFrom(s); + } + return *this; +} + +inline Status::Status(StatusCode code, const std::string& msg) { + assert(code != StatusCode::OK); + state_ = new State; + state_->code = code; + state_->msg = msg; +} + +inline void Status::CopyFrom(const Status& s) { + delete state_; + if (s.state_ == nullptr) { + state_ = nullptr; + } else { + state_ = new State(*s.state_); + } +} + +inline std::string Status::CodeAsString() const { + if (state_ == nullptr) { + return "OK"; + } + + const char* type; + switch (code()) { + case StatusCode::OK: + type = "OK"; + break; + case StatusCode::CodeGenError: + type = "CodeGenError"; + break; + case StatusCode::Invalid: + type = "Invalid"; + break; + case StatusCode::ExpressionValidationError: + type = "ExpressionValidationError"; + break; + default: + type = "Unknown"; + break; + } + return std::string(type); +} + +inline void Status::MoveFrom(Status& s) { + delete state_; + state_ = s.state_; + s.state_ = NULL; +} + +inline std::string Status::ToString() const { + std::string result(CodeAsString()); + if (state_ == NULL) { + return result; + } + result += ": "; + result += state_->msg; + return result; +} + +} // namespace gandiva + +#endif // GANDIVA_STATUS_H diff --git a/cpp/src/gandiva/codegen/status_test.cc b/cpp/src/gandiva/codegen/status_test.cc new file mode 100644 index 00000000000..7f3ac79cf5f --- /dev/null +++ b/cpp/src/gandiva/codegen/status_test.cc @@ -0,0 +1,70 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Adapted from Apache Arrow Status. + +#include "gandiva/status.h" + +#include + +#include + +namespace gandiva { + +TEST(StatusTest, TestCodeAndMessage) { + Status ok = Status::OK(); + ASSERT_EQ(StatusCode::OK, ok.code()); + Status code_gen_error = Status::CodeGenError("input invalid."); + ASSERT_EQ(StatusCode::CodeGenError, code_gen_error.code()); + ASSERT_EQ("input invalid.", code_gen_error.message()); +} + +TEST(StatusTest, TestToString) { + Status code_gen_error = Status::CodeGenError("input invalid."); + ASSERT_EQ("CodeGenError: input invalid.", code_gen_error.ToString()); + + std::stringstream ss; + ss << code_gen_error; + ASSERT_EQ(code_gen_error.ToString(), ss.str()); +} + +TEST(StatusTest, AndStatus) { + Status a = Status::OK(); + Status b = Status::OK(); + Status c = Status::CodeGenError("invalid value"); + + Status res; + res = a & b; + ASSERT_TRUE(res.ok()); + res = a & c; + ASSERT_TRUE(res.IsCodeGenError()); + + res = Status::OK(); + res &= c; + ASSERT_TRUE(res.IsCodeGenError()); + + // With rvalues + res = Status::OK() & Status::CodeGenError("foo"); + ASSERT_TRUE(res.IsCodeGenError()); + res = Status::CodeGenError("foo") & Status::OK(); + ASSERT_TRUE(res.IsCodeGenError()); + + res = Status::OK(); + res &= Status::OK(); + ASSERT_TRUE(res.ok()); + res &= Status::CodeGenError("foo"); + ASSERT_TRUE(res.IsCodeGenError()); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/tree_expr_builder.cc b/cpp/src/gandiva/codegen/tree_expr_builder.cc new file mode 100644 index 00000000000..07ad839a62d --- /dev/null +++ b/cpp/src/gandiva/codegen/tree_expr_builder.cc @@ -0,0 +1,174 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/tree_expr_builder.h" + +#include + +#include "codegen/node.h" + +namespace gandiva { + +#define MAKE_LITERAL(atype, ctype) \ + NodePtr TreeExprBuilder::MakeLiteral(ctype value) { \ + return std::make_shared(atype, LiteralHolder(value), false); \ + } + +MAKE_LITERAL(arrow::boolean(), bool) +MAKE_LITERAL(arrow::int8(), int8_t) +MAKE_LITERAL(arrow::int16(), int16_t) +MAKE_LITERAL(arrow::int32(), int32_t) +MAKE_LITERAL(arrow::int64(), int64_t) +MAKE_LITERAL(arrow::uint8(), uint8_t) +MAKE_LITERAL(arrow::uint16(), uint16_t) +MAKE_LITERAL(arrow::uint32(), uint32_t) +MAKE_LITERAL(arrow::uint64(), uint64_t) +MAKE_LITERAL(arrow::float32(), float) +MAKE_LITERAL(arrow::float64(), double) + +NodePtr TreeExprBuilder::MakeStringLiteral(const std::string &value) { + return std::make_shared(arrow::utf8(), LiteralHolder(value), false); +} + +NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string &value) { + return std::make_shared(arrow::binary(), LiteralHolder(value), false); +} + +NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) { + static const std::string empty; + + if (data_type == nullptr) { + return nullptr; + } + + switch (data_type->id()) { + case arrow::Type::BOOL: + return std::make_shared(data_type, LiteralHolder(false), true); + case arrow::Type::INT8: + return std::make_shared(data_type, LiteralHolder((int8_t)0), true); + case arrow::Type::INT16: + return std::make_shared(data_type, LiteralHolder((int16_t)0), true); + case arrow::Type::INT32: + return std::make_shared(data_type, LiteralHolder((int32_t)0), true); + case arrow::Type::INT64: + return std::make_shared(data_type, LiteralHolder((int64_t)0), true); + case arrow::Type::UINT8: + return std::make_shared(data_type, LiteralHolder((uint8_t)0), true); + case arrow::Type::UINT16: + return std::make_shared(data_type, LiteralHolder((uint16_t)0), true); + case arrow::Type::UINT32: + return std::make_shared(data_type, LiteralHolder((uint32_t)0), true); + case arrow::Type::UINT64: + return std::make_shared(data_type, LiteralHolder((uint64_t)0), true); + case arrow::Type::FLOAT: + return std::make_shared(data_type, LiteralHolder((float_t)0), true); + case arrow::Type::DOUBLE: + return std::make_shared(data_type, LiteralHolder((double_t)0), true); + case arrow::Type::STRING: + case arrow::Type::BINARY: + return std::make_shared(data_type, LiteralHolder(empty), true); + case arrow::Type::DATE64: + return std::make_shared(data_type, LiteralHolder((int64_t)0), true); + case arrow::Type::TIME32: + return std::make_shared(data_type, LiteralHolder((int32_t)0), true); + case arrow::Type::TIME64: + return std::make_shared(data_type, LiteralHolder((int64_t)0), true); + case arrow::Type::TIMESTAMP: + return std::make_shared(data_type, LiteralHolder((int64_t)0), true); + default: + return nullptr; + } +} + +NodePtr TreeExprBuilder::MakeField(FieldPtr field) { + return NodePtr(new FieldNode(field)); +} + +NodePtr TreeExprBuilder::MakeFunction(const std::string &name, const NodeVector ¶ms, + DataTypePtr result_type) { + if (result_type == nullptr) { + return nullptr; + } + return FunctionNode::MakeFunction(name, params, result_type); +} + +NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, + DataTypePtr result_type) { + if (condition == nullptr || then_node == nullptr || else_node == nullptr || + result_type == nullptr) { + return nullptr; + } + return std::make_shared(condition, then_node, else_node, result_type); +} + +NodePtr TreeExprBuilder::MakeAnd(const NodeVector &children) { + return std::make_shared(BooleanNode::AND, children); +} + +NodePtr TreeExprBuilder::MakeOr(const NodeVector &children) { + return std::make_shared(BooleanNode::OR, children); +} + +// set this to true to print expressions for debugging purposes +static bool print_expr = false; + +ExpressionPtr TreeExprBuilder::MakeExpression(NodePtr root_node, FieldPtr result_field) { + if (result_field == nullptr) { + return nullptr; + } + if (print_expr) { + std::cout << "Expression: " << root_node->ToString() << "\n"; + } + return ExpressionPtr(new Expression(root_node, result_field)); +} + +ExpressionPtr TreeExprBuilder::MakeExpression(const std::string &function, + const FieldVector &in_fields, + FieldPtr out_field) { + if (out_field == nullptr) { + return nullptr; + } + std::vector field_nodes; + for (auto &field : in_fields) { + auto node = MakeField(field); + field_nodes.push_back(node); + } + auto func_node = FunctionNode::MakeFunction(function, field_nodes, out_field->type()); + return MakeExpression(func_node, out_field); +} + +ConditionPtr TreeExprBuilder::MakeCondition(NodePtr root_node) { + if (root_node == nullptr) { + return nullptr; + } + if (print_expr) { + std::cout << "Condition: " << root_node->ToString() << "\n"; + } + + return ConditionPtr(new Condition(root_node)); +} + +ConditionPtr TreeExprBuilder::MakeCondition(const std::string &function, + const FieldVector &in_fields) { + std::vector field_nodes; + for (auto &field : in_fields) { + auto node = MakeField(field); + field_nodes.push_back(node); + } + + auto func_node = FunctionNode::MakeFunction(function, field_nodes, arrow::boolean()); + return ConditionPtr(new Condition(func_node)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/tree_expr_builder.h b/cpp/src/gandiva/codegen/tree_expr_builder.h new file mode 100644 index 00000000000..87d6608dc1a --- /dev/null +++ b/cpp/src/gandiva/codegen/tree_expr_builder.h @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2017-2018 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GANDIVA_EXPR_TREE_BUILDER_H +#define GANDIVA_EXPR_TREE_BUILDER_H + +#include +#include + +#include "gandiva/condition.h" +#include "gandiva/expression.h" + +namespace gandiva { + +/// \brief Tree Builder for a nested expression. +class TreeExprBuilder { + public: + /// \brief create a node on a literal. + static NodePtr MakeLiteral(bool value); + static NodePtr MakeLiteral(uint8_t value); + static NodePtr MakeLiteral(uint16_t value); + static NodePtr MakeLiteral(uint32_t value); + static NodePtr MakeLiteral(uint64_t value); + static NodePtr MakeLiteral(int8_t value); + static NodePtr MakeLiteral(int16_t value); + static NodePtr MakeLiteral(int32_t value); + static NodePtr MakeLiteral(int64_t value); + static NodePtr MakeLiteral(float value); + static NodePtr MakeLiteral(double value); + static NodePtr MakeStringLiteral(const std::string &value); + static NodePtr MakeBinaryLiteral(const std::string &value); + + /// \brief create a node on a null literal. + /// returns null if data_type is null or if it's not a supported datatype. + static NodePtr MakeNull(DataTypePtr data_type); + + /// \brief create a node on arrow field. + /// returns null if input is null. + static NodePtr MakeField(FieldPtr field); + + /// \brief create a node with a function. + /// returns null if return_type is null + static NodePtr MakeFunction(const std::string &name, const NodeVector ¶ms, + DataTypePtr return_type); + + /// \brief create a node with an if-else expression. + /// returns null if any of the inputs is null. + static NodePtr MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, + DataTypePtr result_type); + + /// \brief create a node with a boolean AND expression. + static NodePtr MakeAnd(const NodeVector &children); + + /// \brief create a node with a boolean OR expression. + static NodePtr MakeOr(const NodeVector &children); + + /// \brief create an expression with the specified root_node, and the + /// result written to result_field. + /// returns null if the result_field is null. + static ExpressionPtr MakeExpression(NodePtr root_node, FieldPtr result_field); + + /// \brief convenience function for simple function expressions. + /// returns null if the out_field is null. + static ExpressionPtr MakeExpression(const std::string &function, + const FieldVector &in_fields, FieldPtr out_field); + + /// \brief create a condition with the specified root_node + static ConditionPtr MakeCondition(NodePtr root_node); + + /// \brief convenience function for simple function conditions. + static ConditionPtr MakeCondition(const std::string &function, + const FieldVector &in_fields); +}; + +} // namespace gandiva + +#endif // GANDIVA_EXPR_TREE_BUILDER_H diff --git a/cpp/src/gandiva/codegen/tree_expr_test.cc b/cpp/src/gandiva/codegen/tree_expr_test.cc new file mode 100644 index 00000000000..295ac0186c2 --- /dev/null +++ b/cpp/src/gandiva/codegen/tree_expr_test.cc @@ -0,0 +1,161 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/tree_expr_builder.h" + +#include +#include "codegen/annotator.h" +#include "codegen/dex.h" +#include "codegen/expr_decomposer.h" +#include "codegen/function_registry.h" +#include "codegen/node.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; + +class TestExprTree : public ::testing::Test { + public: + void SetUp() { + i0_ = field("i0", int32()); + i1_ = field("i1", int32()); + + b0_ = field("b0", boolean()); + } + + protected: + FieldPtr i0_; // int32 + FieldPtr i1_; // int32 + + FieldPtr b0_; // bool + FunctionRegistry registry_; +}; + +TEST_F(TestExprTree, TestField) { + Annotator annotator; + + auto n0 = TreeExprBuilder::MakeField(i0_); + EXPECT_EQ(n0->return_type(), int32()); + + auto n1 = TreeExprBuilder::MakeField(b0_); + EXPECT_EQ(n1->return_type(), boolean()); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n1, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto value_dex = std::dynamic_pointer_cast(value); + EXPECT_EQ(value_dex->FieldType(), boolean()); + + EXPECT_EQ(pair->validity_exprs().size(), 1); + auto validity = pair->validity_exprs().at(0); + auto validity_dex = std::dynamic_pointer_cast(validity); + EXPECT_NE(validity_dex->ValidityIdx(), value_dex->DataIdx()); +} + +TEST_F(TestExprTree, TestBinary) { + Annotator annotator; + + auto left = TreeExprBuilder::MakeField(i0_); + auto right = TreeExprBuilder::MakeField(i1_); + + auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32()); + auto add = std::dynamic_pointer_cast(n); + + auto func_desc = add->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + + EXPECT_EQ(add->return_type(), int32()); + EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto null_if_null = std::dynamic_pointer_cast(value); + + FunctionSignature signature("add", {int32(), int32()}, int32()); + const NativeFunction *fn = registry_.LookupSignature(signature); + EXPECT_EQ(null_if_null->native_function(), fn); +} + +TEST_F(TestExprTree, TestUnary) { + Annotator annotator; + + auto arg = TreeExprBuilder::MakeField(i0_); + auto n = TreeExprBuilder::MakeFunction("isnumeric", {arg}, boolean()); + + auto unaryFn = std::dynamic_pointer_cast(n); + auto func_desc = unaryFn->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + EXPECT_EQ(unaryFn->return_type(), boolean()); + EXPECT_TRUE(sign == FunctionSignature("isnumeric", {int32()}, boolean())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto never_null = std::dynamic_pointer_cast(value); + + FunctionSignature signature("isnumeric", {int32()}, boolean()); + const NativeFunction *fn = registry_.LookupSignature(signature); + EXPECT_EQ(never_null->native_function(), fn); +} + +TEST_F(TestExprTree, TestExpression) { + Annotator annotator; + auto left = TreeExprBuilder::MakeField(i0_); + auto right = TreeExprBuilder::MakeField(i1_); + + auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32()); + auto e = TreeExprBuilder::MakeExpression(n, field("r", int32())); + auto root_node = e->root(); + EXPECT_EQ(root_node->return_type(), int32()); + + auto add_node = std::dynamic_pointer_cast(root_node); + auto func_desc = add_node->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*root_node, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto null_if_null = std::dynamic_pointer_cast(value); + + FunctionSignature signature("add", {int32(), int32()}, int32()); + const NativeFunction *fn = registry_.LookupSignature(signature); + EXPECT_EQ(null_if_null->native_function(), fn); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/value_validity_pair.h b/cpp/src/gandiva/codegen/value_validity_pair.h new file mode 100644 index 00000000000..8f187db7ee3 --- /dev/null +++ b/cpp/src/gandiva/codegen/value_validity_pair.h @@ -0,0 +1,47 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_VALUEVALIDITYPAIR_H +#define GANDIVA_VALUEVALIDITYPAIR_H + +#include + +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// Pair of vector/validities generated after decomposing an expression tree/subtree. +class ValueValidityPair { + public: + ValueValidityPair(const DexVector &validity_exprs, DexPtr value_expr) + : validity_exprs_(validity_exprs), value_expr_(value_expr) {} + + ValueValidityPair(DexPtr validity_expr, DexPtr value_expr) : value_expr_(value_expr) { + validity_exprs_.push_back(validity_expr); + } + + explicit ValueValidityPair(DexPtr value_expr) : value_expr_(value_expr) {} + + const DexVector &validity_exprs() const { return validity_exprs_; } + + const DexPtr &value_expr() const { return value_expr_; } + + private: + DexVector validity_exprs_; + DexPtr value_expr_; +}; + +} // namespace gandiva + +#endif // GANDIVA_VALUEVALIDITYPAIR_H diff --git a/cpp/src/gandiva/integ/CMakeLists.txt b/cpp/src/gandiva/integ/CMakeLists.txt new file mode 100644 index 00000000000..45f592fb06b --- /dev/null +++ b/cpp/src/gandiva/integ/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (C) 2017-2018 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project(gandiva) + +foreach(lib_type "shared" "static") + add_gandiva_integ_test(filter_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(projector_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(if_expr_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(literal_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(projector_build_validation_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(boolean_expr_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(utf8_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(binary_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(date_time_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(micro_benchmarks.cc gandiva_${lib_type}) + add_gandiva_integ_test(to_string_test.cc gandiva_${lib_type}) + add_gandiva_integ_test(hash_test.cc gandiva_${lib_type}) +endforeach(lib_type) diff --git a/cpp/src/gandiva/integ/binary_test.cc b/cpp/src/gandiva/integ/binary_test.cc new file mode 100644 index 00000000000..3636a96dd57 --- /dev/null +++ b/cpp/src/gandiva/integ/binary_test.cc @@ -0,0 +1,85 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::binary; +using arrow::boolean; +using arrow::int32; + +class TestBinary : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestBinary, TestSimple) { + // schema for input fields + auto field_a = field("a", binary()); + auto field_b = field("b", binary()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", int32()); + + // build expressions. + // a > b ? octet_length(a) : octet_length(b) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto octet_len_a = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32()); + auto octet_len_b = TreeExprBuilder::MakeFunction("octet_length", {node_b}, int32()); + + auto is_greater = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_greater = + TreeExprBuilder::MakeIf(is_greater, octet_len_a, octet_len_b, int32()); + auto expr = TreeExprBuilder::MakeExpression(if_greater, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayBinary({"foo", "hello", "hi", "bye"}, {true, true, true, false}); + auto array_b = + MakeArrowArrayBinary({"fo", "hellos", "hi", "bye"}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({3, 6, 2, 3}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/boolean_expr_test.cc b/cpp/src/gandiva/integ/boolean_expr_test.cc new file mode 100644 index 00000000000..66cfb110e6a --- /dev/null +++ b/cpp/src/gandiva/integ/boolean_expr_test.cc @@ -0,0 +1,384 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; + +class TestBooleanExpr : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestBooleanExpr, SimpleAnd) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) && (b > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + + auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // FALSE_VALID && ? => FALSE_VALID + int num_records = 4; + auto arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true}); + auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + auto exp = MakeArrowArrayBool({false, false, false, false}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_VALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, SimpleOr) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) || (b > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + + auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // TRUE_VALID && ? => TRUE_VALID + int num_records = 4; + auto arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true}); + auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + auto exp = MakeArrowArrayBool({true, true, true, true}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_VALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, AndThree) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) && (b > 0) && (c > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + auto c_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean()); + + auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0, c_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 8; + std::vector validity({true, true, true, true, true, true, true, true}); + auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity); + auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity); + auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity); + auto exp = MakeArrowArrayBool({true, false, false, false, false, false, false, false}, + validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, OrThree) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) || (b > 0) || (c > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + auto c_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean()); + + auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0, c_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 8; + std::vector validity({true, true, true, true, true, true, true, true}); + auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity); + auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity); + auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity); + auto exp = + MakeArrowArrayBool({true, true, true, true, true, true, true, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, BooleanAndInsideIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // if (a > 2 && b > 2) + // a > 3 && b > 3 + // else + // a > 1 && b > 1 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_1 = TreeExprBuilder::MakeLiteral((int32_t)1); + auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2); + auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3); + auto a_gt_1 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_1}, boolean()); + auto a_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean()); + auto a_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean()); + auto b_gt_1 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_1}, boolean()); + auto b_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean()); + auto b_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean()); + + auto and_1 = TreeExprBuilder::MakeAnd({a_gt_1, b_gt_1}); + auto and_2 = TreeExprBuilder::MakeAnd({a_gt_2, b_gt_2}); + auto and_3 = TreeExprBuilder::MakeAnd({a_gt_3, b_gt_3}); + + auto node_if = TreeExprBuilder::MakeIf(and_2, and_3, and_1, arrow::boolean()); + auto expr = TreeExprBuilder::MakeExpression(node_if, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 4; + std::vector validity({true, true, true, true}); + auto arraya = MakeArrowArrayInt32({4, 4, 2, 1}, validity); + auto arrayb = MakeArrowArrayInt32({5, 3, 3, 1}, validity); + auto exp = MakeArrowArrayBool({true, false, true, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, IfInsideBooleanAnd) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (if (a > b) a > 3 else b > 3) && (if (a > b) a > 2 else b > 2) + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2); + auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3); + auto a_gt_b = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto a_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean()); + auto a_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean()); + auto b_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean()); + auto b_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean()); + + auto if_3 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_3, b_gt_3, arrow::boolean()); + auto if_2 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_2, b_gt_2, arrow::boolean()); + auto node_and = TreeExprBuilder::MakeAnd({if_3, if_2}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 4; + std::vector validity({true, true, true, true}); + auto arraya = MakeArrowArrayInt32({4, 3, 3, 2}, validity); + auto arrayb = MakeArrowArrayInt32({3, 4, 2, 3}, validity); + auto exp = MakeArrowArrayBool({true, true, false, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/date_time_test.cc b/cpp/src/gandiva/integ/date_time_test.cc new file mode 100644 index 00000000000..76bde5adc25 --- /dev/null +++ b/cpp/src/gandiva/integ/date_time_test.cc @@ -0,0 +1,322 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::date64; +using arrow::float32; +using arrow::int32; +using arrow::int64; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +int32_t MillisInDay(int32_t hh, int32_t mm, int32_t ss, int32_t millis) { + int32_t mins = hh * 60 + mm; + int32_t secs = mins * 60 + ss; + + return secs * 1000 + millis; +} + +int64_t MillisSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t hr, + int32_t min, int32_t sec, int32_t millis) { + struct tm given_ts = {0}; + given_ts.tm_year = (yy - 1900); + given_ts.tm_mon = (mm - 1); + given_ts.tm_mday = dd; + given_ts.tm_hour = hr; + given_ts.tm_min = min; + given_ts.tm_sec = sec; + + return (lround(difftime(mktime(&given_ts), base_line)) * 1000 + millis); +} + +TEST_F(TestProjector, TestIsNull) { + auto d0 = field("d0", date64()); + auto t0 = field("t0", time32(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({d0, t0}); + + // output fields + auto b0 = field("isnull", boolean()); + + // isnull and isnotnull + auto isnull_expr = TreeExprBuilder::MakeExpression("isnull", {d0}, b0); + auto isnotnull_expr = TreeExprBuilder::MakeExpression("isnotnull", {t0}, b0); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {isnull_expr, isnotnull_expr}, &projector); + ASSERT_TRUE(status.ok()); + + int num_records = 4; + std::vector d0_data = {0, 100, 0, 1000}; + auto t0_data = {0, 100, 0, 1000}; + auto validity = {false, true, false, true}; + auto d0_array = + MakeArrowTypeArray(date64(), d0_data, validity); + auto t0_array = MakeArrowTypeArray( + time32(arrow::TimeUnit::MILLI), t0_data, validity); + + // expected output + auto exp_isnull = + MakeArrowArrayBool({true, false, true, false}, {true, true, true, true}); + auto exp_isnotnull = MakeArrowArrayBool(validity, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {d0_array, t0_array}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_isnotnull, outputs.at(1)); +} + +TEST_F(TestProjector, TestDateTime) { + auto field0 = field("f0", date64()); + auto field2 = field("f2", timestamp(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({field0, field2}); + + // output fields + auto field_year = field("yy", int64()); + auto field_month = field("mm", int64()); + auto field_day = field("dd", int64()); + auto field_hour = field("hh", int64()); + + // extract year and month from date + auto date2year_expr = + TreeExprBuilder::MakeExpression("extractYear", {field0}, field_year); + auto date2month_expr = + TreeExprBuilder::MakeExpression("extractMonth", {field0}, field_month); + + // extract month and day from timestamp + auto ts2month_expr = + TreeExprBuilder::MakeExpression("extractMonth", {field2}, field_month); + auto ts2day_expr = TreeExprBuilder::MakeExpression("extractDay", {field2}, field_day); + + std::shared_ptr projector; + Status status = Projector::Make( + schema, {date2year_expr, date2month_expr, ts2month_expr, ts2day_expr}, &projector); + ASSERT_TRUE(status.ok()); + + struct tm y1970 = {0}; + y1970.tm_year = 70; + y1970.tm_mon = 0; + y1970.tm_mday = 1; + y1970.tm_hour = 0; + y1970.tm_min = 0; + y1970.tm_sec = 0; + time_t epoch = mktime(&y1970); + + // Create a row-batch with some sample data + int num_records = 4; + auto validity = {true, true, true, true}; + std::vector field0_data = {MillisSince(epoch, 2000, 1, 1, 5, 0, 0, 0), + MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0), + MillisSince(epoch, 2015, 6, 30, 20, 0, 0, 0), + MillisSince(epoch, 2015, 7, 1, 20, 0, 0, 0)}; + auto array0 = + MakeArrowTypeArray(date64(), field0_data, validity); + + std::vector field2_data = {MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0), + MillisSince(epoch, 2000, 1, 2, 5, 0, 0, 0), + MillisSince(epoch, 2015, 7, 1, 1, 0, 0, 0), + MillisSince(epoch, 2015, 6, 29, 23, 0, 0, 0)}; + + auto array2 = MakeArrowTypeArray( + arrow::timestamp(arrow::TimeUnit::MILLI), field2_data, validity); + + // expected output + // date 2 year and date 2 month + auto exp_yy_from_date = MakeArrowArrayInt64({2000, 1999, 2015, 2015}, validity); + auto exp_mm_from_date = MakeArrowArrayInt64({1, 12, 6, 7}, validity); + + // ts 2 month and ts 2 day + auto exp_mm_from_ts = MakeArrowArrayInt64({12, 1, 7, 6}, validity); + auto exp_dd_from_ts = MakeArrowArrayInt64({31, 2, 1, 29}, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_yy_from_date, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_date, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_ts, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_dd_from_ts, outputs.at(3)); +} + +TEST_F(TestProjector, TestTime) { + auto field0 = field("f0", time32(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({field0}); + + auto field_min = field("mm", int64()); + auto field_hour = field("hh", int64()); + + // extract day and hour from time32 + auto time2min_expr = + TreeExprBuilder::MakeExpression("extractMinute", {field0}, field_min); + auto time2hour_expr = + TreeExprBuilder::MakeExpression("extractHour", {field0}, field_hour); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {time2min_expr, time2hour_expr}, &projector); + ASSERT_TRUE(status.ok()); + + // create input data + int num_records = 4; + auto validity = {true, true, true, true}; + std::vector field_data = { + MillisInDay(5, 35, 25, 0), // 5:35:25 + MillisInDay(0, 59, 0, 0), // 0:59:12 + MillisInDay(12, 30, 0, 0), // 12:30:0 + MillisInDay(23, 0, 0, 0) // 23:0:0 + }; + auto array = MakeArrowTypeArray( + time32(arrow::TimeUnit::MILLI), field_data, validity); + + // expected output + auto exp_min = MakeArrowArrayInt64({35, 59, 30, 0}, validity); + auto exp_hour = MakeArrowArrayInt64({5, 0, 12, 23}, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_min, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_hour, outputs.at(1)); +} + +TEST_F(TestProjector, TestTimestampDiff) { + auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); + auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({f0, f1}); + + // output fields + auto diff_seconds = field("ss", int32()); + + // get diff + auto diff_secs_expr = + TreeExprBuilder::MakeExpression("timestampdiffSecond", {f0, f1}, diff_seconds); + + auto diff_mins_expr = + TreeExprBuilder::MakeExpression("timestampdiffMinute", {f0, f1}, diff_seconds); + + auto diff_hours_expr = + TreeExprBuilder::MakeExpression("timestampdiffHour", {f0, f1}, diff_seconds); + + auto diff_days_expr = + TreeExprBuilder::MakeExpression("timestampdiffDay", {f0, f1}, diff_seconds); + + auto diff_weeks_expr = + TreeExprBuilder::MakeExpression("timestampdiffWeek", {f0, f1}, diff_seconds); + + auto diff_months_expr = + TreeExprBuilder::MakeExpression("timestampdiffMonth", {f0, f1}, diff_seconds); + + auto diff_quarters_expr = + TreeExprBuilder::MakeExpression("timestampdiffQuarter", {f0, f1}, diff_seconds); + + auto diff_years_expr = + TreeExprBuilder::MakeExpression("timestampdiffYear", {f0, f1}, diff_seconds); + + std::shared_ptr projector; + auto exprs = {diff_secs_expr, diff_mins_expr, diff_hours_expr, diff_days_expr, + diff_weeks_expr, diff_months_expr, diff_quarters_expr, diff_years_expr}; + Status status = Projector::Make(schema, exprs, &projector); + ASSERT_TRUE(status.ok()); + + struct tm y1970 = {0}; + y1970.tm_year = 70; + y1970.tm_mon = 0; + y1970.tm_mday = 1; + y1970.tm_hour = 0; + y1970.tm_min = 0; + y1970.tm_sec = 0; + time_t epoch = mktime(&y1970); + + // 2015-09-10T20:49:42.000 + auto start_millis = MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 0); + // 2017-03-30T22:50:59.050 + auto end_millis = MillisSince(epoch, 2017, 3, 30, 22, 50, 59, 50); + std::vector f0_data = {start_millis, end_millis, + // 2015-09-10T20:49:42.999 + start_millis + 999, + // 2015-09-10T20:49:42.999 + MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 999)}; + std::vector f1_data = {end_millis, start_millis, + // 2015-09-10T20:49:42.999 + start_millis + 999, + // 2015-09-9T21:49:42.999 (23 hours behind) + MillisSince(epoch, 2015, 9, 9, 21, 49, 42, 999)}; + + int num_records = f0_data.size(); + std::vector validity(num_records, true); + auto array0 = MakeArrowTypeArray( + arrow::timestamp(arrow::TimeUnit::MILLI), f0_data, validity); + auto array1 = MakeArrowTypeArray( + arrow::timestamp(arrow::TimeUnit::MILLI), f1_data, validity); + + // expected output + std::vector exp_output; + exp_output.push_back( + MakeArrowArrayInt32({48996077, -48996077, 0, -23 * 3600}, validity)); + exp_output.push_back(MakeArrowArrayInt32({816601, -816601, 0, -23 * 60}, validity)); + exp_output.push_back(MakeArrowArrayInt32({13610, -13610, 0, -23}, validity)); + exp_output.push_back(MakeArrowArrayInt32({567, -567, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({81, -81, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({18, -18, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({6, -6, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({1, -1, 0, 0}, validity)); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + for (uint32_t i = 0; i < exp_output.size(); i++) { + EXPECT_ARROW_ARRAY_EQUALS(exp_output.at(i), outputs.at(i)); + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/filter_test.cc b/cpp/src/gandiva/integ/filter_test.cc new file mode 100644 index 00000000000..5deee9d106a --- /dev/null +++ b/cpp/src/gandiva/integ/filter_test.cc @@ -0,0 +1,290 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/filter.h" +#include +#include "arrow/memory_pool.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestFilter : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestFilter, TestFilterCache) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // same schema and condition, should return the same filter as above. + std::shared_ptr cached_filter; + status = Filter::Make(schema, condition, &cached_filter); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() == filter.get()); + + // schema is different should return a new filter. + auto field2 = field("f2", int32()); + auto different_schema = arrow::schema({field0, field1, field2}); + std::shared_ptr should_be_new_filter; + status = Filter::Make(different_schema, condition, &should_be_new_filter); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() != should_be_new_filter.get()); + + // condition is different, should return a new filter. + auto greater_than_10 = TreeExprBuilder::MakeFunction( + "greater_than", {sum_func, literal_10}, arrow::boolean()); + auto new_condition = TreeExprBuilder::MakeCondition(greater_than_10); + std::shared_ptr should_be_new_filter1; + status = Filter::Make(schema, new_condition, &should_be_new_filter1); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() != should_be_new_filter1.get()); +} + +TEST_F(TestFilter, TestSimple) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestSimpleCustomConfig) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 != f1 + auto condition = TreeExprBuilder::MakeCondition("not_equal", {field0, field1}); + + ConfigurationBuilder config_builder; + std::shared_ptr config = config_builder.build(); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({11, 2, 3, 17}, {true, true, false, true}); + // expected output + auto exp = MakeArrowArrayUint16({0}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestZeroCopy) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // Build condition + auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0}); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayUint16({0, 1, 2}); + + // allocate selection buffers + int64_t data_sz = sizeof(int16_t) * num_records; + std::unique_ptr data(new uint8_t[data_sz]); + std::shared_ptr data_buf = + std::make_shared(data.get(), data_sz); + + std::shared_ptr selection_vector; + status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestZeroCopyNegative) { + ArrayPtr output; + + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // Build expression + auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0}); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayInt16({0, 1, 2}); + + // allocate output buffers + int64_t data_sz = sizeof(int16_t) * num_records; + std::unique_ptr data(new uint8_t[data_sz]); + std::shared_ptr data_buf = + std::make_shared(data.get(), data_sz); + + std::shared_ptr selection_vector; + status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector); + EXPECT_TRUE(status.ok()); + + // the batch can't be empty. + auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0}); + status = filter->Evaluate(*bad_batch, selection_vector); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the selection_vector can't be null. + std::shared_ptr null_selection; + status = filter->Evaluate(*in_batch, null_selection); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the selection vector must be suitably sized. + std::shared_ptr bad_selection; + status = SelectionVector::MakeInt16(num_records - 1, data_buf, &bad_selection); + EXPECT_TRUE(status.ok()); + + status = filter->Evaluate(*in_batch, bad_selection); + EXPECT_EQ(status.code(), StatusCode::Invalid); +} + +TEST_F(TestFilter, TestSimpleSVInt32) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr filter; + Status status = Filter::Make(schema, condition, &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint32({0, 4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr selection_vector; + status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/generate_data.h b/cpp/src/gandiva/integ/generate_data.h new file mode 100644 index 00000000000..fdceece5ec5 --- /dev/null +++ b/cpp/src/gandiva/integ/generate_data.h @@ -0,0 +1,61 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#ifndef GANDIVA_GENERATE_DATA_H +#define GANDIVA_GENERATE_DATA_H + +namespace gandiva { + +template +class DataGenerator { + public: + virtual C_TYPE GenerateData() = 0; +}; + +class Int32DataGenerator : public DataGenerator { + public: + Int32DataGenerator() : seed_(100) {} + + int32_t GenerateData() { return rand_r(&seed_); } + + protected: + unsigned int seed_; +}; + +class BoundedInt32DataGenerator : public Int32DataGenerator { + public: + explicit BoundedInt32DataGenerator(uint32_t upperBound) + : Int32DataGenerator(), upperBound_(upperBound) {} + + int32_t GenerateData() { return (rand_r(&seed_) % upperBound_); } + + protected: + uint32_t upperBound_; +}; + +class Int64DataGenerator : public DataGenerator { + public: + Int64DataGenerator() : seed_(100) {} + + int64_t GenerateData() { return rand_r(&seed_); } + + protected: + unsigned int seed_; +}; + +} // namespace gandiva + +#endif // GANDIVA_GENERATE_DATA_H diff --git a/cpp/src/gandiva/integ/hash_test.cc b/cpp/src/gandiva/integ/hash_test.cc new file mode 100644 index 00000000000..b7b4f470e04 --- /dev/null +++ b/cpp/src/gandiva/integ/hash_test.cc @@ -0,0 +1,142 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; +using arrow::int64; +using arrow::utf8; + +class TestHash : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestHash, TestSimple) { + // schema for input fields + auto field_a = field("a", int32()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", int32()); + auto res_1 = field("res1", int64()); + + // build expression. + // hash32(a, 10) + // hash64(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a, literal_10}, int32()); + auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a}, int64()); + auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0); + auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1); + + // Build a projector for the expression. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr_0, expr_1}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayInt32({1, 2, 3, 4}, {false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast(outputs.at(0)); + EXPECT_EQ(int32_arr->null_count(), 0); + EXPECT_EQ(int32_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1)); + } + + auto int64_arr = std::dynamic_pointer_cast(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + EXPECT_EQ(int64_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1)); + } +} + +TEST_F(TestHash, TestBuf) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", int32()); + auto res_1 = field("res1", int64()); + + // build expressions. + // hash32(a) + // hash64(a, 10) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_10 = TreeExprBuilder::MakeLiteral((int64_t)10); + auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a}, int32()); + auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a, literal_10}, int64()); + auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0); + auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr_0, expr_1}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast(outputs.at(0)); + EXPECT_EQ(int32_arr->null_count(), 0); + EXPECT_EQ(int32_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1)); + } + + auto int64_arr = std::dynamic_pointer_cast(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + EXPECT_EQ(int64_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1)); + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/if_expr_test.cc b/cpp/src/gandiva/integ/if_expr_test.cc new file mode 100644 index 00000000000..1be28f203f0 --- /dev/null +++ b/cpp/src/gandiva/integ/if_expr_test.cc @@ -0,0 +1,310 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestIfExpr : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestIfExpr, TestSimple) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestSimpleArithmetic) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else + // a - b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else if (a < b) + // a - b + // else + // a * b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition_gt = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto condition_lt = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32()); + auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNestedInIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > 10) + // if (a < 20) + // a + b + // else + // b + c + // else + // a + c + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + + auto literal_10 = TreeExprBuilder::MakeLiteral(10); + auto literal_20 = TreeExprBuilder::MakeLiteral(20); + + auto gt_10 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean()); + auto lt_20 = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean()); + auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32()); + auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32()); + + auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32()); + auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 6; + auto array_a = + MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true}); + auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19}, + {true, true, true, false, false, false}); + auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55}, + {true, true, true, false, false, false}); + + // expected output + auto exp = + MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false}); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestBigNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a < 10) + // 10 + // else if (a < 20) + // 20 + // .. + // .. + // else if (a < 190) + // 190 + // else + // 200 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto top_node = TreeExprBuilder::MakeLiteral(200); + for (int thresh = 190; thresh > 0; thresh -= 10) { + auto literal = TreeExprBuilder::MakeLiteral(thresh); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32()); + top_node = if_node; + } + auto expr = TreeExprBuilder::MakeExpression(top_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/literal_test.cc b/cpp/src/gandiva/integ/literal_test.cc new file mode 100644 index 00000000000..6c2d0a9f4a4 --- /dev/null +++ b/cpp/src/gandiva/integ/literal_test.cc @@ -0,0 +1,228 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; + +class TestLiteral : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestLiteral, TestSimpleArithmetic) { + // schema for input fields + auto field_a = field("a", boolean()); + auto field_b = field("b", int32()); + auto field_c = field("c", int64()); + auto field_d = field("d", float32()); + auto field_e = field("e", float64()); + auto schema = arrow::schema({field_a, field_b, field_c, field_d, field_e}); + + // output fields + auto res_a = field("a+1", boolean()); + auto res_b = field("b+1", int32()); + auto res_c = field("c+1", int64()); + auto res_d = field("d+1", float32()); + auto res_e = field("e+1", float64()); + + // build expressions. + // a == true + // b + 1 + // c + 1 + // d + 1 + // e + 1 + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_a = TreeExprBuilder::MakeLiteral(true); + auto func_a = TreeExprBuilder::MakeFunction("equal", {node_a, literal_a}, boolean()); + auto expr_a = TreeExprBuilder::MakeExpression(func_a, res_a); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto literal_b = TreeExprBuilder::MakeLiteral((int32_t)1); + auto func_b = TreeExprBuilder::MakeFunction("add", {node_b, literal_b}, int32()); + auto expr_b = TreeExprBuilder::MakeExpression(func_b, res_b); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto literal_c = TreeExprBuilder::MakeLiteral((int64_t)1); + auto func_c = TreeExprBuilder::MakeFunction("add", {node_c, literal_c}, int64()); + auto expr_c = TreeExprBuilder::MakeExpression(func_c, res_c); + + auto node_d = TreeExprBuilder::MakeField(field_d); + auto literal_d = TreeExprBuilder::MakeLiteral(static_cast(1)); + auto func_d = TreeExprBuilder::MakeFunction("add", {node_d, literal_d}, float32()); + auto expr_d = TreeExprBuilder::MakeExpression(func_d, res_d); + + auto node_e = TreeExprBuilder::MakeField(field_e); + auto literal_e = TreeExprBuilder::MakeLiteral(static_cast(1)); + auto func_e = TreeExprBuilder::MakeFunction("add", {node_e, literal_e}, float64()); + auto expr_e = TreeExprBuilder::MakeExpression(func_e, res_e); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = + Projector::Make(schema, {expr_a, expr_b, expr_c, expr_d, expr_e}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayBool({true, true, false, true}, {true, true, true, false}); + auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + auto array_c = MakeArrowArrayInt64({5, 15, -15, 17}, {true, true, true, false}); + auto array_d = MakeArrowArrayFloat32({5.2, 15, -15.6, 17}, {true, true, true, false}); + auto array_e = MakeArrowArrayFloat64({5.6, 15, -15.9, 17}, {true, true, true, false}); + + // expected output + auto exp_a = MakeArrowArrayBool({true, true, false, false}, {true, true, true, false}); + auto exp_b = MakeArrowArrayInt32({6, 16, -14, 0}, {true, true, true, false}); + auto exp_c = MakeArrowArrayInt64({6, 16, -14, 0}, {true, true, true, false}); + auto exp_d = MakeArrowArrayFloat32({6.2, 16, -14.6, 0}, {true, true, true, false}); + auto exp_e = MakeArrowArrayFloat64({6.6, 16, -14.9, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, + {array_a, array_b, array_c, array_d, array_e}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_a, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_b, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_c, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_d, outputs.at(3)); + EXPECT_ARROW_ARRAY_EQUALS(exp_e, outputs.at(4)); +} + +TEST_F(TestLiteral, TestLiteralHash) { + auto schema = arrow::schema({}); + // output fields + auto res = field("a", int32()); + auto int_literal = TreeExprBuilder::MakeLiteral((int32_t)2); + auto expr = TreeExprBuilder::MakeExpression(int_literal, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + auto res1 = field("a", int64()); + auto int_literal1 = TreeExprBuilder::MakeLiteral((int64_t)2); + auto expr1 = TreeExprBuilder::MakeExpression(int_literal1, res1); + + // Build a projector for the expressions. + std::shared_ptr projector1; + status = Projector::Make(schema, {expr1}, &projector1); + EXPECT_TRUE(status.ok()) << status.message(); + EXPECT_TRUE(projector.get() != projector1.get()); +} + +TEST_F(TestLiteral, TestNullLiteral) { + // schema for input fields + auto field_a = field("a", int32()); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("a+b+null", int32()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto literal_c = TreeExprBuilder::MakeNull(arrow::int32()); + auto add_a_b = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto add_a_b_c = TreeExprBuilder::MakeFunction("add", {add_a_b, literal_c}, int32()); + auto expr = TreeExprBuilder::MakeExpression(add_a_b_c, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayInt32({0, 0, 0, 0}, {false, false, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestLiteral, TestNullLiteralInIf) { + // schema for input fields + auto field_a = field("a", float64()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", float64()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_5 = TreeExprBuilder::MakeLiteral((double_t)5); + auto a_gt_5 = TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_5}, + arrow::boolean()); + auto literal_null = TreeExprBuilder::MakeNull(arrow::float64()); + auto if_node = + TreeExprBuilder::MakeIf(a_gt_5, literal_5, literal_null, arrow::float64()); + auto expr = TreeExprBuilder::MakeExpression(if_node, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayFloat64({6, 15, -15, 17}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayFloat64({5, 5, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/micro_benchmarks.cc b/cpp/src/gandiva/integ/micro_benchmarks.cc new file mode 100644 index 00000000000..1568ad0e369 --- /dev/null +++ b/cpp/src/gandiva/integ/micro_benchmarks.cc @@ -0,0 +1,113 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" +#include "integ/timed_evaluate.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; +using arrow::int64; + +class TestBenchmarks : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestBenchmarks, TimedTestAdd3) { + // schema for input fields + auto field0 = field("f0", int64()); + auto field1 = field("f1", int64()); + auto field2 = field("f2", int64()); + auto schema = arrow::schema({field0, field1, field2}); + + // output field + auto field_sum = field("add", int64()); + + // Build expression + auto part_sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)}, + int64()); + auto sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field0), part_sum}, int64()); + + auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {sum_expr}, &projector); + EXPECT_TRUE(status.ok()); + + int64_t elapsed_millis; + Int64DataGenerator data_generator; + status = TimedEvaluate(schema, projector, data_generator, + pool_, 100 * MILLION, 16 * THOUSAND, + elapsed_millis); + ASSERT_TRUE(status.ok()); + std::cout << "Time taken for Add3 " << elapsed_millis << " ms\n"; +} + +TEST_F(TestBenchmarks, TimedTestBigNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a < 10) + // 10 + // else if (a < 20) + // 20 + // .. + // .. + // else if (a < 190) + // 190 + // else + // 200 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto top_node = TreeExprBuilder::MakeLiteral(200); + for (int thresh = 190; thresh > 0; thresh -= 10) { + auto literal = TreeExprBuilder::MakeLiteral(thresh); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32()); + top_node = if_node; + } + auto expr = TreeExprBuilder::MakeExpression(top_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()); + + int64_t elapsed_millis; + BoundedInt32DataGenerator data_generator(250); + status = TimedEvaluate(schema, projector, data_generator, + pool_, 100 * MILLION, 16 * THOUSAND, + elapsed_millis); + ASSERT_TRUE(status.ok()); + std::cout << "Time taken for BigNestedIf " << elapsed_millis << " ms\n"; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/projector_build_validation_test.cc b/cpp/src/gandiva/integ/projector_build_validation_test.cc new file mode 100644 index 00000000000..cab9c5a2646 --- /dev/null +++ b/cpp/src/gandiva/integ/projector_build_validation_test.cc @@ -0,0 +1,295 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestProjector, TestNonExistentFunction) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = TreeExprBuilder::MakeExpression("non_existent_function", + {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Function bool non_existent_function(float, float) not supported yet."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestNotMatchingDataType) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Return type of root node float does not match that of expression bool"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestNotSupportedDataType) { + // schema for input fields + auto field0 = field("f0", list(int32())); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", list(int32())); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f0 has unsupported data type list"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaMissingField) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f2 not in schema"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto field2 = field("f2", int32()); + auto schema = arrow::schema({field0, field2}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Field definition in schema f2: int32 different from field in expression f2: float"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIfNotSupportedFunction) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("non_existent_function", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestIfNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, boolean()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Return type of if bool and then int32 not matching."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestElseNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", boolean()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Return type of if int32 and else bool not matching."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestElseNotSupportedType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", list(int32())); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field c has unsupported data type list"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestAndMinChildren) { + // schema for input fields + auto fielda = field("a", boolean()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", boolean()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto and_node = TreeExprBuilder::MakeAnd({node_a}); + + auto expr = TreeExprBuilder::MakeExpression(and_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Boolean expression has 1 children, expected atleast two"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestAndBooleanArgType) { + // schema for input fields + auto fielda = field("a", boolean()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto and_node = TreeExprBuilder::MakeAnd({node_a, node_b}); + + auto expr = TreeExprBuilder::MakeExpression(and_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Boolean expression has a child with return type int32, expected return type " + "boolean"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/projector_test.cc b/cpp/src/gandiva/integ/projector_test.cc new file mode 100644 index 00000000000..d0fe6166378 --- /dev/null +++ b/cpp/src/gandiva/integ/projector_test.cc @@ -0,0 +1,522 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gandiva/projector.h" +#include +#include "arrow/memory_pool.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestProjector, TestProjectCache) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", int32()); + auto field_sub = field("subtract", int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {sum_expr, sub_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // everything is same, should return the same projector. + auto schema_same = arrow::schema({field0, field1}); + std::shared_ptr cached_projector; + status = Projector::Make(schema_same, {sum_expr, sub_expr}, &cached_projector); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_projector.get() == projector.get()); + + // schema is different should return a new projector. + auto field2 = field("f2", int32()); + auto different_schema = arrow::schema({field0, field1, field2}); + std::shared_ptr should_be_new_projector; + status = + Projector::Make(different_schema, {sum_expr, sub_expr}, &should_be_new_projector); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_projector.get() != should_be_new_projector.get()); + + // expression list is different should return a new projector. + std::shared_ptr should_be_new_projector1; + status = Projector::Make(schema, {sum_expr}, &should_be_new_projector1); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_projector.get() != should_be_new_projector1.get()); +} + +TEST_F(TestProjector, TestIntSumSub) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", int32()); + auto field_sub = field("subtract", int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {sum_expr, sub_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({11, 13, 15, 17}, {true, true, false, true}); + // expected output + auto exp_sum = MakeArrowArrayInt32({12, 15, 0, 0}, {true, true, false, false}); + auto exp_sub = MakeArrowArrayInt32({-10, -11, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1)); +} + +TEST_F(TestProjector, TestIntSumSubCustomConfig) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", int32()); + auto field_sub = field("subtract", int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + + std::shared_ptr projector; + ConfigurationBuilder config_builder; + std::shared_ptr config = config_builder.build(); + + Status status = Projector::Make(schema, {sum_expr, sub_expr}, config, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({11, 13, 15, 17}, {true, true, false, true}); + // expected output + auto exp_sum = MakeArrowArrayInt32({12, 15, 0, 0}, {true, true, false, false}); + auto exp_sub = MakeArrowArrayInt32({-10, -11, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1)); +} + +template +static void TestArithmeticOpsForType(arrow::MemoryPool* pool) { + auto atype = arrow::TypeTraits::type_singleton(); + + // schema for input fields + auto field0 = field("f0", atype); + auto field1 = field("f1", atype); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", atype); + auto field_sub = field("subtract", atype); + auto field_mul = field("multiply", atype); + auto field_div = field("divide", atype); + auto field_eq = field("equal", arrow::boolean()); + auto field_lt = field("less_than", arrow::boolean()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + auto mul_expr = + TreeExprBuilder::MakeExpression("multiply", {field0, field1}, field_mul); + auto div_expr = TreeExprBuilder::MakeExpression("divide", {field0, field1}, field_div); + auto eq_expr = TreeExprBuilder::MakeExpression("equal", {field0, field1}, field_eq); + auto lt_expr = TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_lt); + + std::shared_ptr projector; + Status status = Projector::Make( + schema, {sum_expr, sub_expr, mul_expr, div_expr, eq_expr, lt_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + std::vector input0 = {1, 2, 53, 84}; + std::vector input1 = {10, 15, 23, 84}; + std::vector validity = {true, true, true, true}; + + auto array0 = MakeArrowArray(input0, validity); + auto array1 = MakeArrowArray(input1, validity); + + // expected output + std::vector sum; + std::vector sub; + std::vector mul; + std::vector div; + std::vector eq; + std::vector lt; + for (int i = 0; i < num_records; i++) { + sum.push_back(input0[i] + input1[i]); + sub.push_back(input0[i] - input1[i]); + mul.push_back(input0[i] * input1[i]); + div.push_back(input0[i] / input1[i]); + eq.push_back(input0[i] == input1[i]); + lt.push_back(input0[i] < input1[i]); + } + auto exp_sum = MakeArrowArray(sum, validity); + auto exp_sub = MakeArrowArray(sub, validity); + auto exp_mul = MakeArrowArray(mul, validity); + auto exp_div = MakeArrowArray(div, validity); + auto exp_eq = MakeArrowArray(eq, validity); + auto exp_lt = MakeArrowArray(lt, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mul, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_div, outputs.at(3)); + EXPECT_ARROW_ARRAY_EQUALS(exp_eq, outputs.at(4)); + EXPECT_ARROW_ARRAY_EQUALS(exp_lt, outputs.at(5)); +} + +TEST_F(TestProjector, TestAllIntTypes) { + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); + TestArithmeticOpsForType(pool_); +} + +TEST_F(TestProjector, TestFloatLessThan) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = MakeArrowArrayFloat32({1.0, 8.9, 3.0}, {true, true, false}); + auto array1 = MakeArrowArrayFloat32({4.0, 3.4, 6.8}, {true, true, true}); + // expected output + auto exp = MakeArrowArrayBool({true, false, false}, {true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestIsNotNull) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto myexpr = TreeExprBuilder::MakeExpression("isnotnull", {field0}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {myexpr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = MakeArrowArrayFloat32({1.0, 8.9, 3.0}, {true, true, false}); + // expected output + auto exp = MakeArrowArrayBool({true, true, false}, {true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestNullInternal) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + auto myexpr = TreeExprBuilder::MakeExpression("half_or_null", {field0}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {myexpr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = + MakeArrowArrayInt32({10, 10, -20, 5, -7}, {true, false, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({5, 0, -10, 0, 0}, {true, false, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestNestedFunctions) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_res1 = field("res1", int32()); + auto field_res2 = field("res2", boolean()); + + // build expression. + // expr1 : half_or_null(f0) * f1 + // expr2 : isnull(half_or_null(f0) * f1) + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto half = TreeExprBuilder::MakeFunction("half_or_null", {node_f0}, int32()); + auto mult = TreeExprBuilder::MakeFunction("multiply", {half, node_f1}, int32()); + auto expr1 = TreeExprBuilder::MakeExpression(mult, field_res1); + + auto isnull = TreeExprBuilder::MakeFunction("isnull", {mult}, boolean()); + auto expr2 = TreeExprBuilder::MakeExpression(isnull, field_res2); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr1, expr2}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 10, -20, 5}, {true, false, true, true}); + auto array1 = MakeArrowArrayInt32({11, 13, 15, 17}, {true, true, false, true}); + + // expected output + auto exp1 = MakeArrowArrayInt32({55, 65, -150, 0}, {true, false, false, false}); + auto exp2 = MakeArrowArrayBool({false, true, true, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp2, outputs.at(1)); +} + +TEST_F(TestProjector, TestZeroCopy) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // output fields + auto res = field("res", float32()); + + // Build expression + auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {cast_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + + // allocate output buffers + int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records); + std::unique_ptr bitmap(new uint8_t[bitmap_sz]); + std::shared_ptr bitmap_buf = + std::make_shared(bitmap.get(), bitmap_sz); + + int64_t data_sz = sizeof(float) * num_records; + std::unique_ptr data(new uint8_t[data_sz]); + std::shared_ptr data_buf = + std::make_shared(data.get(), data_sz); + + auto array_data = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf}); + + // Evaluate expression + status = projector->Evaluate(*in_batch, {array_data}); + EXPECT_TRUE(status.ok()); + + // Validate results + auto output = arrow::MakeArray(array_data); + EXPECT_ARROW_ARRAY_EQUALS(exp, output); +} + +TEST_F(TestProjector, TestZeroCopyNegative) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // output fields + auto res = field("res", float32()); + + // Build expression + auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res); + + std::shared_ptr projector; + Status status = Projector::Make(schema, {cast_expr}, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + + // allocate output buffers + int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records); + std::unique_ptr bitmap(new uint8_t[bitmap_sz]); + std::shared_ptr bitmap_buf = + std::make_shared(bitmap.get(), bitmap_sz); + + int64_t data_sz = sizeof(float) * num_records; + std::unique_ptr data(new uint8_t[data_sz]); + std::shared_ptr data_buf = + std::make_shared(data.get(), data_sz); + + auto array_data = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf}); + + // the batch can't be empty. + auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0}); + status = projector->Evaluate(*bad_batch, {array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output array can't be null. + std::shared_ptr null_array_data; + status = projector->Evaluate(*in_batch, {null_array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output array must have atleast two buffers. + auto bad_array_data = arrow::ArrayData::Make(float32(), num_records, {bitmap_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output buffers must have sufficiently sized data_buf. + std::shared_ptr bad_data_buf = + std::make_shared(data.get(), data_sz - 1); + auto bad_array_data2 = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, bad_data_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data2}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output buffers must have sufficiently sized bitmap_buf. + std::shared_ptr bad_bitmap_buf = + std::make_shared(bitmap.get(), bitmap_sz - 1); + auto bad_array_data3 = + arrow::ArrayData::Make(float32(), num_records, {bad_bitmap_buf, data_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data3}); + EXPECT_EQ(status.code(), StatusCode::Invalid); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/test_util.h b/cpp/src/gandiva/integ/test_util.h new file mode 100644 index 00000000000..91c4f1b31f6 --- /dev/null +++ b/cpp/src/gandiva/integ/test_util.h @@ -0,0 +1,75 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "arrow/test-util.h" +#include "gandiva/arrow.h" + +#ifndef GANDIVA_TEST_UTIL_H +#define GANDIVA_TEST_UTIL_H + +namespace gandiva { + +// Helper function to create an arrow-array of type ARROWTYPE +// from primitive vectors of data & validity. +// +// arrow/test-util.h has good utility classes for this purpose. +// Using those +template +static ArrayPtr MakeArrowArray(std::vector values, std::vector validity) { + ArrayPtr out; + arrow::ArrayFromVector(validity, values, &out); + return out; +} + +template +static ArrayPtr MakeArrowArray(std::vector values) { + ArrayPtr out; + arrow::ArrayFromVector(values, &out); + return out; +} + +template +static ArrayPtr MakeArrowTypeArray(const std::shared_ptr &type, + const std::vector &values, + const std::vector &validity) { + ArrayPtr out; + arrow::ArrayFromVector(type, validity, values, &out); + return out; +} + +#define MakeArrowArrayBool MakeArrowArray +#define MakeArrowArrayInt8 MakeArrowArray +#define MakeArrowArrayInt16 MakeArrowArray +#define MakeArrowArrayInt32 MakeArrowArray +#define MakeArrowArrayInt64 MakeArrowArray +#define MakeArrowArrayUint8 MakeArrowArray +#define MakeArrowArrayUint16 MakeArrowArray +#define MakeArrowArrayUint32 MakeArrowArray +#define MakeArrowArrayUint64 MakeArrowArray +#define MakeArrowArrayFloat32 MakeArrowArray +#define MakeArrowArrayFloat64 MakeArrowArray +#define MakeArrowArrayUtf8 MakeArrowArray +#define MakeArrowArrayBinary MakeArrowArray + +#define EXPECT_ARROW_ARRAY_EQUALS(a, b) \ + EXPECT_TRUE((a)->Equals(b)) << "expected array: " << (a)->ToString() \ + << " actual array: " << (b)->ToString(); + +} // namespace gandiva + +#endif // GANDIVA_TEST_UTIL_H diff --git a/cpp/src/gandiva/integ/timed_evaluate.h b/cpp/src/gandiva/integ/timed_evaluate.h new file mode 100644 index 00000000000..966977b608f --- /dev/null +++ b/cpp/src/gandiva/integ/timed_evaluate.h @@ -0,0 +1,91 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "gandiva/arrow.h" +#include "gandiva/projector.h" +#include "integ/generate_data.h" + +#ifndef GANDIVA_TIMED_EVALUATE_H +#define GANDIVA_TIMED_EVALUATE_H + +#define THOUSAND (1024) +#define MILLION (1024 * 1024) + +namespace gandiva { + +template +std::vector GenerateData(int num_records, DataGenerator &data_generator) { + std::vector data; + + for (int i = 0; i < num_records; i++) { + data.push_back(data_generator.GenerateData()); + } + + return data; +} + +template +Status TimedEvaluate(SchemaPtr schema, std::shared_ptr projector, + DataGenerator &data_generator, arrow::MemoryPool *pool, + int num_records, int batch_size, int64_t &num_millis) { + int num_remaining = num_records; + int num_fields = schema->num_fields(); + int num_calls = 0; + Status status; + std::chrono::duration micros(0); + std::chrono::time_point start; + std::chrono::time_point finish; + + while (num_remaining > 0) { + int num_in_batch = batch_size; + if (batch_size > num_remaining) { + num_in_batch = num_remaining; + } + + // generate data for all columns in the schema + std::vector columns; + for (int col = 0; col < num_fields; col++) { + std::vector data = GenerateData(num_in_batch, data_generator); + std::vector validity(num_in_batch, true); + ArrayPtr col_data = MakeArrowArray(data, validity); + + columns.push_back(col_data); + } + + // make the record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_in_batch, columns); + + // evaluate + arrow::ArrayVector outputs; + start = std::chrono::high_resolution_clock::now(); + status = projector->Evaluate(*in_batch, pool, &outputs); + finish = std::chrono::high_resolution_clock::now(); + if (!status.ok()) { + return status; + } + + micros += std::chrono::duration_cast(finish - start); + num_calls++; + num_remaining -= num_in_batch; + } + + num_millis = micros.count() / 1000; + return Status::OK(); +} + +} // namespace gandiva + +#endif // GANDIVA_TIMED_EVALUATE_H diff --git a/cpp/src/gandiva/integ/to_string_test.cc b/cpp/src/gandiva/integ/to_string_test.cc new file mode 100644 index 00000000000..a5d1260b633 --- /dev/null +++ b/cpp/src/gandiva/integ/to_string_test.cc @@ -0,0 +1,81 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float64; +using arrow::int32; +using arrow::int64; + +class TestToString : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +#define CHECK_EXPR_TO_STRING(e, str) EXPECT_STREQ(e->ToString().c_str(), str) + +TEST_F(TestToString, TestAll) { + auto literal_node = TreeExprBuilder::MakeLiteral((uint64_t)100); + auto literal_expr = + TreeExprBuilder::MakeExpression(literal_node, arrow::field("r", int64())); + CHECK_EXPR_TO_STRING(literal_expr, "(uint64) 100"); + + auto f0 = arrow::field("f0", float64()); + auto f0_node = TreeExprBuilder::MakeField(f0); + auto f0_expr = TreeExprBuilder::MakeExpression(f0_node, f0); + CHECK_EXPR_TO_STRING(f0_expr, "double"); + + auto f1 = arrow::field("f1", int64()); + auto f2 = arrow::field("f2", int64()); + auto f1_node = TreeExprBuilder::MakeField(f1); + auto f2_node = TreeExprBuilder::MakeField(f2); + auto add_node = TreeExprBuilder::MakeFunction("add", {f1_node, f2_node}, int64()); + auto add_expr = TreeExprBuilder::MakeExpression(add_node, f1); + CHECK_EXPR_TO_STRING(add_expr, "int64 add(int64, int64)"); + + auto cond_node = TreeExprBuilder::MakeFunction( + "lesser_than", {f0_node, TreeExprBuilder::MakeLiteral((float)0)}, boolean()); + auto then_node = TreeExprBuilder::MakeField(f1); + auto else_node = TreeExprBuilder::MakeField(f2); + + auto if_node = TreeExprBuilder::MakeIf(cond_node, then_node, else_node, int64()); + auto if_expr = TreeExprBuilder::MakeExpression(if_node, f1); + CHECK_EXPR_TO_STRING( + if_expr, "if (bool lesser_than(double, (float) 0)) { int64 } else { int64 }"); + + auto f1_gt_100 = + TreeExprBuilder::MakeFunction("greater_than", {f1_node, literal_node}, boolean()); + auto f2_equals_100 = + TreeExprBuilder::MakeFunction("equals", {f2_node, literal_node}, boolean()); + auto and_node = TreeExprBuilder::MakeAnd({f1_gt_100, f2_equals_100}); + auto and_expr = + TreeExprBuilder::MakeExpression(and_node, arrow::field("f0", boolean())); + CHECK_EXPR_TO_STRING( + and_expr, + "bool greater_than(int64, (uint64) 100) && bool equals(int64, (uint64) 100)"); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/integ/utf8_test.cc b/cpp/src/gandiva/integ/utf8_test.cc new file mode 100644 index 00000000000..e373f0285cf --- /dev/null +++ b/cpp/src/gandiva/integ/utf8_test.cc @@ -0,0 +1,211 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/status.h" +#include "gandiva/tree_expr_builder.h" +#include "integ/test_util.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; +using arrow::utf8; + +class TestUtf8 : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestUtf8, TestSimple) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_1 = field("res1", int32()); + auto res_2 = field("res2", boolean()); + + // build expressions. + // octet_length(a) + // octet_length(a) == bit_length(a) / 8 + auto expr_a = TreeExprBuilder::MakeExpression("octet_length", {field_a}, res_1); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto octet_length = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32()); + auto literal_8 = TreeExprBuilder::MakeLiteral((int32_t)8); + auto bit_length = TreeExprBuilder::MakeFunction("bit_length", {node_a}, int32()); + auto div_8 = TreeExprBuilder::MakeFunction("divide", {bit_length, literal_8}, int32()); + auto is_equal = + TreeExprBuilder::MakeFunction("equal", {octet_length, div_8}, boolean()); + auto expr_b = TreeExprBuilder::MakeExpression(is_equal, res_2); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr_a, expr_b}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, false, true}); + + // expected output + auto exp_1 = MakeArrowArrayInt32({3, 5, 0, 2}, {true, true, false, true}); + auto exp_2 = MakeArrowArrayBool({true, true, false, true}, {true, true, false, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs.at(1)); +} + +TEST_F(TestUtf8, TestLiteral) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // a == literal(s) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("hello"); + auto is_equal = TreeExprBuilder::MakeFunction("equal", {node_a, literal_s}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_equal, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, false, false}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestNullLiteral) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // a == literal(null) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_null = TreeExprBuilder::MakeNull(arrow::utf8()); + auto is_equal = + TreeExprBuilder::MakeFunction("equal", {node_a, literal_null}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_equal, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false}); + + // expected output + auto exp = + MakeArrowArrayBool({false, false, false, false}, {false, false, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestLike) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // like(literal(s), a) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("%spark%"); + auto is_like = TreeExprBuilder::MakeFunction("like", {node_a, literal_s}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_like, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "spark"}, + {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, true, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/jni/CMakeLists.txt b/cpp/src/gandiva/jni/CMakeLists.txt new file mode 100644 index 00000000000..e119b663947 --- /dev/null +++ b/cpp/src/gandiva/jni/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (C) 2017-2018 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project(gandiva_jni) + +# Find protobuf +set(Protobuf_USE_STATIC_LIBS "ON") +find_package(Protobuf REQUIRED) + +# Find JNI +find_package(JNI REQUIRED) + +# generate the protobuf files from the proto definition. +protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${CMAKE_SOURCE_DIR}/../proto/Types.proto) + +# Create the jni header file (from the java class). +set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/java") +add_subdirectory(../../../java ./java) + +add_library(gandiva_jni SHARED + config_builder.cc + config_holder.cc + expression_registry_helper.cc + jni_common.cc + ${PROTO_SRCS} + ${PROTO_HDRS}) +add_dependencies(gandiva_jni gandiva_java) + +# For users of gandiva_jni library (including integ tests), include-dir is : +# /usr/**/include dir after install, +# cpp/include during build +# For building gandiva_jni library itself, include-dir (in addition to above) is : +# cpp/src +target_include_directories(gandiva_jni + PUBLIC + $ + $ + ${JNI_HEADERS_DIR} + PRIVATE + ${JNI_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/src +) + +# PROTOBUF is a private dependency i.e users of gandiva also will not have a dependency on protobuf. +target_link_libraries(gandiva_jni + PRIVATE + protobuf::libprotobuf + gandiva_static +) diff --git a/cpp/src/gandiva/jni/config_builder.cc b/cpp/src/gandiva/jni/config_builder.cc new file mode 100644 index 00000000000..5cc9c7ce1aa --- /dev/null +++ b/cpp/src/gandiva/jni/config_builder.cc @@ -0,0 +1,63 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "gandiva/configuration.h" +#include "jni/config_holder.h" +#include "jni/env_helper.h" +#include "jni/org_apache_arrow_gandiva_evaluator_ConfigurationBuilder.h" + +using gandiva::ConfigHolder; +using gandiva::Configuration; +using gandiva::ConfigurationBuilder; + +/* + * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder + * Method: buildConfigInstance + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_buildConfigInstance( + JNIEnv *env, jobject configuration) { + jstring byte_code_file_path = + (jstring)env->CallObjectMethod(configuration, byte_code_accessor_method_id_, 0); + jstring helper_library_file_path = (jstring)env->CallObjectMethod( + configuration, helper_library_accessor_method_id_, 0); + ConfigurationBuilder configuration_builder; + if (byte_code_file_path != nullptr) { + const char *byte_code_file_path_cpp = env->GetStringUTFChars(byte_code_file_path, 0); + configuration_builder.set_byte_code_file_path(byte_code_file_path_cpp); + env->ReleaseStringUTFChars(byte_code_file_path, byte_code_file_path_cpp); + } + if (helper_library_file_path != nullptr) { + const char *helper_library_file_path_cpp = + env->GetStringUTFChars(helper_library_file_path, 0); + configuration_builder.set_helper_lib_file_path(helper_library_file_path_cpp); + env->ReleaseStringUTFChars(helper_library_file_path, helper_library_file_path_cpp); + } + std::shared_ptr config = configuration_builder.build(); + env->DeleteLocalRef(byte_code_file_path); + return ConfigHolder::MapInsert(config); +} + +/* + * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder + * Method: releaseConfigInstance + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_releaseConfigInstance( + JNIEnv *env, jobject configuration, jlong config_id) { + ConfigHolder::MapErase(config_id); +} diff --git a/cpp/src/gandiva/jni/config_holder.cc b/cpp/src/gandiva/jni/config_holder.cc new file mode 100644 index 00000000000..a0938003b00 --- /dev/null +++ b/cpp/src/gandiva/jni/config_holder.cc @@ -0,0 +1,26 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "jni/config_holder.h" + +#include + +namespace gandiva { +int64_t ConfigHolder::config_id_ = 1; + +// map of configuration objects created so far +std::unordered_map> + ConfigHolder::configuration_map_; + +std::mutex ConfigHolder::g_mtx_; +} // namespace gandiva diff --git a/cpp/src/gandiva/jni/config_holder.h b/cpp/src/gandiva/jni/config_holder.h new file mode 100644 index 00000000000..26fbc7711bb --- /dev/null +++ b/cpp/src/gandiva/jni/config_holder.h @@ -0,0 +1,66 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CONFIG_HOLDER_H +#define CONFIG_HOLDER_H + +#include +#include +#include +#include + +#include "gandiva/configuration.h" + +namespace gandiva { + +class ConfigHolder { + public: + static int64_t MapInsert(std::shared_ptr config) { + g_mtx_.lock(); + + int64_t result = config_id_++; + configuration_map_.insert( + std::pair>(result, config)); + + g_mtx_.unlock(); + return result; + } + + static void MapErase(int64_t config_id_) { + g_mtx_.lock(); + configuration_map_.erase(config_id_); + g_mtx_.unlock(); + } + + static std::shared_ptr MapLookup(int64_t config_id_) { + std::shared_ptr result = nullptr; + + try { + result = configuration_map_.at(config_id_); + } catch (const std::out_of_range& e) { + } + + return result; + } + + private: + // map of configuration objects created so far + static std::unordered_map> configuration_map_; + + static std::mutex g_mtx_; + + // atomic counter for projector module ids + static int64_t config_id_; +}; +} // namespace gandiva +#endif // CONFIG_HOLDER_H diff --git a/cpp/src/gandiva/jni/env_helper.h b/cpp/src/gandiva/jni/env_helper.h new file mode 100644 index 00000000000..2fe96670284 --- /dev/null +++ b/cpp/src/gandiva/jni/env_helper.h @@ -0,0 +1,26 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef ENV_HELPER_H +#define ENV_HELPER_H + +#include + +// class references +extern jclass configuration_builder_class_; + +// method references +extern jmethodID byte_code_accessor_method_id_; +extern jmethodID helper_library_accessor_method_id_; + +#endif // ENV_HELPER_H diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/cpp/src/gandiva/jni/expression_registry_helper.cc new file mode 100644 index 00000000000..9270f87eec2 --- /dev/null +++ b/cpp/src/gandiva/jni/expression_registry_helper.cc @@ -0,0 +1,178 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "jni/org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper.h" + +#include + +#include "Types.pb.h" +#include "gandiva/arrow.h" +#include "gandiva/expression_registry.h" + +using gandiva::DataTypePtr; +using gandiva::ExpressionRegistry; + +types::TimeUnit MapTimeUnit(arrow::TimeUnit::type &unit) { + switch (unit) { + case arrow::TimeUnit::MILLI: + return types::TimeUnit::MILLISEC; + case arrow::TimeUnit::SECOND: + return types::TimeUnit::SEC; + case arrow::TimeUnit::MICRO: + return types::TimeUnit::MICROSEC; + case arrow::TimeUnit::NANO: + return types::TimeUnit::NANOSEC; + } + // satifsy gcc. should be unreachable. + return types::TimeUnit::SEC; +} + +void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType *gandiva_data_type) { + switch (type->id()) { + case arrow::Type::type::BOOL: + gandiva_data_type->set_type(types::GandivaType::BOOL); + break; + case arrow::Type::type::UINT8: + gandiva_data_type->set_type(types::GandivaType::UINT8); + break; + case arrow::Type::type::INT8: + gandiva_data_type->set_type(types::GandivaType::INT8); + break; + case arrow::Type::type::UINT16: + gandiva_data_type->set_type(types::GandivaType::UINT16); + break; + case arrow::Type::type::INT16: + gandiva_data_type->set_type(types::GandivaType::INT16); + break; + case arrow::Type::type::UINT32: + gandiva_data_type->set_type(types::GandivaType::UINT32); + break; + case arrow::Type::type::INT32: + gandiva_data_type->set_type(types::GandivaType::INT32); + break; + case arrow::Type::type::UINT64: + gandiva_data_type->set_type(types::GandivaType::UINT64); + break; + case arrow::Type::type::INT64: + gandiva_data_type->set_type(types::GandivaType::INT64); + break; + case arrow::Type::type::HALF_FLOAT: + gandiva_data_type->set_type(types::GandivaType::HALF_FLOAT); + break; + case arrow::Type::type::FLOAT: + gandiva_data_type->set_type(types::GandivaType::FLOAT); + break; + case arrow::Type::type::DOUBLE: + gandiva_data_type->set_type(types::GandivaType::DOUBLE); + break; + case arrow::Type::type::STRING: + gandiva_data_type->set_type(types::GandivaType::UTF8); + break; + case arrow::Type::type::BINARY: + gandiva_data_type->set_type(types::GandivaType::BINARY); + break; + case arrow::Type::type::DATE32: + gandiva_data_type->set_type(types::GandivaType::DATE32); + break; + case arrow::Type::type::DATE64: + gandiva_data_type->set_type(types::GandivaType::DATE64); + break; + case arrow::Type::type::TIMESTAMP: { + gandiva_data_type->set_type(types::GandivaType::TIMESTAMP); + std::shared_ptr cast_time_stamp_type = + std::dynamic_pointer_cast(type); + arrow::TimeUnit::type unit = cast_time_stamp_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::type::TIME32: { + gandiva_data_type->set_type(types::GandivaType::TIME32); + std::shared_ptr cast_time_32_type = + std::dynamic_pointer_cast(type); + arrow::TimeUnit::type unit = cast_time_32_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::type::TIME64: { + gandiva_data_type->set_type(types::GandivaType::TIME32); + std::shared_ptr cast_time_64_type = + std::dynamic_pointer_cast(type); + arrow::TimeUnit::type unit = cast_time_64_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::type::NA: + gandiva_data_type->set_type(types::GandivaType::NONE); + break; + case arrow::Type::type::FIXED_SIZE_BINARY: + case arrow::Type::type::MAP: + case arrow::Type::type::INTERVAL: + case arrow::Type::type::DECIMAL: + case arrow::Type::type::LIST: + case arrow::Type::type::STRUCT: + case arrow::Type::type::UNION: + case arrow::Type::type::DICTIONARY: + // un-supported types. test ensures that + // when one of these are added build breaks. + DCHECK(false); + } +} + +JNIEXPORT jbyteArray JNICALL +Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedDataTypes( + JNIEnv *env, jobject types_helper) { + types::GandivaDataTypes gandiva_data_types; + auto supported_types = ExpressionRegistry::supported_types(); + for (auto const &type : supported_types) { + types::ExtGandivaType *gandiva_data_type = gandiva_data_types.add_datatype(); + ArrowToProtobuf(type, gandiva_data_type); + } + size_t size = gandiva_data_types.ByteSizeLong(); + std::unique_ptr buffer{new jbyte[size]}; + gandiva_data_types.SerializeToArray((void *)buffer.get(), size); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, buffer.get()); + return ret; +} + +/* + * Class: org_apache_arrow_gandiva_types_ExpressionRegistryJniHelper + * Method: getGandivaSupportedFunctions + * Signature: ()[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedFunctions( + JNIEnv *env, jobject types_helper) { + ExpressionRegistry expr_registry; + types::GandivaFunctions gandiva_functions; + for (auto function = expr_registry.function_signature_begin(); + function != expr_registry.function_signature_end(); function++) { + types::FunctionSignature *function_signature = gandiva_functions.add_function(); + function_signature->set_name((*function).base_name()); + types::ExtGandivaType *return_type = function_signature->mutable_returntype(); + ArrowToProtobuf((*function).ret_type(), return_type); + for (auto ¶m_type : (*function).param_types()) { + types::ExtGandivaType *proto_param_type = function_signature->add_paramtypes(); + ArrowToProtobuf(param_type, proto_param_type); + } + } + size_t size = gandiva_functions.ByteSizeLong(); + std::unique_ptr buffer{new jbyte[size]}; + gandiva_functions.SerializeToArray((void *)buffer.get(), size); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, buffer.get()); + return ret; +} diff --git a/cpp/src/gandiva/jni/id_to_module_map.h b/cpp/src/gandiva/jni/id_to_module_map.h new file mode 100644 index 00000000000..83da6083953 --- /dev/null +++ b/cpp/src/gandiva/jni/id_to_module_map.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ + +#ifndef JNI_ID_TO_MODULE_MAP_H +#define JNI_ID_TO_MODULE_MAP_H + +#include +#include + +namespace gandiva { + +template +class IdToModuleMap { + public: + IdToModuleMap() : module_id_(kInitModuleId) {} + + jlong Insert(HOLDER holder) { + mtx_.lock(); + jlong result = module_id_++; + map_.insert(std::pair(result, holder)); + mtx_.unlock(); + return result; + } + + void Erase(jlong module_id) { + mtx_.lock(); + map_.erase(module_id); + mtx_.unlock(); + } + + HOLDER Lookup(jlong module_id) { + HOLDER result = nullptr; + try { + result = map_.at(module_id); + } catch (const std::out_of_range &e) { + } + if (result != nullptr) { + return result; + } + mtx_.lock(); + try { + result = map_.at(module_id); + } catch (const std::out_of_range &e) { + } + mtx_.unlock(); + return result; + } + + private: + static const int kInitModuleId = 4; + + long module_id_; + std::mutex mtx_; + // map from module ids returned to Java and module pointers + std::unordered_map map_; +}; + +} // namespace gandiva + +#endif // JNI_ID_TO_MODULE_MAP_H diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc new file mode 100644 index 00000000000..1abecf1b9b7 --- /dev/null +++ b/cpp/src/gandiva/jni/jni_common.cc @@ -0,0 +1,822 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "Types.pb.h" +#include "gandiva/configuration.h" +#include "gandiva/filter.h" +#include "gandiva/projector.h" +#include "gandiva/tree_expr_builder.h" +#include "jni/config_holder.h" +#include "jni/env_helper.h" +#include "jni/id_to_module_map.h" +#include "jni/module_holder.h" +#include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h" + +using gandiva::ConditionPtr; +using gandiva::DataTypePtr; +using gandiva::ExpressionPtr; +using gandiva::ExpressionVector; +using gandiva::FieldPtr; +using gandiva::FieldVector; +using gandiva::Filter; +using gandiva::NodePtr; +using gandiva::NodeVector; +using gandiva::Projector; +using gandiva::SchemaPtr; +using gandiva::Status; +using gandiva::TreeExprBuilder; + +using gandiva::ArrayDataVector; +using gandiva::ConfigHolder; +using gandiva::Configuration; +using gandiva::ConfigurationBuilder; +using gandiva::FilterHolder; +using gandiva::ProjectorHolder; + +// forward declarations +NodePtr ProtoTypeToNode(const types::TreeNode &node); + +static jint JNI_VERSION = JNI_VERSION_1_6; + +// extern refs - initialized for other modules. +jclass configuration_builder_class_; +jmethodID byte_code_accessor_method_id_; +jmethodID helper_library_accessor_method_id_; + +// refs for self. +static jclass gandiva_exception_; + +// module maps +gandiva::IdToModuleMap> projector_modules_; +gandiva::IdToModuleMap> filter_modules_; + +jint JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + jclass local_configuration_builder_class_ = + env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder"); + configuration_builder_class_ = + (jclass)env->NewGlobalRef(local_configuration_builder_class_); + env->DeleteLocalRef(local_configuration_builder_class_); + + jclass localExceptionClass = + env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException"); + gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass); + env->DeleteLocalRef(localExceptionClass); + + const char method_name[] = "getByteCodeFilePath"; + const char return_type[] = "()Ljava/lang/String;"; + byte_code_accessor_method_id_ = + env->GetMethodID(configuration_builder_class_, method_name, return_type); + + const char helper_method_name[] = "getHelperLibraryFilePath"; + helper_library_accessor_method_id_ = + env->GetMethodID(configuration_builder_class_, helper_method_name, return_type); + env->ExceptionDescribe(); + + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv *env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + env->DeleteGlobalRef(configuration_builder_class_); + env->DeleteGlobalRef(gandiva_exception_); +} + +DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType &ext_type) { + switch (ext_type.timeunit()) { + case types::SEC: + return arrow::time32(arrow::TimeUnit::SECOND); + case types::MILLISEC: + return arrow::time32(arrow::TimeUnit::MILLI); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType &ext_type) { + switch (ext_type.timeunit()) { + case types::MICROSEC: + return arrow::time64(arrow::TimeUnit::MICRO); + case types::NANOSEC: + return arrow::time64(arrow::TimeUnit::NANO); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType &ext_type) { + switch (ext_type.timeunit()) { + case types::SEC: + return arrow::timestamp(arrow::TimeUnit::SECOND); + case types::MILLISEC: + return arrow::timestamp(arrow::TimeUnit::MILLI); + case types::MICROSEC: + return arrow::timestamp(arrow::TimeUnit::MICRO); + case types::NANOSEC: + return arrow::timestamp(arrow::TimeUnit::NANO); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType &ext_type) { + switch (ext_type.type()) { + case types::NONE: + return arrow::null(); + case types::BOOL: + return arrow::boolean(); + case types::UINT8: + return arrow::uint8(); + case types::INT8: + return arrow::int8(); + case types::UINT16: + return arrow::uint16(); + case types::INT16: + return arrow::int16(); + case types::UINT32: + return arrow::uint32(); + case types::INT32: + return arrow::int32(); + case types::UINT64: + return arrow::uint64(); + case types::INT64: + return arrow::int64(); + case types::HALF_FLOAT: + return arrow::float16(); + case types::FLOAT: + return arrow::float32(); + case types::DOUBLE: + return arrow::float64(); + case types::UTF8: + return arrow::utf8(); + case types::BINARY: + return arrow::binary(); + case types::DATE32: + return arrow::date32(); + case types::DATE64: + return arrow::date64(); + case types::DECIMAL: + // TODO: error handling + return arrow::decimal(ext_type.precision(), ext_type.scale()); + case types::TIME32: + return ProtoTypeToTime32(ext_type); + case types::TIME64: + return ProtoTypeToTime64(ext_type); + case types::TIMESTAMP: + return ProtoTypeToTimestamp(ext_type); + + case types::FIXED_SIZE_BINARY: + case types::INTERVAL: + case types::LIST: + case types::STRUCT: + case types::UNION: + case types::DICTIONARY: + case types::MAP: + std::cerr << "Unhandled data type: " << ext_type.type() << "\n"; + return nullptr; + + default: + std::cerr << "Unknown data type: " << ext_type.type() << "\n"; + return nullptr; + } +} + +FieldPtr ProtoTypeToField(const types::Field &f) { + const std::string &name = f.name(); + DataTypePtr type = ProtoTypeToDataType(f.type()); + bool nullable = true; + if (f.has_nullable()) { + nullable = f.nullable(); + } + + return field(name, type, nullable); +} + +NodePtr ProtoTypeToFieldNode(const types::FieldNode &node) { + FieldPtr field_ptr = ProtoTypeToField(node.field()); + if (field_ptr == nullptr) { + std::cerr << "Unable to create field node from protobuf\n"; + return nullptr; + } + + return TreeExprBuilder::MakeField(field_ptr); +} + +NodePtr ProtoTypeToFnNode(const types::FunctionNode &node) { + const std::string &name = node.functionname(); + NodeVector children; + + for (int i = 0; i < node.inargs_size(); i++) { + const types::TreeNode &arg = node.inargs(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for function: " << name << "\n"; + return nullptr; + } + + children.push_back(n); + } + + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); + if (return_type == nullptr) { + std::cerr << "Unknown return type for function: " << name << "\n"; + return nullptr; + } + + return TreeExprBuilder::MakeFunction(name, children, return_type); +} + +NodePtr ProtoTypeToIfNode(const types::IfNode &node) { + NodePtr cond = ProtoTypeToNode(node.cond()); + if (cond == nullptr) { + std::cerr << "Unable to create cond node for if node\n"; + return nullptr; + } + + NodePtr then_node = ProtoTypeToNode(node.thennode()); + if (then_node == nullptr) { + std::cerr << "Unable to create then node for if node\n"; + return nullptr; + } + + NodePtr else_node = ProtoTypeToNode(node.elsenode()); + if (else_node == nullptr) { + std::cerr << "Unable to create else node for if node\n"; + return nullptr; + } + + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); + if (return_type == nullptr) { + std::cerr << "Unknown return type for if node\n"; + return nullptr; + } + + return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type); +} + +NodePtr ProtoTypeToAndNode(const types::AndNode &node) { + NodeVector children; + + for (int i = 0; i < node.args_size(); i++) { + const types::TreeNode &arg = node.args(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for boolean and\n"; + return nullptr; + } + children.push_back(n); + } + return TreeExprBuilder::MakeAnd(children); +} + +NodePtr ProtoTypeToOrNode(const types::OrNode &node) { + NodeVector children; + + for (int i = 0; i < node.args_size(); i++) { + const types::TreeNode &arg = node.args(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for boolean or\n"; + return nullptr; + } + children.push_back(n); + } + return TreeExprBuilder::MakeOr(children); +} + +NodePtr ProtoTypeToNullNode(const types::NullNode &node) { + DataTypePtr data_type = ProtoTypeToDataType(node.type()); + if (data_type == nullptr) { + std::cerr << "Unknown type " << data_type->ToString() << " for null node\n"; + return nullptr; + } + + return TreeExprBuilder::MakeNull(data_type); +} + +NodePtr ProtoTypeToNode(const types::TreeNode &node) { + if (node.has_fieldnode()) { + return ProtoTypeToFieldNode(node.fieldnode()); + } + + if (node.has_fnnode()) { + return ProtoTypeToFnNode(node.fnnode()); + } + + if (node.has_ifnode()) { + return ProtoTypeToIfNode(node.ifnode()); + } + + if (node.has_andnode()) { + return ProtoTypeToAndNode(node.andnode()); + } + + if (node.has_ornode()) { + return ProtoTypeToOrNode(node.ornode()); + } + + if (node.has_nullnode()) { + return ProtoTypeToNullNode(node.nullnode()); + } + + if (node.has_intnode()) { + return TreeExprBuilder::MakeLiteral(node.intnode().value()); + } + + if (node.has_floatnode()) { + return TreeExprBuilder::MakeLiteral(node.floatnode().value()); + } + + if (node.has_longnode()) { + return TreeExprBuilder::MakeLiteral(node.longnode().value()); + } + + if (node.has_booleannode()) { + return TreeExprBuilder::MakeLiteral(node.booleannode().value()); + } + + if (node.has_doublenode()) { + return TreeExprBuilder::MakeLiteral(node.doublenode().value()); + } + + if (node.has_stringnode()) { + return TreeExprBuilder::MakeStringLiteral(node.stringnode().value()); + } + + if (node.has_binarynode()) { + return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value()); + } + + std::cerr << "Unknown node type in protobuf\n"; + return nullptr; +} + +ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot &root) { + NodePtr root_node = ProtoTypeToNode(root.root()); + if (root_node == nullptr) { + std::cerr << "Unable to create expression node from expression protobuf\n"; + return nullptr; + } + + FieldPtr field = ProtoTypeToField(root.resulttype()); + if (field == nullptr) { + std::cerr << "Unable to extra return field from expression protobuf\n"; + return nullptr; + } + + return TreeExprBuilder::MakeExpression(root_node, field); +} + +ConditionPtr ProtoTypeToCondition(const types::Condition &condition) { + NodePtr root_node = ProtoTypeToNode(condition.root()); + if (root_node == nullptr) { + return nullptr; + } + + return TreeExprBuilder::MakeCondition(root_node); +} + +SchemaPtr ProtoTypeToSchema(const types::Schema &schema) { + std::vector fields; + + for (int i = 0; i < schema.columns_size(); i++) { + FieldPtr field = ProtoTypeToField(schema.columns(i)); + if (field == nullptr) { + std::cerr << "Unable to extract arrow field from schema\n"; + return nullptr; + } + + fields.push_back(field); + } + + return arrow::schema(fields); +} + +// Common for both projector and filters. + +bool ParseProtobuf(uint8_t *buf, int bufLen, google::protobuf::Message *msg) { + google::protobuf::io::CodedInputStream cis(buf, bufLen); + cis.SetRecursionLimit(1000); + return msg->ParseFromCodedStream(&cis); +} + +Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, + jlong *in_buf_addrs, jlong *in_buf_sizes, + int in_bufs_len, + std::shared_ptr *batch) { + std::vector> columns; + auto num_fields = schema->num_fields(); + int buf_idx = 0; + int sz_idx = 0; + + for (int i = 0; i < num_fields; i++) { + auto field = schema->field(i); + std::vector> buffers; + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + jlong validity_addr = in_buf_addrs[buf_idx++]; + jlong validity_size = in_buf_sizes[sz_idx++]; + auto validity = std::shared_ptr( + new arrow::Buffer(reinterpret_cast(validity_addr), validity_size)); + buffers.push_back(validity); + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + jlong value_addr = in_buf_addrs[buf_idx++]; + jlong value_size = in_buf_sizes[sz_idx++]; + auto data = std::shared_ptr( + new arrow::Buffer(reinterpret_cast(value_addr), value_size)); + buffers.push_back(data); + + if (arrow::is_binary_like(field->type()->id())) { + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + + // add offsets buffer for variable-len fields. + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto offsets = std::shared_ptr( + new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + buffers.push_back(offsets); + } + + auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers)); + columns.push_back(array_data); + } + *batch = arrow::RecordBatch::Make(schema, num_rows, columns); + return Status::OK(); +} + +// projector related functions. +void releaseProjectorInput(jbyteArray schema_arr, jbyte *schema_bytes, + jbyteArray exprs_arr, jbyte *exprs_bytes, JNIEnv *env) { + env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT); + env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT); +} + +JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector( + JNIEnv *env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr, + jlong configuration_id) { + jlong module_id = 0LL; + std::shared_ptr projector; + std::shared_ptr holder; + + types::Schema schema; + jsize schema_len = env->GetArrayLength(schema_arr); + jbyte *schema_bytes = env->GetByteArrayElements(schema_arr, 0); + + types::ExpressionList exprs; + jsize exprs_len = env->GetArrayLength(exprs_arr); + jbyte *exprs_bytes = env->GetByteArrayElements(exprs_arr, 0); + + ExpressionVector expr_vector; + SchemaPtr schema_ptr; + FieldVector ret_types; + gandiva::Status status; + + std::shared_ptr config = ConfigHolder::MapLookup(configuration_id); + std::stringstream ss; + + if (config == nullptr) { + ss << "configuration is mandatory."; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast(schema_bytes), schema_len, &schema)) { + ss << "Unable to parse schema protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast(exprs_bytes), exprs_len, &exprs)) { + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + ss << "Unable to parse expressions protobuf\n"; + goto err_out; + } + + // convert types::Schema to arrow::Schema + schema_ptr = ProtoTypeToSchema(schema); + if (schema_ptr == nullptr) { + ss << "Unable to construct arrow schema object from schema protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + // create Expression out of the list of exprs + for (int i = 0; i < exprs.exprs_size(); i++) { + ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i)); + + if (root == nullptr) { + ss << "Unable to construct expression object from expression protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + expr_vector.push_back(root); + ret_types.push_back(root->result()); + } + + // good to invoke the evaluator now + status = Projector::Make(schema_ptr, expr_vector, config, &projector); + + if (!status.ok()) { + ss << "Failed to make LLVM module due to " << status.message() << "\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + // store the result in a map + holder = std::shared_ptr( + new ProjectorHolder(schema_ptr, ret_types, std::move(projector))); + module_id = projector_modules_.Insert(holder); + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + return module_id; + +err_out: + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return module_id; +} + +#define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \ + if (idx >= len) { \ + status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \ + break; \ + } + +JNIEXPORT void JNICALL +Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( + JNIEnv *env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs, + jlongArray buf_sizes, jlongArray out_buf_addrs, jlongArray out_buf_sizes) { + Status status; + std::shared_ptr holder = projector_modules_.Lookup(module_id); + if (holder == nullptr) { + std::stringstream ss; + ss << "Unknown module id " << module_id; + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return; + } + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes"); + return; + } + + int out_bufs_len = env->GetArrayLength(out_buf_addrs); + if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) { + env->ThrowNew(gandiva_exception_, + "mismatch in arraylen of out_buf_addrs and out_buf_sizes"); + return; + } + + jlong *in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0); + jlong *in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0); + + jlong *out_bufs = env->GetLongArrayElements(out_buf_addrs, 0); + jlong *out_sizes = env->GetLongArrayElements(out_buf_sizes, 0); + + do { + std::shared_ptr in_batch; + status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs, + in_buf_sizes, in_bufs_len, &in_batch); + if (!status.ok()) { + break; + } + + auto ret_types = holder->rettypes(); + ArrayDataVector output; + int buf_idx = 0; + int sz_idx = 0; + for (FieldPtr field : ret_types) { + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t *validity_buf = reinterpret_cast(out_bufs[buf_idx++]); + jlong bitmap_sz = out_sizes[sz_idx++]; + std::shared_ptr bitmap_buf = + std::make_shared(validity_buf, bitmap_sz); + + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t *value_buf = reinterpret_cast(out_bufs[buf_idx++]); + jlong data_sz = out_sizes[sz_idx++]; + std::shared_ptr data_buf = + std::make_shared(value_buf, data_sz); + + auto array_data = + arrow::ArrayData::Make(field->type(), num_rows, {bitmap_buf, data_buf}); + output.push_back(array_data); + } + if (!status.ok()) { + break; + } + + status = holder->projector()->Evaluate(*in_batch, output); + } while (0); + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT); + env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT); + + if (!status.ok()) { + std::stringstream ss; + ss << "Evaluate returned " << status.message() << "\n"; + env->ThrowNew(gandiva_exception_, status.message().c_str()); + return; + } +} + +JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector( + JNIEnv *env, jobject cls, jlong module_id) { + projector_modules_.Erase(module_id); +} + +// filter related functions. +void releaseFilterInput(jbyteArray schema_arr, jbyte *schema_bytes, + jbyteArray condition_arr, jbyte *condition_bytes, JNIEnv *env) { + env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT); + env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT); +} + +JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter( + JNIEnv *env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr, + jlong configuration_id) { + jlong module_id = 0LL; + std::shared_ptr filter; + std::shared_ptr holder; + + types::Schema schema; + jsize schema_len = env->GetArrayLength(schema_arr); + jbyte *schema_bytes = env->GetByteArrayElements(schema_arr, 0); + + types::Condition condition; + jsize condition_len = env->GetArrayLength(condition_arr); + jbyte *condition_bytes = env->GetByteArrayElements(condition_arr, 0); + + ConditionPtr condition_ptr; + SchemaPtr schema_ptr; + gandiva::Status status; + + std::shared_ptr config = ConfigHolder::MapLookup(configuration_id); + std::stringstream ss; + + if (config == nullptr) { + ss << "configuration is mandatory."; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast(schema_bytes), schema_len, &schema)) { + ss << "Unable to parse schema protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast(condition_bytes), condition_len, + &condition)) { + ss << "Unable to parse condition protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // convert types::Schema to arrow::Schema + schema_ptr = ProtoTypeToSchema(schema); + if (schema_ptr == nullptr) { + ss << "Unable to construct arrow schema object from schema protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + condition_ptr = ProtoTypeToCondition(condition); + if (condition_ptr == nullptr) { + ss << "Unable to construct condition object from condition protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // good to invoke the filter builder now + status = Filter::Make(schema_ptr, condition_ptr, config, &filter); + if (!status.ok()) { + ss << "Failed to make LLVM module due to " << status.message() << "\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // store the result in a map + holder = std::shared_ptr(new FilterHolder(schema_ptr, std::move(filter))); + module_id = filter_modules_.Insert(holder); + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + return module_id; + +err_out: + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return module_id; +} + +JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter( + JNIEnv *env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs, + jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr, + jlong out_buf_size) { + gandiva::Status status; + std::shared_ptr holder = filter_modules_.Lookup(module_id); + if (holder == nullptr) { + env->ThrowNew(gandiva_exception_, "Unknown module id\n"); + return -1; + } + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes"); + return -1; + } + + jlong *in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0); + jlong *in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0); + std::shared_ptr selection_vector; + + do { + std::shared_ptr in_batch; + + status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs, + in_buf_sizes, in_bufs_len, &in_batch); + if (!status.ok()) { + break; + } + + auto selection_vector_type = + static_cast(jselection_vector_type); + auto out_buffer = std::make_shared( + reinterpret_cast(out_buf_addr), out_buf_size); + switch (selection_vector_type) { + case types::SV_INT16: + status = + gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector); + break; + case types::SV_INT32: + status = + gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector); + break; + default: + status = gandiva::Status::Invalid("unknown selection vector type"); + } + if (!status.ok()) { + break; + } + + status = holder->filter()->Evaluate(*in_batch, selection_vector); + } while (0); + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + + if (!status.ok()) { + std::stringstream ss; + ss << "Evaluate returned " << status.message() << "\n"; + env->ThrowNew(gandiva_exception_, status.message().c_str()); + return -1; + } else { + return selection_vector->GetNumSlots(); + } +} + +JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter( + JNIEnv *env, jobject cls, jlong module_id) { + filter_modules_.Erase(module_id); +} diff --git a/cpp/src/gandiva/jni/module_holder.h b/cpp/src/gandiva/jni/module_holder.h new file mode 100644 index 00000000000..e75cd332b94 --- /dev/null +++ b/cpp/src/gandiva/jni/module_holder.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + */ + +#ifndef JNI_MODULE_HOLDER_H +#define JNI_MODULE_HOLDER_H + +#include +#include + +#include "gandiva/arrow.h" + +namespace gandiva { + +class Projector; +class Filter; + +class ProjectorHolder { + public: + ProjectorHolder(SchemaPtr schema, FieldVector ret_types, + std::shared_ptr projector) + : schema_(schema), ret_types_(ret_types), projector_(std::move(projector)) {} + + SchemaPtr schema() { return schema_; } + FieldVector rettypes() { return ret_types_; } + std::shared_ptr projector() { return projector_; } + + private: + SchemaPtr schema_; + FieldVector ret_types_; + std::shared_ptr projector_; +}; + +class FilterHolder { + public: + FilterHolder(SchemaPtr schema, std::shared_ptr filter) + : schema_(schema), filter_(std::move(filter)) {} + + SchemaPtr schema() { return schema_; } + std::shared_ptr filter() { return filter_; } + + private: + SchemaPtr schema_; + std::shared_ptr filter_; +}; + +} // namespace gandiva + +#endif // JNI_MODULE_HOLDER_H diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt b/cpp/src/gandiva/precompiled/CMakeLists.txt new file mode 100644 index 00000000000..19b25f196db --- /dev/null +++ b/cpp/src/gandiva/precompiled/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (C) 2017-2018 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project(gandiva) + +set(PRECOMPILED_SRCS + arithmetic_ops.cc + bitmap.cc + hash.cc + print.cc + sample.cc + string_ops.cc + time.cc + timestamp_arithmetic.cc) + +# Create bitcode for each of the source files. +foreach(SRC_FILE ${PRECOMPILED_SRCS}) + get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE) + get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE) + set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc) + add_custom_command( + OUTPUT ${BC_FILE} + COMMAND ${CLANG_EXECUTABLE} + -std=c++11 -emit-llvm -O2 -c ${ABSOLUTE_SRC} -o ${BC_FILE} + DEPENDS ${SRC_FILE}) + list(APPEND BC_FILES ${BC_FILE}) +endforeach() + +# link all of the bitcode files into a single bitcode file. +add_custom_command( + OUTPUT ${GANDIVA_BC_OUTPUT_PATH} + COMMAND ${LINK_EXECUTABLE} + -o ${GANDIVA_BC_OUTPUT_PATH} + ${BC_FILES} + DEPENDS ${BC_FILES}) + +add_custom_target(precompiled ALL DEPENDS ${GANDIVA_BC_OUTPUT_PATH}) + +# testing +add_precompiled_unit_test(bitmap_test.cc bitmap.cc) +add_precompiled_unit_test(time_test.cc time.cc timestamp_arithmetic.cc) +add_precompiled_unit_test(hash_test.cc hash.cc) +add_precompiled_unit_test(sample_test.cc sample.cc) +add_precompiled_unit_test(string_ops_test.cc string_ops.cc) +add_precompiled_unit_test(arithmetic_ops_test.cc arithmetic_ops.cc) diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/cpp/src/gandiva/precompiled/arithmetic_ops.cc new file mode 100644 index 00000000000..9d13498857d --- /dev/null +++ b/cpp/src/gandiva/precompiled/arithmetic_ops.cc @@ -0,0 +1,152 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include "./types.h" + +// Expand inner macro for all numeric types. +#define NUMERIC_TYPES(INNER, NAME, OP) \ + INNER(NAME, int8, OP) \ + INNER(NAME, int16, OP) \ + INNER(NAME, int32, OP) \ + INNER(NAME, int64, OP) \ + INNER(NAME, uint8, OP) \ + INNER(NAME, uint16, OP) \ + INNER(NAME, uint32, OP) \ + INNER(NAME, uint64, OP) \ + INNER(NAME, float32, OP) \ + INNER(NAME, float64, OP) + +// Expand inner macros for all date/time types. +#define DATE_TYPES(INNER, NAME, OP) \ + INNER(NAME, date64, OP) \ + INNER(NAME, timestamp, OP) \ + INNER(NAME, time32, OP) + +#define NUMERIC_DATE_TYPES(INNER, NAME, OP) \ + NUMERIC_TYPES(INNER, NAME, OP) \ + DATE_TYPES(INNER, NAME, OP) + +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME, OP) \ + NUMERIC_TYPES(INNER, NAME, OP) \ + DATE_TYPES(INNER, NAME, OP) \ + INNER(NAME, boolean, OP) + +#define MOD_OP(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + FORCE_INLINE \ + OUT_TYPE NAME##_##IN_TYPE1##_##IN_TYPE2(IN_TYPE1 left, IN_TYPE2 right) { \ + return (right == 0 ? left : left % right); \ + } + +// Symmetric binary fns : left, right params and return type are same. +#define BINARY_SYMMETRIC(NAME, TYPE, OP) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE##_##TYPE(TYPE left, TYPE right) { return left OP right; } + +NUMERIC_TYPES(BINARY_SYMMETRIC, add, +) +NUMERIC_TYPES(BINARY_SYMMETRIC, subtract, -) +NUMERIC_TYPES(BINARY_SYMMETRIC, multiply, *) +NUMERIC_TYPES(BINARY_SYMMETRIC, divide, /) + +MOD_OP(mod, int64, int32, int32) +MOD_OP(mod, int64, int64, int64) + +// Relational binary fns : left, right params are same, return is bool. +#define BINARY_RELATIONAL(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE##_##TYPE(TYPE left, TYPE right) { return left OP right; } + +NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, equal, ==) +NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, not_equal, !=) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than, <) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than, >) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) + +// cast fns : takes one param type, returns another type. +#define CAST_UNARY(NAME, IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + OUT_TYPE NAME##_##IN_TYPE(IN_TYPE in) { return (OUT_TYPE)in; } + +CAST_UNARY(castBIGINT, int32, int64) +CAST_UNARY(castFLOAT4, int32, float32) +CAST_UNARY(castFLOAT4, int64, float32) +CAST_UNARY(castFLOAT8, int32, float64) +CAST_UNARY(castFLOAT8, int64, float64) +CAST_UNARY(castFLOAT8, float32, float64) + +// simple nullable functions, result value = fn(input validity) +#define VALIDITY_OP(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE(TYPE in, boolean is_valid) { return OP is_valid; } + +NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnull, !) +NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnotnull, +) +NUMERIC_TYPES(VALIDITY_OP, isnumeric, +) + +#define NUMERIC_FUNCTION(INNER) \ + INNER(int8) \ + INNER(int16) \ + INNER(int32) \ + INNER(int64) \ + INNER(uint8) \ + INNER(uint16) \ + INNER(uint32) \ + INNER(uint64) \ + INNER(float32) \ + INNER(float64) + +#define DATE_FUNCTION(INNER) \ + INNER(date64) \ + INNER(timestamp) \ + INNER(time32) + +#define NUMERIC_BOOL_DATE_FUNCTION(INNER) \ + NUMERIC_FUNCTION(INNER) \ + DATE_FUNCTION(INNER) \ + INNER(boolean) + +// is_distinct_from +#define IS_DISTINCT_FROM(TYPE) \ + FORCE_INLINE \ + bool is_distinct_from_##TYPE##_##TYPE(TYPE in1, boolean is_valid1, TYPE in2, \ + boolean is_valid2) { \ + if (is_valid1 != is_valid2) { \ + return true; \ + } \ + if (!is_valid1) { \ + return false; \ + } \ + return in1 != in2; \ + } + +// is_not_distinct_from +#define IS_NOT_DISTINCT_FROM(TYPE) \ + FORCE_INLINE \ + bool is_not_distinct_from_##TYPE##_##TYPE(TYPE in1, boolean is_valid1, TYPE in2, \ + boolean is_valid2) { \ + if (is_valid1 != is_valid2) { \ + return false; \ + } \ + if (!is_valid1) { \ + return true; \ + } \ + return in1 == in2; \ + } + +NUMERIC_BOOL_DATE_FUNCTION(IS_DISTINCT_FROM) +NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM) + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc new file mode 100644 index 00000000000..fc46464f4a2 --- /dev/null +++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc @@ -0,0 +1,38 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "precompiled/types.h" + +namespace gandiva { + +TEST(TestArithmeticOps, TestIsDistinctFrom) { + EXPECT_EQ(is_distinct_from_timestamp_timestamp((int64)1000, true, (int64)1000, false), + true); + EXPECT_EQ(is_distinct_from_timestamp_timestamp((int64)1000, false, (int64)1000, true), + true); + EXPECT_EQ(is_distinct_from_timestamp_timestamp((int64)1000, false, (int64)1000, false), + false); + EXPECT_EQ(is_distinct_from_timestamp_timestamp((int64)1000, true, (int64)1000, true), + false); + + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, false), false); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, true), false); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, false), true); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, true), true); +} + +TEST(TestArithmeticOps, TestMod) { EXPECT_EQ(mod_int64_int32(10, 0), 10); } + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/bitmap.cc b/cpp/src/gandiva/precompiled/bitmap.cc new file mode 100644 index 00000000000..195b41983dc --- /dev/null +++ b/cpp/src/gandiva/precompiled/bitmap.cc @@ -0,0 +1,55 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// BitMap functions + +extern "C" { + +#include "./types.h" + +#define BITS_TO_BYTES(x) ((x + 7) / 8) +#define BITS_TO_WORDS(x) ((x + 63) / 64) + +#define POS_TO_BYTE_INDEX(p) (p / 8) +#define POS_TO_BIT_INDEX(p) (p % 8) + +FORCE_INLINE +bool bitMapGetBit(const unsigned char *bmap, int position) { + int byteIdx = POS_TO_BYTE_INDEX(position); + int bitIdx = POS_TO_BIT_INDEX(position); + return ((bmap[byteIdx] & (1 << bitIdx)) > 0); +} + +FORCE_INLINE +void bitMapSetBit(unsigned char *bmap, int position, bool value) { + int byteIdx = POS_TO_BYTE_INDEX(position); + int bitIdx = POS_TO_BIT_INDEX(position); + if (value) { + bmap[byteIdx] |= (1 << bitIdx); + } else { + bmap[byteIdx] &= ~(1 << bitIdx); + } +} + +// Clear the bit if value = false. Does nothing if value = true. +FORCE_INLINE +void bitMapClearBitIfFalse(unsigned char *bmap, int position, bool value) { + if (!value) { + int byteIdx = POS_TO_BYTE_INDEX(position); + int bitIdx = POS_TO_BIT_INDEX(position); + bmap[byteIdx] &= ~(1 << bitIdx); + } +} + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/bitmap_test.cc b/cpp/src/gandiva/precompiled/bitmap_test.cc new file mode 100644 index 00000000000..1f9b395122c --- /dev/null +++ b/cpp/src/gandiva/precompiled/bitmap_test.cc @@ -0,0 +1,59 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "precompiled/types.h" + +namespace gandiva { + +TEST(TestBitMap, TestSimple) { + static const int kNumBytes = 16; + uint8_t bit_map[kNumBytes]; + memset(bit_map, 0, kNumBytes); + + EXPECT_EQ(bitMapGetBit(bit_map, 100), false); + + // set 100th bit and verify + bitMapSetBit(bit_map, 100, true); + EXPECT_EQ(bitMapGetBit(bit_map, 100), true); + + // clear 100th bit and verify + bitMapSetBit(bit_map, 100, false); + EXPECT_EQ(bitMapGetBit(bit_map, 100), false); +} + +TEST(TestBitMap, TestClearIfFalse) { + static const int kNumBytes = 32; + uint8_t bit_map[kNumBytes]; + memset(bit_map, 0, kNumBytes); + + bitMapSetBit(bit_map, 24, true); + + // bit should remain unchanged. + bitMapClearBitIfFalse(bit_map, 24, true); + EXPECT_EQ(bitMapGetBit(bit_map, 24), true); + + // bit should be cleared. + bitMapClearBitIfFalse(bit_map, 24, false); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); + + // this function should have no impact if the bit is already clear. + bitMapClearBitIfFalse(bit_map, 24, true); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); + + bitMapClearBitIfFalse(bit_map, 24, false); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/hash.cc b/cpp/src/gandiva/precompiled/hash.cc new file mode 100644 index 00000000000..72d52d0b043 --- /dev/null +++ b/cpp/src/gandiva/precompiled/hash.cc @@ -0,0 +1,280 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include +#include "./types.h" + +static inline uint64 rotate_left(uint64 val, int distance) { + return (val << distance) | (val >> (64 - distance)); +} + +// +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. +// See http://smhasher.googlecode.com/svn/trunk/MurmurHash3.cpp +// MurmurHash3_x64_128 +// +static inline uint64 fmix64(uint64 k) { + k ^= k >> 33; + k *= 0xff51afd7ed558ccduLL; + k ^= k >> 33; + k *= 0xc4ceb9fe1a85ec53uLL; + k ^= k >> 33; + return k; +} + +static inline uint64 murmur3_64(uint64 val, int32 seed) { + uint64 h1 = seed; + uint64 h2 = seed; + + uint64 c1 = 0x87c37b91114253d5ull; + uint64 c2 = 0x4cf5ad432745937full; + + int length = 8; + uint64 k1 = 0; + + k1 = val; + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + + h1 ^= length; + h2 ^= length; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + + // h2 += h1; + // murmur3_128 should return 128 bit (h1,h2), now we return only 64bits, + return h1; +} + +static inline uint64 double_to_long_bits(double value) { + uint64 result; + memcpy(&result, &value, sizeof(result)); + return result; +} + +FORCE_INLINE int64 hash64(double val, int64 seed) { + return (int64)murmur3_64(double_to_long_bits(val), (int32)seed); +} + +FORCE_INLINE int32 hash32(double val, int32 seed) { + return (int32)murmur3_64(double_to_long_bits(val), seed); +} + +// Wrappers for all the numeric/data/time arrow types + +#define HASH64_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + int64 NAME##_##TYPE(TYPE in, boolean is_valid, int64 seed, boolean seed_isvalid) { \ + return is_valid && seed_isvalid ? hash64((double)in, seed) : 0; \ + } + +#define HASH32_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + int32 NAME##_##TYPE(TYPE in, boolean is_valid, int32 seed, boolean seed_isvalid) { \ + return is_valid && seed_isvalid ? hash32((double)in, seed) : 0; \ + } + +#define HASH64_OP(NAME, TYPE) \ + FORCE_INLINE \ + int64 NAME##_##TYPE(TYPE in, boolean is_valid) { \ + return is_valid ? hash64((double)in, 0) : 0; \ + } + +#define HASH32_OP(NAME, TYPE) \ + FORCE_INLINE \ + int32 NAME##_##TYPE(TYPE in, boolean is_valid) { \ + return is_valid ? hash32((double)in, 0) : 0; \ + } + +// Expand inner macro for all numeric types. +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME) \ + INNER(NAME, int8) \ + INNER(NAME, int16) \ + INNER(NAME, int32) \ + INNER(NAME, int64) \ + INNER(NAME, uint8) \ + INNER(NAME, uint16) \ + INNER(NAME, uint32) \ + INNER(NAME, uint64) \ + INNER(NAME, float32) \ + INNER(NAME, float64) \ + INNER(NAME, boolean) \ + INNER(NAME, date64) \ + INNER(NAME, time32) \ + INNER(NAME, timestamp) + +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash) +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32) +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32AsDouble) +NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32WithSeed) +NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32AsDoubleWithSeed) + +NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64) +NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64AsDouble) +NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64WithSeed) +NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64AsDoubleWithSeed) + +static inline uint64 murmur3_64_buf(const uint8 *key, int32 len, int32 seed) { + uint64 h1 = seed; + uint64 h2 = seed; + uint64 c1 = 0x87c37b91114253d5ull; + uint64 c2 = 0x4cf5ad432745937full; + + const uint64 *blocks = (const uint64 *)key; + int nblocks = len / 16; + for (int i = 0; i < nblocks; i++) { + uint64 k1 = blocks[i * 2 + 0]; + uint64 k2 = blocks[i * 2 + 1]; + + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + h1 = rotate_left(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + k2 *= c2; + k2 = rotate_left(k2, 33); + k2 *= c1; + h2 ^= k2; + h2 = rotate_left(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + // tail + uint64 k1 = 0; + uint64 k2 = 0; + + const uint8 *tail = (const uint8 *)(key + nblocks * 16); + switch (len & 15) { + case 15: + k2 = (uint64)(tail[14]) << 48; + case 14: + k2 ^= (uint64)(tail[13]) << 40; + case 13: + k2 ^= (uint64)(tail[12]) << 32; + case 12: + k2 ^= (uint64)(tail[11]) << 24; + case 11: + k2 ^= (uint64)(tail[10]) << 16; + case 10: + k2 ^= (uint64)(tail[9]) << 8; + case 9: + k2 ^= (uint64)(tail[8]); + k2 *= c2; + k2 = rotate_left(k2, 33); + k2 *= c1; + h2 ^= k2; + case 8: + k1 ^= (uint64)(tail[7]) << 56; + case 7: + k1 ^= (uint64)(tail[6]) << 48; + case 6: + k1 ^= (uint64)(tail[5]) << 40; + case 5: + k1 ^= (uint64)(tail[4]) << 32; + case 4: + k1 ^= (uint64)(tail[3]) << 24; + case 3: + k1 ^= (uint64)(tail[2]) << 16; + case 2: + k1 ^= (uint64)(tail[1]) << 8; + case 1: + k1 ^= (uint64)(tail[0]) << 0; + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + }; + + h1 ^= len; + h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + // h2 += h1; + // returning 64-bits of the 128-bit hash. + return h1; +} + +FORCE_INLINE int64 hash64_buf(const uint8 *buf, int len, int64 seed) { + return (int64)murmur3_64_buf(buf, len, (int32)seed); +} + +FORCE_INLINE int32 hash32_buf(const uint8 *buf, int len, int32 seed) { + return (int32)murmur3_64_buf(buf, len, seed); +} + +// Wrappers for the varlen types + +#define HASH64_BUF_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + int64 NAME##_##TYPE(TYPE in, int32 len, boolean is_valid, int64 seed, \ + boolean seed_isvalid) { \ + return is_valid && seed_isvalid ? hash64_buf((const uint8 *)in, len, seed) : 0; \ + } + +#define HASH32_BUF_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + int32 NAME##_##TYPE(TYPE in, int32 len, boolean is_valid, int32 seed, \ + boolean seed_isvalid) { \ + return is_valid && seed_isvalid ? hash32_buf((const uint8 *)in, len, seed) : 0; \ + } + +#define HASH64_BUF_OP(NAME, TYPE) \ + FORCE_INLINE \ + int64 NAME##_##TYPE(TYPE in, int32 len, boolean is_valid) { \ + return is_valid ? hash64_buf((const uint8 *)in, len, 0) : 0; \ + } + +#define HASH32_BUF_OP(NAME, TYPE) \ + FORCE_INLINE \ + int32 NAME##_##TYPE(TYPE in, int32 len, boolean is_valid) { \ + return is_valid ? hash32_buf((const uint8 *)in, len, 0) : 0; \ + } + +// Expand inner macro for all numeric types. +#define VAR_LEN_TYPES(INNER, NAME) \ + INNER(NAME, utf8) \ + INNER(NAME, binary) + +VAR_LEN_TYPES(HASH32_BUF_OP, hash) +VAR_LEN_TYPES(HASH32_BUF_OP, hash32) +VAR_LEN_TYPES(HASH32_BUF_OP, hash32AsDouble) +VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32WithSeed) +VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32AsDoubleWithSeed) + +VAR_LEN_TYPES(HASH64_BUF_OP, hash64) +VAR_LEN_TYPES(HASH64_BUF_OP, hash64AsDouble) +VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64WithSeed) +VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64AsDoubleWithSeed) + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/hash_test.cc b/cpp/src/gandiva/precompiled/hash_test.cc new file mode 100644 index 00000000000..7e45b196949 --- /dev/null +++ b/cpp/src/gandiva/precompiled/hash_test.cc @@ -0,0 +1,119 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "precompiled/types.h" + +namespace gandiva { + +TEST(TestHash, TestHash32) { + int8 s8 = 0; + uint8 u8 = 0; + int16 s16 = 0; + uint16 u16 = 0; + int32 s32 = 0; + uint32 u32 = 0; + int64 s64 = 0; + uint64 u64 = 0; + float32 f32 = 0; + float64 f64 = 0; + + // hash of 0 should be non-zero (zero is the hash value for nulls). + int32 zero_hash = hash32(s8, 0); + EXPECT_NE(zero_hash, 0); + + // for a given value, all numeric types must have the same hash. + EXPECT_EQ(hash32(u8, 0), zero_hash); + EXPECT_EQ(hash32(s16, 0), zero_hash); + EXPECT_EQ(hash32(u16, 0), zero_hash); + EXPECT_EQ(hash32(s32, 0), zero_hash); + EXPECT_EQ(hash32(u32, 0), zero_hash); + EXPECT_EQ(hash32(s64, 0), zero_hash); + EXPECT_EQ(hash32(u64, 0), zero_hash); + EXPECT_EQ(hash32(f32, 0), zero_hash); + EXPECT_EQ(hash32(f64, 0), zero_hash); + + // hash must change with a change in seed. + EXPECT_NE(hash32(s8, 1), zero_hash); + + // for a given value and seed, all numeric types must have the same hash. + EXPECT_EQ(hash32(s8, 1), hash32(s16, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(u32, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(f32, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(f64, 1)); +} + +TEST(TestHash, TestHash64) { + int8 s8 = 0; + uint8 u8 = 0; + int16 s16 = 0; + uint16 u16 = 0; + int32 s32 = 0; + uint32 u32 = 0; + int64 s64 = 0; + uint64 u64 = 0; + float32 f32 = 0; + float64 f64 = 0; + + // hash of 0 should be non-zero (zero is the hash value for nulls). + int64 zero_hash = hash64(s8, 0); + EXPECT_NE(zero_hash, 0); + EXPECT_NE(hash64(u8, 0), hash32(u8, 0)); + + // for a given value, all numeric types must have the same hash. + EXPECT_EQ(hash64(u8, 0), zero_hash); + EXPECT_EQ(hash64(s16, 0), zero_hash); + EXPECT_EQ(hash64(u16, 0), zero_hash); + EXPECT_EQ(hash64(s32, 0), zero_hash); + EXPECT_EQ(hash64(u32, 0), zero_hash); + EXPECT_EQ(hash64(s64, 0), zero_hash); + EXPECT_EQ(hash64(u64, 0), zero_hash); + EXPECT_EQ(hash64(f32, 0), zero_hash); + EXPECT_EQ(hash64(f64, 0), zero_hash); + + // hash must change with a change in seed. + EXPECT_NE(hash64(s8, 1), zero_hash); + + // for a given value and seed, all numeric types must have the same hash. + EXPECT_EQ(hash64(s8, 1), hash64(s16, 1)); + EXPECT_EQ(hash64(s8, 1), hash64(u32, 1)); + EXPECT_EQ(hash64(s8, 1), hash64(f32, 1)); +} + +TEST(TestHash, TestHashBuf) { + const char *buf = "hello"; + int buf_len = 5; + + // hash should be non-zero (zero is the hash value for nulls). + EXPECT_NE(hash32_buf((const uint8 *)buf, buf_len, 0), 0); + EXPECT_NE(hash64_buf((const uint8 *)buf, buf_len, 0), 0); + + // hash must change if the string is changed. + EXPECT_NE(hash32_buf((const uint8 *)buf, buf_len, 0), + hash32_buf((const uint8 *)buf, buf_len - 1, 0)); + + EXPECT_NE(hash64_buf((const uint8 *)buf, buf_len, 0), + hash64_buf((const uint8 *)buf, buf_len - 1, 0)); + + // hash must change if the seed is changed. + EXPECT_NE(hash32_buf((const uint8 *)buf, buf_len, 0), + hash32_buf((const uint8 *)buf, buf_len, 1)); + + EXPECT_NE(hash64_buf((const uint8 *)buf, buf_len, 0), + hash64_buf((const uint8 *)buf, buf_len, 1)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/print.cc b/cpp/src/gandiva/precompiled/print.cc new file mode 100644 index 00000000000..f497b957dad --- /dev/null +++ b/cpp/src/gandiva/precompiled/print.cc @@ -0,0 +1,25 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include + +#include "./types.h" + +int print_double(char *msg, double val) { return printf(msg, val); } + +int print_float(char *msg, float val) { return printf(msg, val); } + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/sample.cc b/cpp/src/gandiva/precompiled/sample.cc new file mode 100644 index 00000000000..3af2beff450 --- /dev/null +++ b/cpp/src/gandiva/precompiled/sample.cc @@ -0,0 +1,35 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include "./types.h" + +// Dummy function to test NULL_INTERNAL (most valid ones need varchar). + +// If input is valid and a multiple of 2, return half the value. else, null. +FORCE_INLINE +int half_or_null_int32(int32 val, bool in_valid, bool *out_valid) { + if (in_valid && (val % 2 == 0)) { + // output is valid. + *out_valid = true; + return val / 2; + } else { + // output is invalid. + *out_valid = false; + return 0; + } +} + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/sample_test.cc b/cpp/src/gandiva/precompiled/sample_test.cc new file mode 100644 index 00000000000..a2b2830f1bb --- /dev/null +++ b/cpp/src/gandiva/precompiled/sample_test.cc @@ -0,0 +1,48 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +extern "C" int half_or_null_int32(int val, bool in_valid, bool *out_valid); + +namespace gandiva { + +TEST(TestSample, half_or_null) { + bool is_valid = false; + int ret; + + // 4 is a multiple, so expect 2. + ret = half_or_null_int32(4, true, &is_valid); + EXPECT_EQ(ret, 2); + EXPECT_EQ(is_valid, true); + + // if input is not valid, expect null. + ret = half_or_null_int32(4, false, &is_valid); + EXPECT_EQ(is_valid, false); + + // -16 is a multiple, so expect 8. + ret = half_or_null_int32(-16, true, &is_valid); + EXPECT_EQ(ret, -8); + EXPECT_EQ(is_valid, true); + + // 5 is not a multiple, so expect null. + ret = half_or_null_int32(5, true, &is_valid); + EXPECT_EQ(is_valid, false); + + // -31 is not a multiple, so expect null. + ret = half_or_null_int32(-31, true, &is_valid); + EXPECT_EQ(is_valid, false); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc new file mode 100644 index 00000000000..2df5881840f --- /dev/null +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -0,0 +1,69 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// String functions + +extern "C" { + +#include +#include "./types.h" + +FORCE_INLINE +int32 octet_length_utf8(const utf8 input, int32 length) { return length; } + +FORCE_INLINE +int32 bit_length_utf8(const utf8 input, int32 length) { return length * 8; } + +FORCE_INLINE +int32 octet_length_binary(const binary input, int32 length) { return length; } + +FORCE_INLINE +int32 bit_length_binary(const binary input, int32 length) { return length * 8; } + +FORCE_INLINE +int32 mem_compare(const char *left, int32 left_len, const char *right, int32 right_len) { + int min = left_len; + if (right_len < min) { + min = right_len; + } + + int cmp_ret = memcmp(left, right, min); + if (cmp_ret != 0) { + return cmp_ret; + } else { + return left_len - right_len; + } +} + +// Expand inner macro for all varlen types. +#define VAR_LEN_TYPES(INNER, NAME, OP) \ + INNER(NAME, utf8, OP) \ + INNER(NAME, binary, OP) + +// Relational binary fns : left, right params are same, return is bool. +#define BINARY_RELATIONAL(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE##_##TYPE(const TYPE left, int32 left_len, const TYPE right, \ + int32 right_len) { \ + return mem_compare(left, left_len, right, right_len) OP 0; \ + } + +VAR_LEN_TYPES(BINARY_RELATIONAL, equal, ==) +VAR_LEN_TYPES(BINARY_RELATIONAL, not_equal, !=) +VAR_LEN_TYPES(BINARY_RELATIONAL, less_than, <) +VAR_LEN_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) +VAR_LEN_TYPES(BINARY_RELATIONAL, greater_than, >) +VAR_LEN_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc new file mode 100644 index 00000000000..ddc66c90ff3 --- /dev/null +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -0,0 +1,37 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "precompiled/types.h" + +namespace gandiva { + +TEST(TestStringOps, TestCompare) { + const char *left = "abcd789"; + const char *right = "abcd123"; + + // 0 for equal + EXPECT_EQ(mem_compare(left, 4, right, 4), 0); + + // compare lengths if the prefixes match + EXPECT_GT(mem_compare(left, 5, right, 4), 0); + EXPECT_LT(mem_compare(left, 4, right, 5), 0); + + // compare bytes if the prefixes don't match + EXPECT_GT(mem_compare(left, 5, right, 5), 0); + EXPECT_GT(mem_compare(left, 5, right, 7), 0); + EXPECT_GT(mem_compare(left, 7, right, 5), 0); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/time.cc b/cpp/src/gandiva/precompiled/time.cc new file mode 100644 index 00000000000..2816255fefb --- /dev/null +++ b/cpp/src/gandiva/precompiled/time.cc @@ -0,0 +1,474 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include +#include + +#include "./time_constants.h" +#include "./types.h" + +#define MINS_IN_HOUR 60 +#define SECONDS_IN_MINUTE 60 + +// Expand inner macro for all date types. +#define DATE_TYPES(INNER) \ + INNER(date64) \ + INNER(timestamp) + +// Extract millennium +#define EXTRACT_MILLENNIUM(TYPE) \ + FORCE_INLINE \ + int64 extractMillennium##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return (1900 + tm.tm_year - 1) / 1000 + 1; \ + } + +DATE_TYPES(EXTRACT_MILLENNIUM) + +// Extract century +#define EXTRACT_CENTURY(TYPE) \ + FORCE_INLINE \ + int64 extractCentury##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return (1900 + tm.tm_year - 1) / 100 + 1; \ + } + +DATE_TYPES(EXTRACT_CENTURY) + +// Extract decade +#define EXTRACT_DECADE(TYPE) \ + FORCE_INLINE \ + int64 extractDecade##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return (1900 + tm.tm_year) / 10; \ + } + +DATE_TYPES(EXTRACT_DECADE) + +// Extract year. +#define EXTRACT_YEAR(TYPE) \ + FORCE_INLINE \ + int64 extractYear##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return 1900 + tm.tm_year; \ + } + +DATE_TYPES(EXTRACT_YEAR) + +#define EXTRACT_DOY(TYPE) \ + FORCE_INLINE \ + int64 extractDoy##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return 1 + tm.tm_yday; \ + } + +DATE_TYPES(EXTRACT_DOY) + +#define EXTRACT_QUARTER(TYPE) \ + FORCE_INLINE \ + int64 extractQuarter##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return tm.tm_mon / 3 + 1; \ + } + +DATE_TYPES(EXTRACT_QUARTER) + +#define EXTRACT_MONTH(TYPE) \ + FORCE_INLINE \ + int64 extractMonth##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return 1 + tm.tm_mon; \ + } + +DATE_TYPES(EXTRACT_MONTH) + +#define JAN1_WDAY(ptm) ((ptm->tm_wday - (ptm->tm_yday % 7) + 7) % 7) + +bool IsLeapYear(int yy) { + if ((yy % 4) != 0) { + // not divisible by 4 + return false; + } + + // yy = 4x + if ((yy % 400) == 0) { + // yy = 400x + return true; + } + + // yy = 4x, return true if yy != 100x + return ((yy % 100) != 0); +} + +// Day belongs to current year +// Note that tm_yday is 0 for Jan 1 (subtract 1 from day in the below examples) +// +// If Jan 1 is Mon, (ptm->tm_yday) / 7 + 1 (Jan 1->WK1, Jan 8->WK2, etc) +// If Jan 1 is Tues, (ptm->tm_yday + 1) / 7 + 1 (Jan 1->WK1, Jan 7->WK2, etc) +// If Jan 1 is Wed, (ptm->tm_yday + 2) / 7 + 1 +// If Jan 1 is Thu, (ptm->tm_yday + 3) / 7 + 1 +// +// If Jan 1 is Fri, Sat or Sun, the first few days belong to the previous year +// If Jan 1 is Fri, (ptm->tm_yday - 3) / 7 + 1 (Jan 4->WK1, Jan 11->WK2) +// If Jan 1 is Sat, (ptm->tm_yday - 2) / 7 + 1 (Jan 3->WK1, Jan 10->WK2) +// If Jan 1 is Sun, (ptm->tm_yday - 1) / 7 + 1 (Jan 2->WK1, Jan 9->WK2) +int weekOfCurrentYear(struct tm *ptm) { + int jan1_wday = JAN1_WDAY(ptm); + switch (jan1_wday) { + // Monday + case 1: + // Tuesday + case 2: + // Wednesday + case 3: + // Thursday + case 4: { + return (ptm->tm_yday + jan1_wday - 1) / 7 + 1; + } + // Friday + case 5: + // Saturday + case 6: { + return (ptm->tm_yday - (8 - jan1_wday)) / 7 + 1; + } + // Sunday + case 0: { + return (ptm->tm_yday - 1) / 7 + 1; + } + } + + // cannot reach here + // keep compiler happy + return 0; +} + +// Jan 1-3 +// If Jan 1 is one of Mon, Tue, Wed, Thu - belongs to week of current year +// If Jan 1 is Fri/Sat/Sun - belongs to previous year +int getJanWeekOfYear(struct tm *ptm) { + int jan1_wday = JAN1_WDAY(ptm); + + if ((jan1_wday >= 1) && (jan1_wday <= 4)) { + // Jan 1-3 with the week belonging to this year + return 1; + } + + if (jan1_wday == 5) { + // Jan 1 is a Fri + // Jan 1-3 belong to previous year. Dec 31 of previous year same week # as Jan 1-3 + // previous year is a leap year: + // Prev Jan 1 is a Wed. Jan 6th is Mon + // Dec 31 - Jan 6 = 366 - 5 = 361 + // week from Jan 6 = (361 - 1) / 7 + 1 = 52 + // week # in previous year = 52 + 1 = 53 + // + // previous year is not a leap year. Jan 1 is Thu. Jan 5th is Mon + // Dec 31 - Jan 5 = 365 - 4 = 361 + // week from Jan 5 = (361 - 1) / 7 + 1 = 52 + // week # in previous year = 52 + 1 = 53 + return 53; + } + + if (jan1_wday == 0) { + // Jan 1 is a Sun + if (ptm->tm_mday > 1) { + // Jan 2 and 3 belong to current year + return 1; + } + + // day belongs to previous year. Same as Dec 31 + // Same as the case where Jan 1 is a Fri, except that previous year + // does not have an extra week + // Hence, return 52 + return 52; + } + + // Jan 1 is a Sat + // Jan 1-2 belong to previous year + if (ptm->tm_mday == 3) { + // Jan 3, return 1 + return 1; + } + + // prev Jan 1 is leap year + // prev Jan 1 is a Thu + // return 53 (extra week) + if (IsLeapYear(1900 + ptm->tm_year - 1)) { + return 53; + } + + // prev Jan 1 is not a leap year + // prev Jan 1 is a Fri + // return 52 (no extra week) + return 52; +} + +// Dec 29-31 +int getDecWeekOfYear(struct tm *ptm) { + int next_jan1_wday = (ptm->tm_wday + (31 - ptm->tm_mday) + 1) % 7; + + if (next_jan1_wday == 4) { + // next Jan 1 is a Thu + // day belongs to week 1 of next year + return 1; + } + + if (next_jan1_wday == 3) { + // next Jan 1 is a Wed + // Dec 31 and 30 belong to next year - return 1 + if (ptm->tm_mday != 29) { + return 1; + } + + // Dec 29 belongs to current year + return weekOfCurrentYear(ptm); + } + + if (next_jan1_wday == 2) { + // next Jan 1 is a Tue + // Dec 31 belongs to next year - return 1 + if (ptm->tm_mday == 31) { + return 1; + } + + // Dec 29 and 30 belong to current year + return weekOfCurrentYear(ptm); + } + + // next Jan 1 is a Fri/Sat/Sun. No day from this year belongs to that week + // next Jan 1 is a Mon. No day from this year belongs to that week + return weekOfCurrentYear(ptm); +} + +// Week of year is determined by ISO 8601 standard +// Take a look at: https://en.wikipedia.org/wiki/ISO_week_date +// +// Important points to note: +// Week starts with a Monday and ends with a Sunday +// A week can have some days in this year and some days in the previous/next year +// This is true for the first and last weeks +// +// The first week of the year should have at-least 4 days in the current year +// The last week of the year should have at-least 4 days in the current year +// +// A given day might belong to the first week of the next year - e.g Dec 29, 30 and 31 +// A given day might belong to the last week of the previous year - e.g. Jan 1, 2 and 3 +// +// Algorithm: +// If day belongs to week in current year, weekOfCurrentYear +// +// If day is Jan 1-3, see getJanWeekOfYear +// If day is Dec 29-21, see getDecWeekOfYear +// +int64 weekOfYear(struct tm *ptm) { + if (ptm->tm_yday < 3) { + // Jan 1-3 + return getJanWeekOfYear(ptm); + } + + if ((ptm->tm_mon == 11) && (ptm->tm_mday >= 29)) { + // Dec 29-31 + return getDecWeekOfYear(ptm); + } + + return weekOfCurrentYear(ptm); +} + +#define EXTRACT_WEEK(TYPE) \ + FORCE_INLINE \ + int64 extractWeek##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return weekOfYear(&tm); \ + } + +DATE_TYPES(EXTRACT_WEEK) + +#define EXTRACT_DOW(TYPE) \ + FORCE_INLINE \ + int64 extractDow##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return 1 + tm.tm_wday; \ + } + +DATE_TYPES(EXTRACT_DOW) + +#define EXTRACT_DAY(TYPE) \ + FORCE_INLINE \ + int64 extractDay##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return tm.tm_mday; \ + } + +DATE_TYPES(EXTRACT_DAY) + +#define EXTRACT_HOUR(TYPE) \ + FORCE_INLINE \ + int64 extractHour##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return tm.tm_hour; \ + } + +DATE_TYPES(EXTRACT_HOUR) + +#define EXTRACT_MINUTE(TYPE) \ + FORCE_INLINE \ + int64 extractMinute##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return tm.tm_min; \ + } + +DATE_TYPES(EXTRACT_MINUTE) + +#define EXTRACT_SECOND(TYPE) \ + FORCE_INLINE \ + int64 extractSecond##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + return tm.tm_sec; \ + } + +DATE_TYPES(EXTRACT_SECOND) + +#define EXTRACT_EPOCH(TYPE) \ + FORCE_INLINE \ + int64 extractEpoch##_##TYPE(TYPE millis) { return MILLIS_TO_SEC(millis); } + +DATE_TYPES(EXTRACT_EPOCH) + +// Functions that work on millis in a day +#define EXTRACT_SECOND_TIME(TYPE) \ + FORCE_INLINE \ + int64 extractSecond##_##TYPE(TYPE millis) { \ + int64 seconds_of_day = MILLIS_TO_SEC(millis); \ + int64 sec = seconds_of_day % SECONDS_IN_MINUTE; \ + return sec; \ + } + +EXTRACT_SECOND_TIME(time32) + +#define EXTRACT_MINUTE_TIME(TYPE) \ + FORCE_INLINE \ + int64 extractMinute##_##TYPE(TYPE millis) { \ + TYPE mins = MILLIS_TO_MINS(millis); \ + return (mins % (MINS_IN_HOUR)); \ + } + +EXTRACT_MINUTE_TIME(time32) + +#define EXTRACT_HOUR_TIME(TYPE) \ + FORCE_INLINE \ + int64 extractHour##_##TYPE(TYPE millis) { return MILLIS_TO_HOUR(millis); } + +EXTRACT_HOUR_TIME(time32) + +#define DATE_TRUNC_FIXED_UNIT(NAME, TYPE, NMILLIS_IN_UNIT) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE(TYPE millis) { \ + return ((millis / NMILLIS_IN_UNIT) * NMILLIS_IN_UNIT); \ + } + +#define DATE_TRUNC_WEEK(TYPE) \ + FORCE_INLINE \ + TYPE date_trunc_Week_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + tm.tm_sec = 0; \ + tm.tm_min = 0; \ + tm.tm_hour = 0; \ + if (tm.tm_wday == 0) { \ + /* Sunday */ \ + tm.tm_mday -= 6; \ + } else { \ + /* All other days */ \ + tm.tm_mday -= (tm.tm_wday - 1); \ + } \ + return (TYPE)timegm(&tm) * MILLIS_IN_SEC; \ + } + +#define DATE_TRUNC_MONTH_UNITS(NAME, TYPE, NMONTHS_IN_UNIT) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + tm.tm_sec = 0; \ + tm.tm_min = 0; \ + tm.tm_hour = 0; \ + tm.tm_mday = 1; \ + tm.tm_mon = (tm.tm_mon / NMONTHS_IN_UNIT) * NMONTHS_IN_UNIT; \ + return (TYPE)timegm(&tm) * MILLIS_IN_SEC; \ + } + +#define DATE_TRUNC_YEAR_UNITS(NAME, TYPE, NYEARS_IN_UNIT, OFF_BY) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE(TYPE millis) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + tm.tm_sec = 0; \ + tm.tm_min = 0; \ + tm.tm_hour = 0; \ + tm.tm_mday = 1; \ + tm.tm_mon = 0; \ + int year = 1900 + tm.tm_year; \ + year = ((year - OFF_BY) / NYEARS_IN_UNIT) * NYEARS_IN_UNIT + OFF_BY; \ + tm.tm_year = year - 1900; \ + return (TYPE)timegm(&tm) * MILLIS_IN_SEC; \ + } + +#define DATE_TRUNC_FUNCTIONS(TYPE) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Second, TYPE, MILLIS_IN_SEC) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Minute, TYPE, MILLIS_IN_MIN) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Hour, TYPE, MILLIS_IN_HOUR) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Day, TYPE, MILLIS_IN_DAY) \ + DATE_TRUNC_WEEK(TYPE) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Month, TYPE, 1) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Quarter, TYPE, 3) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Year, TYPE, 12) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Decade, TYPE, 10, 0) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Century, TYPE, 100, 1) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Millennium, TYPE, 1000, 1) + +DATE_TRUNC_FUNCTIONS(date64) +DATE_TRUNC_FUNCTIONS(timestamp) + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/time_constants.h b/cpp/src/gandiva/precompiled/time_constants.h new file mode 100644 index 00000000000..8afd4ae8a9c --- /dev/null +++ b/cpp/src/gandiva/precompiled/time_constants.h @@ -0,0 +1,30 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TIME_CONSTANTS_H +#define TIME_CONSTANTS_H + +#define MILLIS_IN_SEC (1000) +#define MILLIS_IN_MIN (60 * MILLIS_IN_SEC) +#define MILLIS_IN_HOUR (60 * MILLIS_IN_MIN) +#define MILLIS_IN_DAY (24 * MILLIS_IN_HOUR) +#define MILLIS_IN_WEEK (7 * MILLIS_IN_DAY) + +#define MILLIS_TO_SEC(millis) ((millis) / MILLIS_IN_SEC) +#define MILLIS_TO_MINS(millis) ((millis) / MILLIS_IN_MIN) +#define MILLIS_TO_HOUR(millis) ((millis) / MILLIS_IN_HOUR) +#define MILLIS_TO_DAY(millis) ((millis) / MILLIS_IN_DAY) +#define MILLIS_TO_WEEK(millis) ((millis) / MILLIS_IN_WEEK) + +#endif // TIME_CONSTANTS_H diff --git a/cpp/src/gandiva/precompiled/time_test.cc b/cpp/src/gandiva/precompiled/time_test.cc new file mode 100644 index 00000000000..7f9fb0f6e0c --- /dev/null +++ b/cpp/src/gandiva/precompiled/time_test.cc @@ -0,0 +1,493 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "precompiled/types.h" + +namespace gandiva { + +timestamp StringToTimestamp(const char *buf) { + struct tm tm; + strptime(buf, "%Y-%m-%d %H:%M:%S", &tm); + return timegm(&tm) * 1000; // to millis +} + +TEST(TestTime, TestExtractTime) { + // 10:20:33 + int32 time_as_millis_in_day = 37233000; + + EXPECT_EQ(extractHour_time32(time_as_millis_in_day), 10); + EXPECT_EQ(extractMinute_time32(time_as_millis_in_day), 20); + EXPECT_EQ(extractSecond_time32(time_as_millis_in_day), 33); +} + +TEST(TestTime, TestExtractTimestamp) { + timestamp ts = StringToTimestamp("1970-05-02 10:20:33"); + + EXPECT_EQ(extractMillennium_timestamp(ts), 2); + EXPECT_EQ(extractCentury_timestamp(ts), 20); + EXPECT_EQ(extractDecade_timestamp(ts), 197); + EXPECT_EQ(extractYear_timestamp(ts), 1970); + EXPECT_EQ(extractDoy_timestamp(ts), 122); + EXPECT_EQ(extractMonth_timestamp(ts), 5); + EXPECT_EQ(extractDow_timestamp(ts), 7); + EXPECT_EQ(extractDay_timestamp(ts), 2); + EXPECT_EQ(extractHour_timestamp(ts), 10); + EXPECT_EQ(extractMinute_timestamp(ts), 20); + EXPECT_EQ(extractSecond_timestamp(ts), 33); +} + +TEST(TestTime, TimeStampTrunc) { + EXPECT_EQ(date_trunc_Second_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:20:34")); + EXPECT_EQ(date_trunc_Minute_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:20:00")); + EXPECT_EQ(date_trunc_Hour_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:00:00")); + EXPECT_EQ(date_trunc_Day_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 00:00:00")); + EXPECT_EQ(date_trunc_Month_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-01 00:00:00")); + EXPECT_EQ(date_trunc_Quarter_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-04-01 00:00:00")); + EXPECT_EQ(date_trunc_Year_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Decade_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2010-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Century_date64(StringToTimestamp("2115-05-05 10:20:34")), + StringToTimestamp("2101-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Millennium_date64(StringToTimestamp("2115-05-05 10:20:34")), + StringToTimestamp("2001-01-01 00:00:00")); + + // truncate week going to previous year + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-01 10:10:10")), + StringToTimestamp("2010-12-27 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-02 10:10:10")), + StringToTimestamp("2010-12-27 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-03 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-04 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-05 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-06 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-07 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-08 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-09 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + + // truncate week for Feb in a leap year + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-28 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-29 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-01 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-02 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-03 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-04 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-05 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-06 10:10:10")), + StringToTimestamp("2000-03-06 00:00:00")); +} + +TEST(TestTime, TimeStampAdd) { + EXPECT_EQ( + timestampaddSecond_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 30), + StringToTimestamp("2000-05-01 10:21:04")); + + EXPECT_EQ(timestampaddMinute_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), + (int64)-30), + StringToTimestamp("2000-05-01 09:50:34")); + + EXPECT_EQ( + timestampaddHour_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 20), + StringToTimestamp("2000-05-02 06:20:34")); + + EXPECT_EQ(timestampaddDay_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), + (int64)-35), + StringToTimestamp("2000-03-27 10:20:34")); + + EXPECT_EQ(timestampaddWeek_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 4), + StringToTimestamp("2000-05-29 10:20:34")); + + EXPECT_EQ(timestampaddMonth_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), + (int64)10), + StringToTimestamp("2001-03-01 10:20:34")); + + EXPECT_EQ( + timestampaddQuarter_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), -2), + StringToTimestamp("1999-11-01 10:20:34")); + + EXPECT_EQ(timestampaddYear_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), + (int64)2), + StringToTimestamp("2002-05-01 10:20:34")); + + // date_add + EXPECT_EQ(date_add_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), 7), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ(add_int32_timestamp(4, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-05 00:00:00")); + + EXPECT_EQ(add_timestamp_int64(StringToTimestamp("2000-05-01 00:00:00"), (int64)7), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ(date_add_int64_timestamp((int64)4, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-05 00:00:00")); + + EXPECT_EQ(date_add_int64_timestamp((int64)4, StringToTimestamp("2000-02-27 00:00:00")), + StringToTimestamp("2000-03-02 00:00:00")); + + // date_sub + EXPECT_EQ(date_sub_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), 7), + StringToTimestamp("2000-04-24 00:00:00")); + + EXPECT_EQ(subtract_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), -7), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ( + date_diff_timestamp_int64(StringToTimestamp("2000-05-01 00:00:00"), (int64)365), + StringToTimestamp("1999-05-02 00:00:00")); + + EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-03-01 00:00:00"), (int64)1), + StringToTimestamp("2000-02-29 00:00:00")); + + EXPECT_EQ( + date_diff_timestamp_int64(StringToTimestamp("2000-02-29 00:00:00"), (int64)365), + StringToTimestamp("1999-03-01 00:00:00")); +} + +// test cases from http://www.staff.science.uu.nl/~gent0113/calendar/isocalendar.htm +TEST(TestTime, TestExtractWeek) { + std::vector data; + + // A type + // Jan 1, 2 and 3 + data.push_back("2006-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2006-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2006-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2006-04-24 10:10:10"); + data.push_back("17"); + data.push_back("2006-04-30 10:10:10"); + data.push_back("17"); + // Dec 29-31 + data.push_back("2006-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2006-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2006-12-31 10:10:10"); + data.push_back("52"); + // B(C) type + // Jan 1, 2 and 3 + data.push_back("2011-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2011-01-02 10:10:10"); + data.push_back("52"); + data.push_back("2011-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2011-07-18 10:10:10"); + data.push_back("29"); + data.push_back("2011-07-24 10:10:10"); + data.push_back("29"); + // Dec 29-31 + data.push_back("2011-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2011-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2011-12-31 10:10:10"); + data.push_back("52"); + // B(DC) type + // Jan 1, 2 and 3 + data.push_back("2005-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2005-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2005-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2005-11-07 10:10:10"); + data.push_back("45"); + data.push_back("2005-11-13 10:10:10"); + data.push_back("45"); + // Dec 29-31 + data.push_back("2005-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2005-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2005-12-31 10:10:10"); + data.push_back("52"); + // C type + // Jan 1, 2 and 3 + data.push_back("2010-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2010-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2010-01-03 10:10:10"); + data.push_back("53"); + // middle, Monday and Sunday + data.push_back("2010-09-13 10:10:10"); + data.push_back("37"); + data.push_back("2010-09-19 10:10:10"); + data.push_back("37"); + // Dec 29-31 + data.push_back("2010-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2010-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2010-12-31 10:10:10"); + data.push_back("52"); + // D type + // Jan 1, 2 and 3 + data.push_back("2037-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2037-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2037-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2037-08-17 10:10:10"); + data.push_back("34"); + data.push_back("2037-08-23 10:10:10"); + data.push_back("34"); + // Dec 29-31 + data.push_back("2037-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2037-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2037-12-31 10:10:10"); + data.push_back("53"); + // E type + // Jan 1, 2 and 3 + data.push_back("2014-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2014-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2014-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2014-01-13 10:10:10"); + data.push_back("3"); + data.push_back("2014-01-19 10:10:10"); + data.push_back("3"); + // Dec 29-31 + data.push_back("2014-12-29 10:10:10"); + data.push_back("1"); + data.push_back("2014-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2014-12-31 10:10:10"); + data.push_back("1"); + // F type + // Jan 1, 2 and 3 + data.push_back("2019-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2019-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2019-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2019-02-11 10:10:10"); + data.push_back("7"); + data.push_back("2019-02-17 10:10:10"); + data.push_back("7"); + // Dec 29-31 + data.push_back("2019-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2019-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2019-12-31 10:10:10"); + data.push_back("1"); + // G type + // Jan 1, 2 and 3 + data.push_back("2001-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2001-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2001-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2001-03-19 10:10:10"); + data.push_back("12"); + data.push_back("2001-03-25 10:10:10"); + data.push_back("12"); + // Dec 29-31 + data.push_back("2001-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2001-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2001-12-31 10:10:10"); + data.push_back("1"); + // AG type + // Jan 1, 2 and 3 + data.push_back("2012-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2012-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2012-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2012-04-02 10:10:10"); + data.push_back("14"); + data.push_back("2012-04-08 10:10:10"); + data.push_back("14"); + // Dec 29-31 + data.push_back("2012-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2012-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2012-12-31 10:10:10"); + data.push_back("1"); + // BA type + // Jan 1, 2 and 3 + data.push_back("2000-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2000-01-02 10:10:10"); + data.push_back("52"); + data.push_back("2000-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2000-05-22 10:10:10"); + data.push_back("21"); + data.push_back("2000-05-28 10:10:10"); + data.push_back("21"); + // Dec 29-31 + data.push_back("2000-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2000-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2000-12-31 10:10:10"); + data.push_back("52"); + // CB type + // Jan 1, 2 and 3 + data.push_back("2016-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2016-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2016-01-03 10:10:10"); + data.push_back("53"); + // middle, Monday and Sunday + data.push_back("2016-06-20 10:10:10"); + data.push_back("25"); + data.push_back("2016-06-26 10:10:10"); + data.push_back("25"); + // Dec 29-31 + data.push_back("2016-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2016-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2016-12-31 10:10:10"); + data.push_back("52"); + // DC type + // Jan 1, 2 and 3 + data.push_back("2004-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2004-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2004-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2004-07-19 10:10:10"); + data.push_back("30"); + data.push_back("2004-07-25 10:10:10"); + data.push_back("30"); + // Dec 29-31 + data.push_back("2004-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2004-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2004-12-31 10:10:10"); + data.push_back("53"); + // ED type + // Jan 1, 2 and 3 + data.push_back("2020-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2020-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2020-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2020-08-17 10:10:10"); + data.push_back("34"); + data.push_back("2020-08-23 10:10:10"); + data.push_back("34"); + // Dec 29-31 + data.push_back("2020-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2020-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2020-12-31 10:10:10"); + data.push_back("53"); + // FE type + // Jan 1, 2 and 3 + data.push_back("2008-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2008-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2008-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2008-09-15 10:10:10"); + data.push_back("38"); + data.push_back("2008-09-21 10:10:10"); + data.push_back("38"); + // Dec 29-31 + data.push_back("2008-12-29 10:10:10"); + data.push_back("1"); + data.push_back("2008-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2008-12-31 10:10:10"); + data.push_back("1"); + // GF type + // Jan 1, 2 and 3 + data.push_back("2024-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2024-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2024-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2024-10-07 10:10:10"); + data.push_back("41"); + data.push_back("2024-10-13 10:10:10"); + data.push_back("41"); + // Dec 29-31 + data.push_back("2024-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2024-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2024-12-31 10:10:10"); + data.push_back("1"); + + for (uint32_t i = 0; i < data.size(); i += 2) { + timestamp ts = StringToTimestamp(data.at(i).c_str()); + int64 exp = atol(data.at(i + 1).c_str()); + EXPECT_EQ(extractWeek_timestamp(ts), exp); + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc b/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc new file mode 100644 index 00000000000..aee2031ddf7 --- /dev/null +++ b/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc @@ -0,0 +1,214 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern "C" { + +#include +#include "./time_constants.h" +#include "./types.h" + +#define TIMESTAMP_DIFF_FIXED_UNITS(TYPE, NAME, FROM_MILLIS) \ + FORCE_INLINE \ + int32 NAME##_##TYPE##_##TYPE(TYPE start_millis, TYPE end_millis) { \ + return FROM_MILLIS(end_millis - start_millis); \ + } + +#define SIGN_ADJUST_DIFF(is_positive, diff) ((is_positive) ? (diff) : -(diff)) +#define MONTHS_TO_TIMEUNIT(diff, num_months) (diff) / (num_months) + +// Assuming end_millis > start_millis, the algorithm to find the diff in months is: +// diff_in_months = year_diff * 12 + month_diff +// This is approximately correct, except when the last month has not fully elapsed +// +// a) If end_day > start_day, return diff_in_months e.g. diff(2015-09-10, 2017-03-31) +// b) If end_day < start_day, return diff_in_months - 1 e.g. diff(2015-09-30, 2017-03-10) +// c) If end_day = start_day, check for millis e.g. diff(2017-03-10, 2015-03-10) +// Need to check if end_millis_in_day > start_millis_in_day +// c1) If end_millis_in_day >= start_millis_in_day, return diff_in_months +// c2) else return diff_in_months - 1 +#define TIMESTAMP_DIFF_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + int32 NAME##_##TYPE##_##TYPE(TYPE start_millis, TYPE end_millis) { \ + int32 diff; \ + bool is_positive = (end_millis > start_millis); \ + if (!is_positive) { \ + /* if end_millis < start_millis, swap and multiply by -1 at the end */ \ + TYPE tmp = start_millis; \ + start_millis = end_millis; \ + end_millis = tmp; \ + } \ + time_t start_tsec = (time_t)MILLIS_TO_SEC(start_millis); \ + struct tm start_tm; \ + gmtime_r(&start_tsec, &start_tm); \ + time_t end_tsec = (time_t)MILLIS_TO_SEC(end_millis); \ + struct tm end_tm; \ + gmtime_r(&end_tsec, &end_tm); \ + int32 months_diff; \ + months_diff = \ + 12 * (end_tm.tm_year - start_tm.tm_year) + (end_tm.tm_mon - start_tm.tm_mon); \ + if (end_tm.tm_mday > start_tm.tm_mday) { \ + /* case a */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + if (end_tm.tm_mday < start_tm.tm_mday) { \ + /* case b */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + int32 end_day_millis = \ + end_tm.tm_hour * MILLIS_IN_HOUR + end_tm.tm_min * MILLIS_IN_MIN + end_tm.tm_sec; \ + int32 start_day_millis = start_tm.tm_hour * MILLIS_IN_HOUR + \ + start_tm.tm_min * MILLIS_IN_MIN + start_tm.tm_sec; \ + if (end_day_millis >= start_day_millis) { \ + /* case c1 */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + /* case c2 */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } + +#define TIMESTAMP_DIFF(TYPE) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffSecond, MILLIS_TO_SEC) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffMinute, MILLIS_TO_MINS) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffHour, MILLIS_TO_HOUR) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffDay, MILLIS_TO_DAY) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffWeek, MILLIS_TO_WEEK) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffMonth, 1) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffQuarter, 3) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffYear, 12) + +TIMESTAMP_DIFF(timestamp) + +#define ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE##_int32(TYPE millis, int32 count) { \ + return millis + TO_MILLIS * (TYPE)count; \ + } + +// Documentation of mktime suggests that it handles +// tm_mon being negative, and also tm_mon being >= 12 by +// adjusting tm_year accordingly +// +// Using gmtime_r() and timegm() instead of localtime_r() and mktime() +// since the input millis are since epoch +#define ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE##_int32(TYPE millis, int32 count) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + tm.tm_mon += count * N_MONTHS; \ + return (TYPE)timegm(&tm) * MILLIS_IN_SEC; \ + } + +// TODO: Handle overflow while converting int64 to millis +#define ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE##_int64(TYPE millis, int64 count) { \ + return millis + TO_MILLIS * (TYPE)count; \ + } + +#define ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + TYPE NAME##_##TYPE##_int64(TYPE millis, int64 count) { \ + time_t tsec = (time_t)MILLIS_TO_SEC(millis); \ + struct tm tm; \ + gmtime_r(&tsec, &tm); \ + tm.tm_mon += count * N_MONTHS; \ + return (TYPE)timegm(&tm) * MILLIS_IN_SEC; \ + } + +#define TIMESTAMP_ADD_INT32(TYPE) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \ + ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddMonth, 1) \ + ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddQuarter, 3) \ + ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddYear, 12) + +#define TIMESTAMP_ADD_INT64(TYPE) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \ + ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddMonth, 1) \ + ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddQuarter, 3) \ + ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, timestampaddYear, 12) + +#define TIMESTAMP_ADD_INT(TYPE) \ + TIMESTAMP_ADD_INT32(TYPE) \ + TIMESTAMP_ADD_INT64(TYPE) + +TIMESTAMP_ADD_INT(date64) +TIMESTAMP_ADD_INT(timestamp) + +// add int32 to timestamp +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// add int64 to timestamp +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// date_sub, subtract, date_diff on int32 +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY) + +// date_sub, subtract, date_diff on int64 +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY) + +#define ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + TYPE NAME##_int32_##TYPE(int32 count, TYPE millis) { \ + return millis + TO_MILLIS * (TYPE)count; \ + } + +#define ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + TYPE NAME##_int64_##TYPE(int64 count, TYPE millis) { \ + return millis + TO_MILLIS * (TYPE)count; \ + } + +// add timestamp to int32 +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// add timestamp to int64 +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +} // extern "C" diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h new file mode 100644 index 00000000000..cd97dc745ef --- /dev/null +++ b/cpp/src/gandiva/precompiled/types.h @@ -0,0 +1,124 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PRECOMPILED_TYPES_H +#define PRECOMPILED_TYPES_H + +#include + +// Use the same names as in arrow data types. Makes it easy to write pre-processor macros. +using boolean = bool; +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; +using float32 = float; +using float64 = double; +using date64 = int64_t; +using time32 = int32_t; +using timestamp = int64_t; +using utf8 = char *; +using binary = char *; + +#ifdef GANDIVA_UNIT_TEST +// unit tests may be compiled without O2, so inlining may not happen. +#define FORCE_INLINE +#else +#define FORCE_INLINE __attribute__((always_inline)) +#endif + +// Declarations : used in testing + +extern "C" { + +bool bitMapGetBit(const unsigned char *bmap, int position); +void bitMapSetBit(unsigned char *bmap, int position, bool value); +void bitMapClearBitIfFalse(unsigned char *bmap, int position, bool value); + +int64 extractMillennium_timestamp(timestamp millis); +int64 extractCentury_timestamp(timestamp millis); +int64 extractDecade_timestamp(timestamp millis); +int64 extractYear_timestamp(timestamp millis); +int64 extractDoy_timestamp(timestamp millis); +int64 extractQuarter_timestamp(timestamp millis); +int64 extractMonth_timestamp(timestamp millis); +int64 extractWeek_timestamp(timestamp millis); +int64 extractDow_timestamp(timestamp millis); +int64 extractDay_timestamp(timestamp millis); +int64 extractHour_timestamp(timestamp millis); +int64 extractMinute_timestamp(timestamp millis); +int64 extractSecond_timestamp(timestamp millis); +int64 extractHour_time32(int32 millis_in_day); +int64 extractMinute_time32(int32 millis_in_day); +int64 extractSecond_time32(int32 millis_in_day); + +int32 hash32(double val, int32 seed); +int32 hash32_buf(const uint8 *buf, int len, int32 seed); +int64 hash64(double val, int64 seed); +int64 hash64_buf(const uint8 *buf, int len, int64 seed); + +int64 timestampaddSecond_timestamp_int32(timestamp, int32); +int64 timestampaddMinute_timestamp_int32(timestamp, int32); +int64 timestampaddHour_timestamp_int32(timestamp, int32); +int64 timestampaddDay_timestamp_int32(timestamp, int32); +int64 timestampaddWeek_timestamp_int32(timestamp, int32); +int64 timestampaddMonth_timestamp_int32(timestamp, int32); +int64 timestampaddQuarter_timestamp_int32(timestamp, int32); +int64 timestampaddYear_timestamp_int32(timestamp, int32); + +int64 timestampaddSecond_timestamp_int64(timestamp, int64); +int64 timestampaddMinute_timestamp_int64(timestamp, int64); +int64 timestampaddHour_timestamp_int64(timestamp, int64); +int64 timestampaddDay_timestamp_int64(timestamp, int64); +int64 timestampaddWeek_timestamp_int64(timestamp, int64); +int64 timestampaddMonth_timestamp_int64(timestamp, int64); +int64 timestampaddQuarter_timestamp_int64(timestamp, int64); +int64 timestampaddYear_timestamp_int64(timestamp, int64); + +int64 date_add_timestamp_int32(timestamp, int32); +int64 add_timestamp_int64(timestamp, int64); +int64 add_int32_timestamp(int32, timestamp); +int64 date_add_int64_timestamp(int64, timestamp); + +int64 date_sub_timestamp_int32(timestamp, int32); +int64 subtract_timestamp_int32(timestamp, int32); +int64 date_diff_timestamp_int64(timestamp, int64); + +bool is_distinct_from_timestamp_timestamp(int64, bool, int64, bool); +bool is_not_distinct_from_int32_int32(int32, bool, int32, bool); + +int64 date_trunc_Second_date64(date64); +int64 date_trunc_Minute_date64(date64); +int64 date_trunc_Hour_date64(date64); +int64 date_trunc_Day_date64(date64); +int64 date_trunc_Month_date64(date64); +int64 date_trunc_Quarter_date64(date64); +int64 date_trunc_Year_date64(date64); +int64 date_trunc_Decade_date64(date64); +int64 date_trunc_Century_date64(date64); +int64 date_trunc_Millennium_date64(date64); + +int64 date_trunc_Week_timestamp(timestamp); + +int32 mem_compare(const char *left, int32 left_len, const char *right, int32 right_len); + +int32 mod_int64_int32(int64 left, int32 right); + +} // extern "C" + +#endif // PRECOMPILED_TYPES_H