From ee60c0098e0999f5d01b383fd6fc238fb80a9b3c Mon Sep 17 00:00:00 2001 From: Pindikura Ravindra Date: Wed, 13 Feb 2019 17:59:14 +0530 Subject: [PATCH] ARROW-4204: [Gandiva] add support for decimal subtract --- cpp/src/gandiva/decimal_ir.cc | 51 +++++++++++++++- cpp/src/gandiva/decimal_ir.h | 3 + .../gandiva/function_registry_arithmetic.cc | 1 + cpp/src/gandiva/precompiled/decimal_ops.cc | 5 ++ cpp/src/gandiva/precompiled/decimal_ops.h | 5 ++ .../gandiva/precompiled/decimal_ops_test.cc | 59 +++++++++++++++++-- cpp/src/gandiva/tests/decimal_single_test.cc | 54 +++++++++++++---- 7 files changed, 160 insertions(+), 18 deletions(-) diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc index d10158a6f04..f51f51262dc 100644 --- a/cpp/src/gandiva/decimal_ir.cc +++ b/cpp/src/gandiva/decimal_ir.cc @@ -307,6 +307,52 @@ Status DecimalIR::BuildAdd() { return Status::OK(); } +Status DecimalIR::BuildSubtract() { + // Create fn prototype : + // int128_t + // subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t + // x_scale, + // int128_t y_value, int32_t y_precision, int32_t y_scale + // int32_t out_precision, int32_t out_scale) + auto i32 = types()->i32_type(); + auto i128 = types()->i128_type(); + auto function = BuildFunction("subtract_decimal128_decimal128", i128, + { + {"x_value", i128}, + {"x_precision", i32}, + {"x_scale", i32}, + {"y_value", i128}, + {"y_precision", i32}, + {"y_scale", i32}, + {"out_precision", i32}, + {"out_scale", i32}, + }); + + auto entry = llvm::BasicBlock::Create(*context(), "entry", function); + ir_builder()->SetInsertPoint(entry); + + // reuse add function after negating y_value. i.e + // add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale, + // out_precision, out_scale) + std::vector args; + int i = 0; + for (auto& in_arg : function->args()) { + if (i == 3) { + auto y_neg_value = ir_builder()->CreateNeg(&in_arg); + args.push_back(y_neg_value); + } else { + args.push_back(&in_arg); + } + ++i; + } + auto value = + ir_builder()->CreateCall(module()->getFunction("add_decimal128_decimal128"), args); + + // store result to out + ir_builder()->CreateRet(value); + return Status::OK(); +} + Status DecimalIR::AddFunctions(Engine* engine) { auto decimal_ir = std::make_shared(engine); @@ -317,7 +363,10 @@ Status DecimalIR::AddFunctions(Engine* engine) { decimal_ir->InitializeIntrinsics(); // build "add" - return decimal_ir->BuildAdd(); + ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd()); + + // build "subtract" + return decimal_ir->BuildSubtract(); } // Do an bitwise-or of all the overflow bits. diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h index fae762c362d..fb9fe704f6e 100644 --- a/cpp/src/gandiva/decimal_ir.h +++ b/cpp/src/gandiva/decimal_ir.h @@ -143,6 +143,9 @@ class DecimalIR : public FunctionIRBuilder { // Build the function for adding decimals. Status BuildAdd(); + // Build the function for decimal subtraction. + Status BuildSubtract(); + // Add a trace in IR code. void AddTrace(const std::string& fmt, std::vector args); diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc index c5a798cb4e2..0a2ac93dd80 100644 --- a/cpp/src/gandiva/function_registry_arithmetic.cc +++ b/cpp/src/gandiva/function_registry_arithmetic.cc @@ -58,6 +58,7 @@ std::vector GetArithmeticFunctionRegistry() { BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64), BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128), BINARY_RELATIONAL_BOOL_FN(equal), BINARY_RELATIONAL_BOOL_FN(not_equal), diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc index 99231fe537f..887f42df13d 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.cc +++ b/cpp/src/gandiva/precompiled/decimal_ops.cc @@ -221,5 +221,10 @@ BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& } } +BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale) { + return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, out_scale); +} + } // namespace decimalops } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h index 1e202b88a25..5a6c94b9d43 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.h +++ b/cpp/src/gandiva/precompiled/decimal_ops.h @@ -30,5 +30,10 @@ namespace decimalops { arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, int32_t out_precision, int32_t out_scale); +/// Subtract 'y' from 'x', and return the result. +arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale); + } // namespace decimalops } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc index e16f2021f2e..ef2c4023caa 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc +++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc @@ -29,8 +29,18 @@ namespace gandiva { class TestDecimalSql : public ::testing::Test { protected: - static void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected); + static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x, + const DecimalScalar128& y, const DecimalScalar128& expected); + + void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + return Verify(DecimalTypeUtil::kOpAdd, x, y, expected); + } + + void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected); + } }; #define EXPECT_DECIMAL_EQ(x, y, expected, actual) \ @@ -38,15 +48,28 @@ class TestDecimalSql : public ::testing::Test { << " expected : " << expected.ToString() << " actual " \ << actual.ToString() -void TestDecimalSql::AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { +void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x, + const DecimalScalar128& y, const DecimalScalar128& expected) { auto t1 = std::make_shared(x.precision(), x.scale()); auto t2 = std::make_shared(y.precision(), y.scale()); Decimal128TypePtr out_type; - EXPECT_OK(DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {t1, t2}, &out_type)); + EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type)); + + arrow::BasicDecimal128 out_value; + switch (op) { + case DecimalTypeUtil::kOpAdd: + out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale()); + break; + + case DecimalTypeUtil::kOpSubtract: + out_value = decimalops::Subtract(x, y, out_type->precision(), out_type->scale()); + break; - auto out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale()); + default: + // not implemented. + ASSERT_FALSE(true); + } EXPECT_DECIMAL_EQ( x, y, expected, DecimalScalar128(out_value, out_type->precision(), out_type->scale())); @@ -74,4 +97,28 @@ TEST_F(TestDecimalSql, Add) { DecimalScalar128{"-99999999999999999999999999999990000010", 38, 6}); } +TEST_F(TestDecimalSql, Subtract) { + // fast-path + SubtractAndVerify(DecimalScalar128{"201", 30, 3}, // x + DecimalScalar128{"301", 30, 3}, // y + DecimalScalar128{"-100", 31, 3}); // expected + + // max precision + SubtractAndVerify( + DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"100", 38, 7}, // y + DecimalScalar128{"99999999999999999999999999999989999990", 38, 6}); + + // Both -ve + SubtractAndVerify(DecimalScalar128{"-201", 30, 3}, // x + DecimalScalar128{"-301", 30, 2}, // y + DecimalScalar128{"2809", 32, 3}); // expected + + // -ve and max precision + SubtractAndVerify( + DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"-100", 38, 7}, // y + DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6}); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc b/cpp/src/gandiva/tests/decimal_single_test.cc index 776ef6efbd0..a83137f678c 100644 --- a/cpp/src/gandiva/tests/decimal_single_test.cc +++ b/cpp/src/gandiva/tests/decimal_single_test.cc @@ -31,9 +31,10 @@ using arrow::Decimal128; namespace gandiva { -#define EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual) \ - EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \ - << " expected : " << (expected).ToString() \ +#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual) \ + EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \ + << ")" \ + << " expected : " << (expected).ToString() \ << " actual : " << (actual).ToString(); DecimalScalar128 decimal_literal(const char* value, int precision, int scale) { @@ -46,8 +47,19 @@ class TestDecimalOps : public ::testing::Test { void SetUp() { pool_ = arrow::default_memory_pool(); } ArrayPtr MakeDecimalVector(const DecimalScalar128& in); + + void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x, + const DecimalScalar128& y, const DecimalScalar128& expected); + void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected); + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected); + } + + void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected); + } protected: arrow::MemoryPool* pool_; @@ -62,8 +74,9 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) { return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true}); } -void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, - const DecimalScalar128& expected) { +void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, + const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { auto x_type = std::make_shared(x.precision(), x.scale()); auto y_type = std::make_shared(y.precision(), y.scale()); auto field_x = field("x", x_type); @@ -71,15 +84,14 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const DecimalScalar auto schema = arrow::schema({field_x, field_y}); Decimal128TypePtr output_type; - auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {x_type, y_type}, - &output_type); + auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type); EXPECT_OK(status); // output fields auto res = field("res", output_type); - // build expression : x + y - auto expr = TreeExprBuilder::MakeExpression("add", {field_x, field_y}, res); + // build expression : x op y + auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res); // Build a projector for the expression. std::shared_ptr projector; @@ -106,7 +118,7 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const DecimalScalar std::string value_string = out_value.ToString(0); DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()}; - EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual); + EXPECT_DECIMAL_RESULT(function, x, y, expected, actual); } TEST_F(TestDecimalOps, TestAdd) { @@ -221,4 +233,24 @@ TEST_F(TestDecimalOps, TestAdd) { decimal_literal("-10000992", 38, 7), // y decimal_literal("-2001098", 38, 6)); } + +// subtract is a wrapper over add. so, minimal tests are sufficient. +TEST_F(TestDecimalOps, TestSubtract) { + // fast-path + SubtractAndVerify(decimal_literal("201", 30, 3), // x + decimal_literal("301", 30, 3), // y + decimal_literal("-100", 31, 3)); // expected + + // max precision + SubtractAndVerify( + decimal_literal("09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("100", 38, 7), // y + decimal_literal("99999999999999999999999999999989999990", 38, 6)); + + // Mix of +ve and -ve + SubtractAndVerify(decimal_literal("-201", 30, 3), // x + decimal_literal("301", 30, 2), // y + decimal_literal("-3211", 32, 3)); // expected +} + } // namespace gandiva