diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 99c23f99cd9..0ae5a193f53 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -224,6 +224,7 @@ add_gandiva_test(internals-test like_holder_test.cc decimal_type_util_test.cc random_generator_holder_test.cc + gdv_function_stubs_test.cc EXTRA_DEPENDENCIES LLVM::LLVM_INTERFACE EXTRA_INCLUDES diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 2c71126aafe..ea3af5b45c9 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -16,6 +16,7 @@ // under the License. #include "gandiva/function_registry_string.h" + #include "gandiva/function_registry_common.h" namespace gandiva { @@ -61,17 +62,26 @@ std::vector GetStringFunctionRegistry() { UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}), UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}), - UNARY_UNSAFE_NULL_IF_NULL(castINT, {}, utf8, int32), - UNARY_UNSAFE_NULL_IF_NULL(castBIGINT, {}, utf8, int64), - UNARY_UNSAFE_NULL_IF_NULL(castFLOAT4, {}, utf8, float32), - UNARY_UNSAFE_NULL_IF_NULL(castFLOAT8, {}, utf8, float64), - NativeFunction("upper", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "upper_utf8", NativeFunction::kNeedsContext), NativeFunction("lower", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "lower_utf8", NativeFunction::kNeedsContext), + NativeFunction("castINT", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull, + "gdv_fn_castINT_utf8", NativeFunction::kNeedsContext), + + NativeFunction("castBIGINT", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull, + "gdv_fn_castBIGINT_utf8", NativeFunction::kNeedsContext), + + NativeFunction("castFLOAT4", {}, DataTypeVector{utf8()}, float32(), + kResultNullIfNull, "gdv_fn_castFLOAT4_utf8", + NativeFunction::kNeedsContext), + + NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(), + kResultNullIfNull, "gdv_fn_castFLOAT8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(), kResultNullIfNull, "castVARCHAR_utf8_int64", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index ad3036f96b5..ad93ce8c412 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/util/value_parsing.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -150,6 +151,37 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, memcpy(ret, dec_str.data(), *dec_str_len); return ret; } + +#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ + GANDIVA_EXPORT \ + OUT_TYPE gdv_fn_cast##TYPE_NAME##_utf8(int64_t context, const char* data, \ + int32_t len) { \ + OUT_TYPE val = 0; \ + /* trim leading and trailing spaces */ \ + int32_t trimmed_len; \ + int32_t start = 0, end = len - 1; \ + while (start <= end && data[start] == ' ') { \ + ++start; \ + } \ + while (end >= start && data[end] == ' ') { \ + --end; \ + } \ + trimmed_len = end - start + 1; \ + const char* trimmed_data = data + start; \ + if (!arrow::internal::ParseValue(trimmed_data, trimmed_len, &val)) { \ + std::string err = \ + "Failed to cast the string " + std::string(data, len) + " to " #OUT_TYPE; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + } \ + return val; \ + } + +CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT) +CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT) +CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4) +CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8) + +#undef CAST_NUMERIC_FROM_STRING } namespace gandiva { @@ -277,6 +309,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { args = {types->i64_type(), types->i32_type(), types->i1_type()}; engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args, reinterpret_cast(gdv_fn_random_with_seed)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_utf8", types->i32_type(), args, + reinterpret_cast(gdv_fn_castINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_utf8", types->i64_type(), args, + reinterpret_cast(gdv_fn_castBIGINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_utf8", types->float_type(), args, + reinterpret_cast(gdv_fn_castFLOAT4_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args, + reinterpret_cast(gdv_fn_castFLOAT8_utf8)); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 4d66aa3e987..457f42511cc 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -19,6 +19,8 @@ #include +#include "gandiva/visibility.h" + /// Stub functions that can be accessed from LLVM. extern "C" { @@ -52,4 +54,16 @@ int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_lengt char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, int32_t x_scale, int32_t* dec_str_len); + +GANDIVA_EXPORT +int32_t gdv_fn_castINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len); } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc new file mode 100644 index 00000000000..90ac1dfa540 --- /dev/null +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/gdv_function_stubs.h" + +#include +#include + +#include "gandiva/execution_context.h" + +namespace gandiva { + +TEST(TestGdvFnStubs, TestCastINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castINT_utf8(ctx_ptr, "2147483648", 10); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-2147483649", 11); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int32")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastBIGINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int64")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat4) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to float")); + ctx.Reset(); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to float")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to double")); + ctx.Reset(); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to double")); + ctx.Reset(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 34dd011ffb3..0432d6c761c 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -23,6 +23,7 @@ extern "C" { #include #include #include + #include "./types.h" FORCE_INLINE @@ -1439,27 +1440,4 @@ const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_le return ret; } -#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ - FORCE_INLINE \ - gdv_##OUT_TYPE cast##TYPE_NAME##_utf8(int64_t context, const char* data, \ - int32_t len) { \ - gdv_##OUT_TYPE val = 0; \ - int32_t trimmed_len; \ - data = btrim_utf8(context, data, len, &trimmed_len); \ - if (!arrow::internal::ParseValue(data, trimmed_len, &val)) { \ - std::string err = "Failed to cast the string " + std::string(data, trimmed_len) + \ - " to " #OUT_TYPE; \ - gdv_fn_context_set_error_msg(context, err.c_str()); \ - } \ - return val; \ - } - -CAST_NUMERIC_FROM_STRING(int32, arrow::Int32Type, INT) -CAST_NUMERIC_FROM_STRING(int64, arrow::Int64Type, BIGINT) -CAST_NUMERIC_FROM_STRING(float32, arrow::FloatType, FLOAT4) -CAST_NUMERIC_FROM_STRING(float64, arrow::DoubleType, FLOAT8) - -#undef CAST_INT_FROM_STRING -#undef CAST_FLOAT_FROM_STRING - } // extern "C" diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index 9bb44af9a1b..b1836d877ab 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -17,6 +17,7 @@ #include #include + #include "gandiva/execution_context.h" #include "gandiva/precompiled/types.h" @@ -1002,138 +1003,4 @@ TEST(TestStringOps, TestSplitPart) { EXPECT_EQ(std::string(out_str, out_len), "ååçåå"); } -TEST(TestArithmeticOps, TestCastINT) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castINT_utf8(ctx_ptr, "-45", 3), -45); - EXPECT_EQ(castINT_utf8(ctx_ptr, "0", 1), 0); - EXPECT_EQ(castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647); - EXPECT_EQ(castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647); - EXPECT_EQ(castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL); - EXPECT_EQ(castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL); - EXPECT_EQ(castINT_utf8(ctx_ptr, " 12 ", 4), 12); - - castINT_utf8(ctx_ptr, "2147483648", 10); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "-2147483649", 11); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "12.34", 5); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "abc", 3); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string abc to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to int32")); - ctx.Reset(); - - castINT_utf8(ctx_ptr, "-", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string - to int32")); - ctx.Reset(); -} - -TEST(TestArithmeticOps, TestCastBIGINT) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-45", 3), -45); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "0", 1), 0); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19), 9223372036854775807LL); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20), 9223372036854775807LL); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20), - -9223372036854775807LL - 1); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22), - -9223372036854775807LL - 1); - EXPECT_EQ(castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12); - - castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19); - EXPECT_THAT( - ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20); - EXPECT_THAT( - ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "12.34", 5); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "abc", 3); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string abc to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to int64")); - ctx.Reset(); - - castBIGINT_utf8(ctx_ptr, "-", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string - to int64")); - ctx.Reset(); -} - -TEST(TestArithmeticOps, TestCastFloat4) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f); - EXPECT_EQ(castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f); - - castFLOAT4_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to float32")); - ctx.Reset(); - - castFLOAT4_utf8(ctx_ptr, "e", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string e to float32")); - ctx.Reset(); -} - -TEST(TestParseStringHolder, TestCastFloat8) { - gandiva::ExecutionContext ctx; - - int64_t ctx_ptr = reinterpret_cast(&ctx); - - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0); - EXPECT_EQ(castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4); - - castFLOAT8_utf8(ctx_ptr, "", 0); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string to float64")); - ctx.Reset(); - - castFLOAT8_utf8(ctx_ptr, "e", 1); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Failed to cast the string e to float64")); - ctx.Reset(); -} - } // namespace gandiva diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 753cdf6a10a..85ac83b42da 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -1741,4 +1741,191 @@ public void testCaseInsensitiveFunctions() throws Exception { releaseValueVectors(output); } + @Test + public void testCastInt() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "0", "123", "-123", "-1", "1" + }; + int[] expValues = + new int[] { + 0, 123, -123, -1, 1 + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + IntVector intVector = (IntVector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(intVector.isNull(j)); + assertTrue(expValues[j] == intVector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastIntInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 1; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "abc" + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } + + @Test + public void testCastFloat() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "0", + "111", + "12345.67" + }; + double[] expValues = + new double[] { + 2.3, -11.11, 0, 111, 12345.67 + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + Float8Vector float8Vector = (Float8Vector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(float8Vector.isNull(j)); + assertTrue(expValues[j] == float8Vector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastFloatInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "abc", + "111", + "12345.67" + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } }