diff --git a/c_glib/test/test-decimal128.rb b/c_glib/test/test-decimal128.rb index 0e4bc8264d5..98789d3812e 100644 --- a/c_glib/test/test-decimal128.rb +++ b/c_glib/test/test-decimal128.rb @@ -214,7 +214,7 @@ def test_rescale_fail decimal = Arrow::Decimal128.new(10) message = "[decimal128][rescale]: Invalid: " + - "Rescaling decimal value would cause data loss" + "Rescaling Decimal128 value would cause data loss" assert_raise(Arrow::Error::Invalid.new(message)) do decimal.rescale(1, -1) end diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index 900e8d2b38f..b2524afe4f8 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -73,6 +73,10 @@ struct ScalarFromArraySlotImpl { return Finish(Decimal128(a.GetValue(index_))); } + Status Visit(const Decimal256Array& a) { + return Finish(Decimal256(a.GetValue(index_))); + } + template Status Visit(const BaseBinaryArray& a) { return Finish(a.GetString(index_)); diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index 1e813f2e515..d65f6ee5356 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -33,11 +33,11 @@ namespace arrow { using internal::checked_cast; // ---------------------------------------------------------------------- -// Decimal +// Decimal128 Decimal128Array::Decimal128Array(const std::shared_ptr& data) : FixedSizeBinaryArray(data) { - ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL); + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL128); } std::string Decimal128Array::FormatValue(int64_t i) const { @@ -46,4 +46,18 @@ std::string Decimal128Array::FormatValue(int64_t i) const { return value.ToString(type_.scale()); } +// ---------------------------------------------------------------------- +// Decimal256 + +Decimal256Array::Decimal256Array(const std::shared_ptr& data) + : FixedSizeBinaryArray(data) { + ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL256); +} + +std::string Decimal256Array::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const Decimal256 value(GetValue(i)); + return value.ToString(type_.scale()); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index 6d5e884118b..8d7d1c59cd0 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -47,4 +47,20 @@ class ARROW_EXPORT Decimal128Array : public FixedSizeBinaryArray { // Backward compatibility using DecimalArray = Decimal128Array; +// ---------------------------------------------------------------------- +// Decimal256Array + +/// Concrete Array class for 256-bit decimal data +class ARROW_EXPORT Decimal256Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal256Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal256Array from ArrayData instance + explicit Decimal256Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + } // namespace arrow diff --git a/cpp/src/arrow/array/array_dict_test.cc b/cpp/src/arrow/array/array_dict_test.cc index d029838c4fc..7bf51fa8931 100644 --- a/cpp/src/arrow/array/array_dict_test.cc +++ b/cpp/src/arrow/array/array_dict_test.cc @@ -835,13 +835,13 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, AppendArrayInvalidType) { } #endif -TEST(TestDecimalDictionaryBuilder, Basic) { +template +void TestDecimalDictionaryBuilderBasic(std::shared_ptr decimal_type) { // Build the dictionary Array - auto decimal_type = arrow::decimal(2, 0); DictionaryBuilder builder(decimal_type); // Test data - std::vector test{12, 12, 11, 12}; + std::vector test{12, 12, 11, 12}; for (const auto& value : test) { ASSERT_OK(builder.Append(value.ToBytes().data())); } @@ -857,40 +857,48 @@ TEST(TestDecimalDictionaryBuilder, Basic) { ASSERT_TRUE(expected.Equals(result)); } -TEST(TestDecimalDictionaryBuilder, DoubleTableSize) { - const auto& decimal_type = arrow::decimal(21, 0); +TEST(TestDecimal128DictionaryBuilder, Basic) { + TestDecimalDictionaryBuilderBasic(arrow::decimal128(2, 0)); +} + +TEST(TestDecimal256DictionaryBuilder, Basic) { + TestDecimalDictionaryBuilderBasic(arrow::decimal256(76, 0)); +} +void TestDecimalDictionaryBuilderDoubleTableSize( + std::shared_ptr decimal_type, FixedSizeBinaryBuilder& decimal_builder) { // Build the dictionary Array DictionaryBuilder dict_builder(decimal_type); // Build expected data - Decimal128Builder decimal_builder(decimal_type); Int16Builder int_builder; // Fill with 1024 different values for (int64_t i = 0; i < 1024; i++) { - const uint8_t bytes[] = {0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 12, - 12, - static_cast(i / 128), - static_cast(i % 128)}; + // Decimal256Builder takes 32 bytes, while Decimal128Builder takes only the first 16 + // bytes. + const uint8_t bytes[32] = {0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 12, + 12, + static_cast(i / 128), + static_cast(i % 128)}; ASSERT_OK(dict_builder.Append(bytes)); ASSERT_OK(decimal_builder.Append(bytes)); ASSERT_OK(int_builder.Append(static_cast(i))); } // Fill with an already existing value - const uint8_t known_value[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 1}; + const uint8_t known_value[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 1}; for (int64_t i = 0; i < 1024; i++) { ASSERT_OK(dict_builder.Append(known_value)); ASSERT_OK(int_builder.Append(1)); @@ -911,6 +919,18 @@ TEST(TestDecimalDictionaryBuilder, DoubleTableSize) { ASSERT_TRUE(expected.Equals(result)); } +TEST(TestDecimal128DictionaryBuilder, DoubleTableSize) { + const auto& decimal_type = arrow::decimal128(21, 0); + Decimal128Builder decimal_builder(decimal_type); + TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder); +} + +TEST(TestDecimal256DictionaryBuilder, DoubleTableSize) { + const auto& decimal_type = arrow::decimal256(21, 0); + Decimal256Builder decimal_builder(decimal_type); + TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder); +} + TEST(TestNullDictionaryBuilder, Basic) { // MakeBuilder auto dict_type = dictionary(int8(), null()); diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index c256d26d88a..9d4f00d9435 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -429,6 +429,7 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { std::make_shared( hello, fixed_size_binary(static_cast(hello->size()))), std::make_shared(Decimal128(10), decimal(16, 4)), + std::make_shared(Decimal256(10), decimal(76, 38)), std::make_shared(hello), std::make_shared(hello), std::make_shared(ArrayFromJSON(int8(), "[1, 2, 3]")), @@ -2390,10 +2391,14 @@ TEST(TestAdaptiveUIntBuilderWithStartIntSize, TestReset) { // ---------------------------------------------------------------------- // Test Decimal arrays -using DecimalVector = std::vector; - +template class DecimalTest : public ::testing::TestWithParam { public: + using DecimalBuilder = typename TypeTraits::BuilderType; + using DecimalValue = typename TypeTraits::ScalarType::ValueType; + using DecimalArray = typename TypeTraits::ArrayType; + using DecimalVector = std::vector; + DecimalTest() {} template @@ -2409,8 +2414,8 @@ class DecimalTest : public ::testing::TestWithParam { template void TestCreate(int32_t precision, const DecimalVector& draw, const std::vector& valid_bytes, int64_t offset) const { - auto type = std::make_shared(precision, 4); - auto builder = std::make_shared(type); + auto type = std::make_shared(precision, 4); + auto builder = std::make_shared(type); size_t null_count = 0; @@ -2441,7 +2446,7 @@ class DecimalTest : public ::testing::TestWithParam { ASSERT_OK_AND_ASSIGN(expected_null_bitmap, internal::BytesToBits(valid_bytes)); int64_t expected_null_count = CountNulls(valid_bytes); - auto expected = std::make_shared( + auto expected = std::make_shared( type, size, expected_data, expected_null_bitmap, expected_null_count); std::shared_ptr lhs = out->Slice(offset); @@ -2450,7 +2455,9 @@ class DecimalTest : public ::testing::TestWithParam { } }; -TEST_P(DecimalTest, NoNulls) { +using Decimal128Test = DecimalTest; + +TEST_P(Decimal128Test, NoNulls) { int32_t precision = GetParam(); std::vector draw = {Decimal128(1), Decimal128(-2), Decimal128(2389), Decimal128(4), Decimal128(-12348)}; @@ -2459,7 +2466,7 @@ TEST_P(DecimalTest, NoNulls) { this->TestCreate(precision, draw, valid_bytes, 2); } -TEST_P(DecimalTest, WithNulls) { +TEST_P(Decimal128Test, WithNulls) { int32_t precision = GetParam(); std::vector draw = {Decimal128(1), Decimal128(2), Decimal128(-1), Decimal128(4), Decimal128(-1), Decimal128(1), @@ -2478,7 +2485,44 @@ TEST_P(DecimalTest, WithNulls) { this->TestCreate(precision, draw, valid_bytes, 2); } -INSTANTIATE_TEST_SUITE_P(DecimalTest, DecimalTest, ::testing::Range(1, 38)); +INSTANTIATE_TEST_SUITE_P(Decimal128Test, Decimal128Test, ::testing::Range(1, 38)); + +using Decimal256Test = DecimalTest; + +TEST_P(Decimal256Test, NoNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal256(1), Decimal256(-2), Decimal256(2389), + Decimal256(4), Decimal256(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +TEST_P(Decimal256Test, WithNulls) { + int32_t precision = GetParam(); + std::vector draw = {Decimal256(1), Decimal256(2), Decimal256(-1), + Decimal256(4), Decimal256(-1), Decimal256(1), + Decimal256(2)}; + Decimal256 big; // (pow(2, 255) - 1) / pow(10, 38) + ASSERT_OK_AND_ASSIGN(big, + Decimal256::FromString("578960446186580977117854925043439539266." + "34992332820282019728792003956564819967")); + draw.push_back(big); + + Decimal256 big_negative; // -pow(2, 255) / pow(10, 38) + ASSERT_OK_AND_ASSIGN(big_negative, + Decimal256::FromString("-578960446186580977117854925043439539266." + "34992332820282019728792003956564819968")); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); +} + +INSTANTIATE_TEST_SUITE_P(Decimal256Test, Decimal256Test, + ::testing::Values(1, 2, 5, 10, 38, 39, 40, 75, 76)); // ---------------------------------------------------------------------- // Test rechunking diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index ea5c9ebd0c3..bd7615a7309 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -67,4 +67,39 @@ Status Decimal128Builder::FinishInternal(std::shared_ptr* out) { return Status::OK(); } +// ---------------------------------------------------------------------- +// Decimal256Builder + +Decimal256Builder::Decimal256Builder(const std::shared_ptr& type, + MemoryPool* pool) + : FixedSizeBinaryBuilder(type, pool), + decimal_type_(internal::checked_pointer_cast(type)) {} + +Status Decimal256Builder::Append(const Decimal256& value) { + RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); + UnsafeAppend(value); + return Status::OK(); +} + +void Decimal256Builder::UnsafeAppend(const Decimal256& value) { + value.ToBytes(GetMutableValue(length())); + byte_builder_.UnsafeAdvance(32); + UnsafeAppendToBitmap(true); +} + +void Decimal256Builder::UnsafeAppend(util::string_view value) { + FixedSizeBinaryBuilder::UnsafeAppend(value); +} + +Status Decimal256Builder::FinishInternal(std::shared_ptr* out) { + std::shared_ptr data; + RETURN_NOT_OK(byte_builder_.Finish(&data)); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 8f0ff83288c..8c75e7dd674 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -58,6 +58,35 @@ class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { std::shared_ptr decimal_type_; }; +class ARROW_EXPORT Decimal256Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal256Type; + + explicit Decimal256Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool()); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(const Decimal256& val); + void UnsafeAppend(const Decimal256& val); + void UnsafeAppend(util::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + using DecimalBuilder = Decimal128Builder; } // namespace arrow diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index e2a758f4419..b13f6a2db34 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -45,7 +45,7 @@ class DictionaryMemoTable::DictionaryMemoTableImpl { template enable_if_no_memoize Visit(const T&) { - return Status::NotImplemented("Initialization of ", value_type_, + return Status::NotImplemented("Initialization of ", value_type_->ToString(), " memo table is not implemented"); } diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index d15855b0a89..40d6ce1ba9a 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -240,12 +240,20 @@ class DictionaryBuilderBase : public ArrayBuilder { /// \brief Append a decimal (only for Decimal128Type) template - enable_if_decimal Append(const Decimal128& value) { + enable_if_decimal128 Append(const Decimal128& value) { uint8_t data[16]; value.ToBytes(data); return Append(data, 16); } + /// \brief Append a decimal (only for Decimal128Type) + template + enable_if_decimal256 Append(const Decimal256& value) { + uint8_t data[32]; + value.ToBytes(data); + return Append(data, 32); + } + /// \brief Append a scalar null value Status AppendNull() final { length_ += 1; diff --git a/cpp/src/arrow/array/concatenate.cc b/cpp/src/arrow/array/concatenate.cc index dcfb3f53004..30eeeee2a2d 100644 --- a/cpp/src/arrow/array/concatenate.cc +++ b/cpp/src/arrow/array/concatenate.cc @@ -201,7 +201,7 @@ class ConcatenateImpl { } Status Visit(const FixedWidthType& fixed) { - // Handles numbers, decimal128, fixed_size_binary + // Handles numbers, decimal128, decimal256, fixed_size_binary ARROW_ASSIGN_OR_RAISE(auto buffers, Buffers(1, fixed)); return ConcatenateBuffers(buffers, pool_).Value(&out_->buffers[1]); } diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 3063f5580cd..5bc0bf31d07 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -64,6 +64,13 @@ struct ValidateArrayVisitor { return Status::OK(); } + Status Visit(const Decimal256Array& array) { + if (array.length() > 0 && array.values() == nullptr) { + return Status::Invalid("values is null"); + } + return Status::OK(); + } + Status Visit(const StringArray& array) { return ValidateBinaryArray(array); } Status Visit(const BinaryArray& array) { return ValidateBinaryArray(array); } diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 1dcbf7851ab..f22228a4588 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -51,6 +51,7 @@ struct DictionaryBuilderCase { } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } + Status Visit(const Decimal256Type&) { return CreateFor(); } Status Visit(const DataType& value_type) { return NotImplemented(value_type); } Status Visit(const HalfFloatType& value_type) { return NotImplemented(value_type); } @@ -138,6 +139,7 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, BUILDER_CASE(LargeBinary); BUILDER_CASE(FixedSizeBinary); BUILDER_CASE(Decimal128); + BUILDER_CASE(Decimal256); case Type::DICTIONARY: { const auto& dict_type = static_cast(*type); diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 5b360abc48c..5cb3e577235 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -304,9 +304,16 @@ struct SchemaExporter { return SetFormat("w:" + std::to_string(type.byte_width())); } - Status Visit(const Decimal128Type& type) { - return SetFormat("d:" + std::to_string(type.precision()) + "," + - std::to_string(type.scale())); + Status Visit(const DecimalType& type) { + if (type.bit_width() == 128) { + // 128 is the default bit-width + return SetFormat("d:" + std::to_string(type.precision()) + "," + + std::to_string(type.scale())); + } else { + return SetFormat("d:" + std::to_string(type.precision()) + "," + + std::to_string(type.scale()) + "," + + std::to_string(type.bit_width())); + } } Status Visit(const BinaryType& type) { return SetFormat("z"); } @@ -973,13 +980,20 @@ struct SchemaImporter { Status ProcessDecimal() { RETURN_NOT_OK(f_parser_.CheckNext(':')); ARROW_ASSIGN_OR_RAISE(auto prec_scale, f_parser_.ParseInts(f_parser_.Rest())); - if (prec_scale.size() != 2) { + // 3 elements indicates bit width was communicated as well. + if (prec_scale.size() != 2 && prec_scale.size() != 3) { return f_parser_.Invalid(); } if (prec_scale[0] <= 0 || prec_scale[1] <= 0) { return f_parser_.Invalid(); } - type_ = decimal(prec_scale[0], prec_scale[1]); + if (prec_scale.size() == 2 || prec_scale[2] == 128) { + type_ = decimal(prec_scale[0], prec_scale[1]); + } else if (prec_scale[2] == 256) { + type_ = decimal256(prec_scale[0], prec_scale[1]); + } else { + return f_parser_.Invalid(); + } return Status::OK(); } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 3f84edfc2a6..fc11f126e72 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -281,6 +281,7 @@ TEST_F(TestSchemaExport, Primitive) { TestPrimitive(large_utf8(), "U"); TestPrimitive(decimal(16, 4), "d:16,4"); + TestPrimitive(decimal256(16, 4), "d:16,4,256"); } TEST_F(TestSchemaExport, Temporal) { @@ -740,6 +741,7 @@ TEST_F(TestArrayExport, Primitive) { TestPrimitive(large_utf8(), R"(["foo", "bar", null])"); TestPrimitive(decimal(16, 4), R"(["1234.5670", null])"); + TestPrimitive(decimal256(16, 4), R"(["1234.5670", null])"); } TEST_F(TestArrayExport, PrimitiveSliced) { @@ -1186,6 +1188,13 @@ TEST_F(TestSchemaImport, Primitive) { CheckImport(field("", float32())); FillPrimitive("g"); CheckImport(field("", float64())); + + FillPrimitive("d:16,4"); + CheckImport(field("", decimal128(16, 4))); + FillPrimitive("d:16,4,128"); + CheckImport(field("", decimal128(16, 4))); + FillPrimitive("d:16,4,256"); + CheckImport(field("", decimal256(16, 4))); } TEST_F(TestSchemaImport, Temporal) { @@ -2373,6 +2382,8 @@ TEST_F(TestSchemaRoundtrip, Primitive) { TestWithTypeFactory(float16); TestWithTypeFactory(std::bind(decimal, 19, 4)); + TestWithTypeFactory(std::bind(decimal128, 19, 4)); + TestWithTypeFactory(std::bind(decimal256, 19, 4)); TestWithTypeFactory(std::bind(fixed_size_binary, 3)); TestWithTypeFactory(binary); TestWithTypeFactory(large_utf8); @@ -2430,7 +2441,7 @@ TEST_F(TestSchemaRoundtrip, Map) { TEST_F(TestSchemaRoundtrip, Schema) { auto f1 = field("f1", utf8(), /*nullable=*/false); - auto f2 = field("f2", list(decimal(19, 4))); + auto f2 = field("f2", list(decimal256(19, 4))); auto md1 = key_value_metadata(kMetadataKeys1, kMetadataValues1); auto md2 = key_value_metadata(kMetadataKeys2, kMetadataValues2); @@ -2574,8 +2585,13 @@ TEST_F(TestArrayRoundtrip, Primitive) { TestWithJSON(int32(), "[]"); TestWithJSON(int32(), "[4, 5, null]"); + TestWithJSON(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSON(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSONSliced(int32(), "[4, 5]"); TestWithJSONSliced(int32(), "[4, 5, 6, null]"); + TestWithJSONSliced(decimal128(16, 4), R"(["0.4759", "1234.5670", null])"); + TestWithJSONSliced(decimal256(16, 4), R"(["0.4759", "1234.5670", null])"); } TEST_F(TestArrayRoundtrip, UnknownNullCount) { diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 421ec139242..622f5cb5c5f 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -353,6 +353,10 @@ class RangeEqualsVisitor { return Visit(checked_cast(left)); } + Status Visit(const Decimal256Array& left) { + return Visit(checked_cast(left)); + } + Status Visit(const NullArray& left) { ARROW_UNUSED(left); result_ = true; @@ -806,6 +810,12 @@ class TypeEqualsVisitor { return Status::OK(); } + Status Visit(const Decimal256Type& left) { + const auto& right = checked_cast(right_); + result_ = left.precision() == right.precision() && left.scale() == right.scale(); + return Status::OK(); + } + template enable_if_t::value || is_struct_type::value, Status> Visit( const T& left) { @@ -919,6 +929,12 @@ class ScalarEqualsVisitor { return Status::OK(); } + Status Visit(const Decimal256Scalar& left) { + const auto& right = checked_cast(right_); + result_ = left.value == right.value; + return Status::OK(); + } + Status Visit(const ListScalar& left) { const auto& right = checked_cast(right_); result_ = internal::SharedPtrEquals(left.value, right.value); diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index df18fceaa20..a5ef9d44e18 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -38,7 +38,7 @@ TEST(TypeMatcher, SameTypeId) { ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); ASSERT_FALSE(matcher->Matches(*int8())); - ASSERT_EQ("Type::DECIMAL", matcher->ToString()); + ASSERT_EQ("Type::DECIMAL128", matcher->ToString()); ASSERT_TRUE(matcher->Equals(*matcher)); ASSERT_TRUE(matcher->Equals(*match::SameTypeId(Type::DECIMAL))); @@ -103,7 +103,7 @@ TEST(InputType, Constructors) { // Same type id constructor InputType ty2(Type::DECIMAL); ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); - ASSERT_EQ("any[Type::DECIMAL]", ty2.ToString()); + ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); @@ -135,9 +135,9 @@ TEST(InputType, Constructors) { ASSERT_EQ("array[int8]", ty1_array.ToString()); ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); - ASSERT_EQ("any[Type::DECIMAL]", ty2.ToString()); - ASSERT_EQ("array[Type::DECIMAL]", ty2_array.ToString()); - ASSERT_EQ("scalar[Type::DECIMAL]", ty2_scalar.ToString()); + ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); + ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString()); + ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString()); InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO)); ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); @@ -484,14 +484,14 @@ TEST(KernelSignature, ToString) { InputType(Type::DECIMAL, ValueDescr::ARRAY), InputType(utf8())}; KernelSignature sig(in_types, utf8()); - ASSERT_EQ("(scalar[int8], array[Type::DECIMAL], any[string]) -> string", + ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string", sig.ToString()); OutputType out_type([](KernelContext*, const std::vector& args) { return Status::Invalid("NYI"); }); KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type); - ASSERT_EQ("(any[int8], any[Type::DECIMAL]) -> computed", sig2.ToString()); + ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString()); } TEST(KernelSignature, VarArgsToString) { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index bea9a0ef8dc..bbd5ce07412 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -65,7 +65,7 @@ struct TestCType> { }; template -struct TestCType> { +struct TestCType> { using type = Decimal128; }; diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index 64a1849153f..0009fe53346 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -498,7 +498,8 @@ KernelInit GetHashInit(Type::type type_id) { case Type::LARGE_STRING: return HashInit; case Type::FIXED_SIZE_BINARY: - case Type::DECIMAL: + case Type::DECIMAL128: + case Type::DECIMAL256: return HashInit; default: DCHECK(false); diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 114708532c3..44cc370c8e3 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -136,6 +136,7 @@ struct CompareVisitor { } Status Visit(const Decimal128Type&) { return CompareValues(); } + Status Visit(const Decimal256Type&) { return CompareValues(); } // Explicit because it falls under `physical_unsigned_integer`. // TODO(bkietz) whenever we vendor a float16, this can be implemented diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 39d19d6ec28..61c93cda7cc 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -289,12 +289,14 @@ class FloatConverter final : public ConcreteConverter::BuilderType> -class DecimalConverter final : public ConcreteConverter> { +template +class DecimalConverter final + : public ConcreteConverter< + DecimalConverter> { public: explicit DecimalConverter(const std::shared_ptr& type) { this->type_ = type; - decimal_type_ = &checked_cast(*this->value_type()); + decimal_type_ = &checked_cast(*this->value_type()); } Status Init() override { return this->MakeConcreteBuilder(&builder_); } @@ -305,9 +307,9 @@ class DecimalConverter final : public ConcreteConverterscale()) { return Status::Invalid("Invalid scale for decimal: expected ", decimal_type_->scale(), ", got ", scale); @@ -321,9 +323,14 @@ class DecimalConverter final : public ConcreteConverter builder_; - const Decimal128Type* decimal_type_; + const DecimalSubtype* decimal_type_; }; +template ::BuilderType> +using Decimal128Converter = DecimalConverter; +template ::BuilderType> +using Decimal256Converter = DecimalConverter; + // ------------------------------------------------------------------------ // Converter for timestamp arrays @@ -773,7 +780,8 @@ Status GetDictConverter(const std::shared_ptr& type, PARAM_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter, LargeBinaryType) SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter, FixedSizeBinaryType) - SIMPLE_CONVERTER_CASE(Type::DECIMAL, DecimalConverter, Decimal128Type) + SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter, Decimal128Type) + SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter, Decimal256Type) default: return ConversionNotImplemented(type); } @@ -829,7 +837,8 @@ Status GetConverter(const std::shared_ptr& type, SIMPLE_CONVERTER_CASE(Type::LARGE_STRING, StringConverter) SIMPLE_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter) SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter<>) - SIMPLE_CONVERTER_CASE(Type::DECIMAL, DecimalConverter<>) + SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter<>) + SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter<>) SIMPLE_CONVERTER_CASE(Type::SPARSE_UNION, UnionConverter) SIMPLE_CONVERTER_CASE(Type::DENSE_UNION, UnionConverter) SIMPLE_CONVERTER_CASE(Type::INTERVAL_MONTHS, IntegerConverter) diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index f6a6a92c5f7..98bf6e46211 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -494,14 +494,14 @@ TEST(TestFixedSizeBinary, Dictionary) { ASSERT_RAISES(Invalid, ArrayFromJSON(dictionary(int8(), type), R"(["x"])", &array)); } -TEST(TestDecimal, Basics) { - std::shared_ptr type = decimal(10, 4); +template +void TestDecimalBasic(std::shared_ptr type) { std::shared_ptr expected, actual; ASSERT_OK(ArrayFromJSON(type, "[]", &actual)); ASSERT_OK(actual->ValidateFull()); { - Decimal128Builder builder(type); + DecimalBuilder builder(type); ASSERT_OK(builder.Finish(&expected)); } AssertArraysEqual(*expected, *actual); @@ -509,9 +509,9 @@ TEST(TestDecimal, Basics) { ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", \"-78.9000\"]", &actual)); ASSERT_OK(actual->ValidateFull()); { - Decimal128Builder builder(type); - ASSERT_OK(builder.Append(Decimal128(1234567))); - ASSERT_OK(builder.Append(Decimal128(-789000))); + DecimalBuilder builder(type); + ASSERT_OK(builder.Append(DecimalValue(1234567))); + ASSERT_OK(builder.Append(DecimalValue(-789000))); ASSERT_OK(builder.Finish(&expected)); } AssertArraysEqual(*expected, *actual); @@ -519,31 +519,41 @@ TEST(TestDecimal, Basics) { ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", null]", &actual)); ASSERT_OK(actual->ValidateFull()); { - Decimal128Builder builder(type); - ASSERT_OK(builder.Append(Decimal128(1234567))); + DecimalBuilder builder(type); + ASSERT_OK(builder.Append(DecimalValue(1234567))); ASSERT_OK(builder.AppendNull()); ASSERT_OK(builder.Finish(&expected)); } AssertArraysEqual(*expected, *actual); } -TEST(TestDecimal, Errors) { - std::shared_ptr type = decimal(10, 4); - std::shared_ptr array; +TEST(TestDecimal128, Basics) { + TestDecimalBasic(decimal128(10, 4)); +} - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array)); - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[12.3456]", &array)); - // Bad scale - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.345\"]", &array)); - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.34560\"]", &array)); +TEST(TestDecimal256, Basics) { + TestDecimalBasic(decimal256(10, 4)); } -TEST(TestDecimal, Dictionary) { - std::shared_ptr type = decimal(10, 2); +TEST(TestDecimal, Errors) { + for (std::shared_ptr type : {decimal128(10, 4), decimal256(10, 4)}) { + std::shared_ptr array; + + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array)); + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[12.3456]", &array)); + // Bad scale + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.345\"]", &array)); + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.34560\"]", &array)); + } +} - AssertJSONDictArray(int32(), type, R"(["123.45", "-78.90", "-78.90", null, "123.45"])", - /*indices=*/"[0, 1, 1, null, 0]", - /*values=*/R"(["123.45", "-78.90"])"); +TEST(TestDecimal, Dictionary) { + for (std::shared_ptr type : {decimal128(10, 2), decimal256(10, 2)}) { + AssertJSONDictArray(int32(), type, + R"(["123.45", "-78.90", "-78.90", null, "123.45"])", + /*indices=*/"[0, 1, 1, null, 0]", + /*values=*/R"(["123.45", "-78.90"])"); + } } TEST(TestList, IntegerList) { diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index a82aef328d6..10cee49963e 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -236,8 +236,6 @@ static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) { return TimeUnit::SECOND; } -constexpr int32_t kDecimalBitWidth = 128; - Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data, const std::vector>& children, std::shared_ptr* out) { @@ -273,10 +271,13 @@ Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data, return Status::OK(); case flatbuf::Type::Decimal: { auto dec_type = static_cast(type_data); - if (dec_type->bitWidth() != kDecimalBitWidth) { - return Status::Invalid("Library only supports 128-bit decimal values"); + if (dec_type->bitWidth() == 128) { + return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out); + } else if (dec_type->bitWidth() == 256) { + return Decimal256Type::Make(dec_type->precision(), dec_type->scale()).Value(out); + } else { + return Status::Invalid("Library only supports 128-bit or 256-bit decimal values"); } - return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out); } case flatbuf::Type::Date: { auto date_type = static_cast(type_data); @@ -594,11 +595,21 @@ class FieldToFlatbufferVisitor { return Status::OK(); } - Status Visit(const DecimalType& type) { + Status Visit(const Decimal128Type& type) { const auto& dec_type = checked_cast(type); fb_type_ = flatbuf::Type::Decimal; - type_offset_ = - flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale()).Union(); + type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(), + /*bitWidth=*/128) + .Union(); + return Status::OK(); + } + + Status Visit(const Decimal256Type& type) { + const auto& dec_type = checked_cast(type); + fb_type_ = flatbuf::Type::Decimal; + type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(), + /*bitWith=*/256) + .Union(); return Status::OK(); } diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index 9223ce7fba6..8c2ac376d1e 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -227,6 +227,11 @@ class ArrayPrinter : public PrettyPrinter { return Status::OK(); } + Status WriteDataValues(const Decimal256Array& array) { + WriteValues(array, [&](int64_t i) { (*sink_) << array.FormatValue(i); }); + return Status::OK(); + } + template enable_if_list_like WriteDataValues(const T& array) { bool skip_comma = true; diff --git a/cpp/src/arrow/pretty_print_test.cc b/cpp/src/arrow/pretty_print_test.cc index 9e58e46fe94..feac583f495 100644 --- a/cpp/src/arrow/pretty_print_test.cc +++ b/cpp/src/arrow/pretty_print_test.cc @@ -499,15 +499,16 @@ TEST_F(TestPrettyPrint, FixedSizeBinaryType) { CheckArray(*array, {2, 1}, ex_2); } -TEST_F(TestPrettyPrint, Decimal128Type) { +TEST_F(TestPrettyPrint, DecimalTypes) { int32_t p = 19; int32_t s = 4; - auto type = decimal(p, s); - auto array = ArrayFromJSON(type, "[\"123.4567\", \"456.7891\", null]"); + for (auto type : {decimal128(p, s), decimal256(p, s)}) { + auto array = ArrayFromJSON(type, "[\"123.4567\", \"456.7891\", null]"); - static const char* ex = "[\n 123.4567,\n 456.7891,\n null\n]"; - CheckArray(*array, {0}, ex); + static const char* ex = "[\n 123.4567,\n 456.7891,\n null\n]"; + CheckArray(*array, {0}, ex); + } } TEST_F(TestPrettyPrint, DictionaryType) { diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index be27a108dd1..09245285030 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -167,7 +167,8 @@ static inline bool ListTypeSupported(const DataType& type) { case Type::UINT64: case Type::FLOAT: case Type::DOUBLE: - case Type::DECIMAL: + case Type::DECIMAL128: + case Type::DECIMAL256: case Type::BINARY: case Type::LARGE_BINARY: case Type::STRING: @@ -1118,6 +1119,31 @@ struct ObjectWriterVisitor { return Status::OK(); } + Status Visit(const Decimal256Type& type) { + OwnedRef decimal; + OwnedRef Decimal; + RETURN_NOT_OK(internal::ImportModule("decimal", &decimal)); + RETURN_NOT_OK(internal::ImportFromModule(decimal.obj(), "Decimal", &Decimal)); + PyObject* decimal_constructor = Decimal.obj(); + + for (int c = 0; c < data.num_chunks(); c++) { + const auto& arr = checked_cast(*data.chunk(c)); + + for (int64_t i = 0; i < arr.length(); ++i) { + if (arr.IsNull(i)) { + Py_INCREF(Py_None); + *out_values++ = Py_None; + } else { + *out_values++ = + internal::DecimalFromString(decimal_constructor, arr.FormatValue(i)); + RETURN_IF_PYERROR(); + } + } + } + + return Status::OK(); + } + template enable_if_t::value || is_var_length_list_type::value, Status> @@ -1845,7 +1871,8 @@ static Status GetPandasWriterType(const ChunkedArray& data, const PandasOptions& case Type::STRUCT: // fall through case Type::TIME32: // fall through case Type::TIME64: // fall through - case Type::DECIMAL: // fall through + case Type::DECIMAL128: // fall through + case Type::DECIMAL256: // fall through *output_type = PandasWriter::OBJECT; break; case Type::DATE32: // fall through diff --git a/cpp/src/arrow/python/decimal.cc b/cpp/src/arrow/python/decimal.cc index 18712015df8..a624f5a073a 100644 --- a/cpp/src/arrow/python/decimal.cc +++ b/cpp/src/arrow/python/decimal.cc @@ -109,13 +109,14 @@ PyObject* DecimalFromString(PyObject* decimal_constructor, namespace { +template Status DecimalFromStdString(const std::string& decimal_string, - const DecimalType& arrow_type, Decimal128* out) { + const DecimalType& arrow_type, ArrowDecimal* out) { int32_t inferred_precision; int32_t inferred_scale; - RETURN_NOT_OK( - Decimal128::FromString(decimal_string, out, &inferred_precision, &inferred_scale)); + RETURN_NOT_OK(ArrowDecimal::FromString(decimal_string, out, &inferred_precision, + &inferred_scale)); const int32_t precision = arrow_type.precision(); const int32_t scale = arrow_type.scale(); @@ -133,10 +134,10 @@ Status DecimalFromStdString(const std::string& decimal_string, return Status::OK(); } -} // namespace - -Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, - Decimal128* out) { +template +Status InternalDecimalFromPythonDecimal(PyObject* python_decimal, + const DecimalType& arrow_type, + ArrowDecimal* out) { DCHECK_NE(python_decimal, NULLPTR); DCHECK_NE(out, NULLPTR); @@ -145,8 +146,9 @@ Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arr return DecimalFromStdString(string, arrow_type, out); } -Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, - Decimal128* out) { +template +Status InternalDecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + ArrowDecimal* out) { DCHECK_NE(obj, NULLPTR); DCHECK_NE(out, NULLPTR); @@ -156,13 +158,35 @@ Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, RETURN_NOT_OK(PyObject_StdStringStr(obj, &string)); return DecimalFromStdString(string, arrow_type, out); } else if (PyDecimal_Check(obj)) { - return DecimalFromPythonDecimal(obj, arrow_type, out); + return InternalDecimalFromPythonDecimal(obj, arrow_type, out); } else { return Status::TypeError("int or Decimal object expected, got ", Py_TYPE(obj)->tp_name); } } +} // namespace + +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal128* out) { + return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); +} + +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + Decimal128* out) { + return InternalDecimalFromPyObject(obj, arrow_type, out); +} + +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal256* out) { + return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); +} + +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + Decimal256* out) { + return InternalDecimalFromPyObject(obj, arrow_type, out); +} + bool PyDecimal_Check(PyObject* obj) { static OwnedRef decimal_type; if (!decimal_type.obj()) { diff --git a/cpp/src/arrow/python/decimal.h b/cpp/src/arrow/python/decimal.h index 3d20b014010..1187037aed2 100644 --- a/cpp/src/arrow/python/decimal.h +++ b/cpp/src/arrow/python/decimal.h @@ -25,6 +25,7 @@ namespace arrow { class Decimal128; +class Decimal256; namespace py { @@ -72,6 +73,23 @@ Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arr ARROW_PYTHON_EXPORT Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal128* out); +// \brief Convert a Python decimal to an Arrow Decimal256 object +// \param[in] python_decimal A Python decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal256 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal256* out); + +// \brief Convert a Python object to an Arrow Decimal256 object +// \param[in] python_decimal A Python int or decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal256 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal256* out); + // \brief Check whether obj is an instance of Decimal ARROW_PYTHON_EXPORT bool PyDecimal_Check(PyObject* obj); diff --git a/cpp/src/arrow/python/inference.cc b/cpp/src/arrow/python/inference.cc index a75a887693c..9d6707aa11d 100644 --- a/cpp/src/arrow/python/inference.cc +++ b/cpp/src/arrow/python/inference.cc @@ -450,9 +450,16 @@ class TypeInferrer { } else if (struct_count_) { RETURN_NOT_OK(GetStructType(out)); } else if (decimal_count_) { - // the default constructor does not validate the precision and scale - ARROW_ASSIGN_OR_RAISE(*out, Decimal128Type::Make(max_decimal_metadata_.precision(), - max_decimal_metadata_.scale())); + if (max_decimal_metadata_.precision() > Decimal128Type::kMaxPrecision) { + // the default constructor does not validate the precision and scale + ARROW_ASSIGN_OR_RAISE(*out, + Decimal256Type::Make(max_decimal_metadata_.precision(), + max_decimal_metadata_.scale())); + } else { + ARROW_ASSIGN_OR_RAISE(*out, + Decimal128Type::Make(max_decimal_metadata_.precision(), + max_decimal_metadata_.scale())); + } } else if (float_count_) { // Prioritize floats before integers *out = float64(); diff --git a/cpp/src/arrow/python/python_test.cc b/cpp/src/arrow/python/python_test.cc index 80bda384bde..037a85875a5 100644 --- a/cpp/src/arrow/python/python_test.cc +++ b/cpp/src/arrow/python/python_test.cc @@ -28,6 +28,7 @@ #include "arrow/table.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/decimal.h" +#include "arrow/util/optional.h" #include "arrow/python/arrow_to_pandas.h" #include "arrow/python/decimal.h" @@ -332,54 +333,62 @@ TEST(BuiltinConversionTest, TestMixedTypeFails) { ASSERT_RAISES(TypeError, ConvertPySequence(list, nullptr, {})); } +template +void DecimalTestFromPythonDecimalRescale(std::shared_ptr type, + OwnedRef python_decimal, + ::arrow::util::optional expected) { + DecimalValue value; + const auto& decimal_type = checked_cast(*type); + + if (expected.has_value()) { + ASSERT_OK( + internal::DecimalFromPythonDecimal(python_decimal.obj(), decimal_type, &value)); + ASSERT_EQ(expected.value(), value); + + ASSERT_OK(internal::DecimalFromPyObject(python_decimal.obj(), decimal_type, &value)); + ASSERT_EQ(expected.value(), value); + } else { + ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(), + decimal_type, &value)); + ASSERT_RAISES(Invalid, internal::DecimalFromPyObject(python_decimal.obj(), + decimal_type, &value)); + } +} + TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) { // We fail when truncating values that would lose data if cast to a decimal type with // lower scale - Decimal128 value; - OwnedRef python_decimal(this->CreatePythonDecimal("1.001")); - auto type = ::arrow::decimal(10, 2); - const auto& decimal_type = checked_cast(*type); - ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(), - decimal_type, &value)); + DecimalTestFromPythonDecimalRescale(::arrow::decimal128(10, 2), + this->CreatePythonDecimal("1.001"), {}); + // TODO: Test Decimal256 after implementing scaling. } TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) { // We allow truncation of values that do not lose precision when dividing by 10 * the // difference between the scales, e.g., 1.000 -> 1.00 - Decimal128 value; - OwnedRef python_decimal(this->CreatePythonDecimal("1.000")); - auto type = ::arrow::decimal(10, 2); - const auto& decimal_type = checked_cast(*type); - ASSERT_OK( - internal::DecimalFromPythonDecimal(python_decimal.obj(), decimal_type, &value)); - ASSERT_EQ(100, value.low_bits()); - ASSERT_EQ(0, value.high_bits()); - - ASSERT_OK(internal::DecimalFromPyObject(python_decimal.obj(), decimal_type, &value)); - ASSERT_EQ(100, value.low_bits()); - ASSERT_EQ(0, value.high_bits()); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.000"), 100); + // TODO: Test Decimal256 after implementing scaling. } TEST_F(DecimalTest, FromPythonNegativeDecimalRescale) { - Decimal128 value; - OwnedRef python_decimal(this->CreatePythonDecimal("-1.000")); - auto type = ::arrow::decimal(10, 9); - const auto& decimal_type = checked_cast(*type); - ASSERT_OK( - internal::DecimalFromPythonDecimal(python_decimal.obj(), decimal_type, &value)); - ASSERT_EQ(-1000000000, value); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal128(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); + // TODO: Test Decimal256 after implementing scaling. } -TEST_F(DecimalTest, FromPythonInteger) { +TEST_F(DecimalTest, Decimal128FromPythonInteger) { Decimal128 value; OwnedRef python_long(PyLong_FromLong(42)); - auto type = ::arrow::decimal(10, 2); + auto type = ::arrow::decimal128(10, 2); const auto& decimal_type = checked_cast(*type); ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value)); ASSERT_EQ(4200, value); } -TEST_F(DecimalTest, TestOverflowFails) { +// TODO: Test Decimal256 from python after implementing scaling. + +TEST_F(DecimalTest, TestDecimal128OverflowFails) { Decimal128 value; OwnedRef python_decimal( this->CreatePythonDecimal("9999999999999999999999999999999999999.9")); @@ -394,6 +403,8 @@ TEST_F(DecimalTest, TestOverflowFails) { decimal_type, &value)); } +// TODO: Test Decimal256 overflow after implementing scaling. + TEST_F(DecimalTest, TestNoneAndNaN) { OwnedRef list_ref(PyList_New(4)); PyObject* list = list_ref.obj(); diff --git a/cpp/src/arrow/python/python_to_arrow.cc b/cpp/src/arrow/python/python_to_arrow.cc index d3b9d6b225b..37006219e7e 100644 --- a/cpp/src/arrow/python/python_to_arrow.cc +++ b/cpp/src/arrow/python/python_to_arrow.cc @@ -167,6 +167,12 @@ class PyValue { return value; } + static Result Convert(const Decimal256Type* type, const O&, I obj) { + Decimal256 value; + RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); + return value; + } + static Result Convert(const Date32Type*, const O&, I obj) { int32_t value; if (PyDate_Check(obj)) { diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 14472fc8d12..9e038024e06 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -69,6 +69,14 @@ struct ScalarHashImpl { return StdHash(s.value.low_bits()) & StdHash(s.value.high_bits()); } + Status Visit(const Decimal256Scalar& s) { + Status status = Status::OK(); + for (uint64_t elem : s.value.little_endian_array()) { + status &= StdHash(elem); + } + return status; + } + Status Visit(const BaseListScalar& s) { return ArrayHash(*s.value); } Status Visit(const StructScalar& s) { diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index da7114c05a4..80157a750cb 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -347,6 +347,17 @@ struct ARROW_EXPORT Decimal128Scalar : public Scalar { Decimal128 value; }; +struct ARROW_EXPORT Decimal256Scalar : public Scalar { + using Scalar::Scalar; + using TypeClass = Decimal256Type; + using ValueType = Decimal256; + + Decimal256Scalar(Decimal256 value, std::shared_ptr type) + : Scalar(std::move(type), true), value(value) {} + + Decimal256 value; +}; + struct ARROW_EXPORT BaseListScalar : public Scalar { using Scalar::Scalar; using ValueType = std::shared_ptr; diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index dc8708f689e..71f1ae04ce2 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -127,8 +127,8 @@ TYPED_TEST(TestNumericScalar, MakeScalar) { ASSERT_EQ(ScalarType(3), *three); } -TEST(TestDecimalScalar, Basics) { - auto ty = decimal(3, 2); +TEST(TestDecimal128Scalar, Basics) { + auto ty = decimal128(3, 2); auto pi = Decimal128Scalar(Decimal128("3.14"), ty); auto null = MakeNullScalar(ty); @@ -144,6 +144,23 @@ TEST(TestDecimalScalar, Basics) { ASSERT_FALSE(second->Equals(null)); } +TEST(TestDecimal256Scalar, Basics) { + auto ty = decimal256(3, 2); + auto pi = Decimal256Scalar(Decimal256("3.14"), ty); + auto null = MakeNullScalar(ty); + + ASSERT_EQ(pi.value, Decimal256("3.14")); + + // test Array.GetScalar + auto arr = ArrayFromJSON(ty, "[null, \"3.14\"]"); + ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0)); + ASSERT_OK_AND_ASSIGN(auto second, arr->GetScalar(1)); + ASSERT_TRUE(first->Equals(null)); + ASSERT_FALSE(first->Equals(pi)); + ASSERT_TRUE(second->Equals(pi)); + ASSERT_FALSE(second->Equals(null)); +} + TEST(TestBinaryScalar, Basics) { std::string data = "test data"; auto buf = std::make_shared(data); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 7f0838392e8..41ab7a4e8b5 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -69,7 +69,8 @@ std::vector AllTypeIds() { Type::HALF_FLOAT, Type::FLOAT, Type::DOUBLE, - Type::DECIMAL, + Type::DECIMAL128, + Type::DECIMAL256, Type::DATE32, Type::DATE64, Type::TIME32, diff --git a/cpp/src/arrow/testing/json_internal.cc b/cpp/src/arrow/testing/json_internal.cc index 9bd2f14ed3b..fae0e35b676 100644 --- a/cpp/src/arrow/testing/json_internal.cc +++ b/cpp/src/arrow/testing/json_internal.cc @@ -303,6 +303,13 @@ class SchemaWriter { writer_->Int(type.scale()); } + void WriteTypeMetadata(const Decimal256Type& type) { + writer_->Key("precision"); + writer_->Int(type.precision()); + writer_->Key("scale"); + writer_->Int(type.scale()); + } + void WriteTypeMetadata(const UnionType& type) { writer_->Key("mode"); switch (type.mode()) { @@ -376,6 +383,7 @@ class SchemaWriter { } Status Visit(const Decimal128Type& type) { return WritePrimitive("decimal", type); } + Status Visit(const Decimal256Type& type) { return WritePrimitive("decimal256", type); } Status Visit(const TimestampType& type) { return WritePrimitive("timestamp", type); } Status Visit(const DurationType& type) { return WritePrimitive(kDuration, type); } Status Visit(const MonthIntervalType& type) { return WritePrimitive("interval", type); } @@ -546,6 +554,18 @@ class ArrayWriter { } } + void WriteDataValues(const Decimal256Array& arr) { + static const char null_string[] = "0"; + for (int64_t i = 0; i < arr.length(); ++i) { + if (arr.IsValid(i)) { + const Decimal256 value(arr.GetValue(i)); + writer_->String(value.ToIntegerString()); + } else { + writer_->String(null_string, sizeof(null_string)); + } + } + } + void WriteDataValues(const BooleanArray& arr) { for (int64_t i = 0; i < arr.length(); ++i) { if (arr.IsValid(i)) { @@ -819,8 +839,20 @@ Status GetDecimal(const RjObject& json_type, std::shared_ptr* type) { ARROW_ASSIGN_OR_RAISE(const int32_t precision, GetMemberInt(json_type, "precision")); ARROW_ASSIGN_OR_RAISE(const int32_t scale, GetMemberInt(json_type, "scale")); + int32_t bit_width = 128; + Result maybe_bit_width = GetMemberInt(json_type, "bitWidth"); + if (maybe_bit_width.ok()) { + bit_width = maybe_bit_width.ValueOrDie(); + } - *type = decimal(precision, scale); + if (bit_width == 128) { + *type = decimal128(precision, scale); + } else if (bit_width == 256) { + *type = decimal256(precision, scale); + } else { + return Status::Invalid("Only 128 bit and 256 Decimals are supported. Received", + bit_width); + } return Status::OK(); } @@ -1296,8 +1328,9 @@ class ArrayReader { DCHECK_GT(val.GetStringLength(), 0) << "Empty string found when parsing Decimal128 value"; - Decimal128 value; - ARROW_ASSIGN_OR_RAISE(value, Decimal128::FromString(val.GetString())); + using Value = typename TypeTraits::ScalarType::ValueType; + Value value; + ARROW_ASSIGN_OR_RAISE(value, Value::FromString(val.GetString())); RETURN_NOT_OK(builder.Append(value)); } } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5482deaccac..cbf18a08734 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -70,6 +70,8 @@ constexpr Type::type StructType::type_id; constexpr Type::type Decimal128Type::type_id; +constexpr Type::type Decimal256Type::type_id; + constexpr Type::type SparseUnionType::type_id; constexpr Type::type DenseUnionType::type_id; @@ -130,7 +132,8 @@ std::string ToString(Type::type id) { TO_STRING_CASE(HALF_FLOAT) TO_STRING_CASE(FLOAT) TO_STRING_CASE(DOUBLE) - TO_STRING_CASE(DECIMAL) + TO_STRING_CASE(DECIMAL128) + TO_STRING_CASE(DECIMAL256) TO_STRING_CASE(DATE32) TO_STRING_CASE(DATE64) TO_STRING_CASE(TIME32) @@ -748,7 +751,7 @@ std::vector> StructType::GetAllFieldsByName( // Decimal128 type Decimal128Type::Decimal128Type(int32_t precision, int32_t scale) - : DecimalType(16, precision, scale) { + : DecimalType(type_id, 16, precision, scale) { ARROW_CHECK_GE(precision, kMinPrecision); ARROW_CHECK_LE(precision, kMaxPrecision); } @@ -760,6 +763,22 @@ Result> Decimal128Type::Make(int32_t precision, int32_ return std::make_shared(precision, scale); } +// ---------------------------------------------------------------------- +// Decimal256 type + +Decimal256Type::Decimal256Type(int32_t precision, int32_t scale) + : DecimalType(type_id, 32, precision, scale) { + ARROW_CHECK_GE(precision, kMinPrecision); + ARROW_CHECK_LE(precision, kMaxPrecision); +} + +Result> Decimal256Type::Make(int32_t precision, int32_t scale) { + if (precision < kMinPrecision || precision > kMaxPrecision) { + return Status::Invalid("Decimal precision out of range: ", precision); + } + return std::make_shared(precision, scale); +} + // ---------------------------------------------------------------------- // Dictionary-encoded type @@ -2138,13 +2157,28 @@ std::shared_ptr field(std::string name, std::shared_ptr type, } std::shared_ptr decimal(int32_t precision, int32_t scale) { + return precision <= Decimal128Type::kMaxPrecision ? decimal128(precision, scale) + : decimal256(precision, scale); +} + +std::shared_ptr decimal128(int32_t precision, int32_t scale) { return std::make_shared(precision, scale); } +std::shared_ptr decimal256(int32_t precision, int32_t scale) { + return std::make_shared(precision, scale); +} + std::string Decimal128Type::ToString() const { std::stringstream s; s << "decimal(" << precision_ << ", " << scale_ << ")"; return s.str(); } +std::string Decimal256Type::ToString() const { + std::stringstream s; + s << "decimal256(" << precision_ << ", " << scale_ << ")"; + return s.str(); +} + } // namespace arrow diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index e67cf284760..c8a71ab9c13 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -861,10 +861,9 @@ class ARROW_EXPORT StructType : public NestedType { /// \brief Base type class for (fixed-size) decimal data class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { public: - explicit DecimalType(int32_t byte_width, int32_t precision, int32_t scale) - : FixedSizeBinaryType(byte_width, Type::DECIMAL), - precision_(precision), - scale_(scale) {} + explicit DecimalType(Type::type type_id, int32_t byte_width, int32_t precision, + int32_t scale) + : FixedSizeBinaryType(byte_width, type_id), precision_(precision), scale_(scale) {} int32_t precision() const { return precision_; } int32_t scale() const { return scale_; } @@ -879,7 +878,7 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { /// \brief Concrete type class for 128-bit decimal data class ARROW_EXPORT Decimal128Type : public DecimalType { public: - static constexpr Type::type type_id = Type::DECIMAL; + static constexpr Type::type type_id = Type::DECIMAL128; static constexpr const char* type_name() { return "decimal"; } @@ -896,6 +895,26 @@ class ARROW_EXPORT Decimal128Type : public DecimalType { static constexpr int32_t kMaxPrecision = 38; }; +/// \brief Concrete type class for 256-bit decimal data +class ARROW_EXPORT Decimal256Type : public DecimalType { + public: + static constexpr Type::type type_id = Type::DECIMAL256; + + static constexpr const char* type_name() { return "decimal256"; } + + /// Decimal256Type constructor that aborts on invalid input. + explicit Decimal256Type(int32_t precision, int32_t scale); + + /// Decimal256Type constructor that returns an error on invalid input. + static Result> Make(int32_t precision, int32_t scale); + + std::string ToString() const override; + std::string name() const override { return "decimal256"; } + + static constexpr int32_t kMinPrecision = 1; + static constexpr int32_t kMaxPrecision = 76; +}; + /// \brief Concrete type class for union data class ARROW_EXPORT UnionType : public NestedType { public: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index fc25b27238c..e62a8ca0082 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -143,11 +143,16 @@ class StructBuilder; struct StructScalar; class Decimal128; +class Decimal256; class DecimalType; class Decimal128Type; +class Decimal256Type; class Decimal128Array; +class Decimal256Array; class Decimal128Builder; +class Decimal256Builder; struct Decimal128Scalar; +struct Decimal256Scalar; struct UnionMode { enum type { SPARSE, DENSE }; @@ -326,9 +331,14 @@ struct Type { /// DAY_TIME interval in SQL style INTERVAL_DAY_TIME, - /// Precision- and scale-based decimal type. Storage type depends on the - /// parameters. - DECIMAL, + /// Precision- and scale-based decimal type with 128 bits. + DECIMAL128, + + /// Defined for backward-compatibility. + DECIMAL = DECIMAL128, + + /// Precision- and scale-based decimal type with 256 bits. + DECIMAL256, /// A list of some logical data type LIST, @@ -423,10 +433,18 @@ std::shared_ptr ARROW_EXPORT date64(); ARROW_EXPORT std::shared_ptr fixed_size_binary(int32_t byte_width); -/// \brief Create a Decimal128Type instance +/// \brief Create a Decimal128Type or Decimal256Type instance depending on the precision ARROW_EXPORT std::shared_ptr decimal(int32_t precision, int32_t scale); +/// \brief Create a Decimal128Type instance +ARROW_EXPORT +std::shared_ptr decimal128(int32_t precision, int32_t scale); + +/// \brief Create a Decimal256Type instance +ARROW_EXPORT +std::shared_ptr decimal256(int32_t precision, int32_t scale); + /// \brief Create a ListType instance from its child Field type ARROW_EXPORT std::shared_ptr list(const std::shared_ptr& value_type); diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index e53d259c0fb..d5ece2eea8e 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1770,43 +1770,85 @@ TEST(TestDictionaryType, UnifyLarge) { TEST(TypesTest, TestDecimal128Small) { Decimal128Type t1(8, 4); - ASSERT_EQ(t1.id(), Type::DECIMAL); - ASSERT_EQ(t1.precision(), 8); - ASSERT_EQ(t1.scale(), 4); + EXPECT_EQ(t1.id(), Type::DECIMAL128); + EXPECT_EQ(t1.precision(), 8); + EXPECT_EQ(t1.scale(), 4); - ASSERT_EQ(t1.ToString(), std::string("decimal(8, 4)")); + EXPECT_EQ(t1.ToString(), std::string("decimal(8, 4)")); // Test properties - ASSERT_EQ(t1.byte_width(), 16); - ASSERT_EQ(t1.bit_width(), 128); + EXPECT_EQ(t1.byte_width(), 16); + EXPECT_EQ(t1.bit_width(), 128); } TEST(TypesTest, TestDecimal128Medium) { Decimal128Type t1(12, 5); - ASSERT_EQ(t1.id(), Type::DECIMAL); - ASSERT_EQ(t1.precision(), 12); - ASSERT_EQ(t1.scale(), 5); + EXPECT_EQ(t1.id(), Type::DECIMAL128); + EXPECT_EQ(t1.precision(), 12); + EXPECT_EQ(t1.scale(), 5); - ASSERT_EQ(t1.ToString(), std::string("decimal(12, 5)")); + EXPECT_EQ(t1.ToString(), std::string("decimal(12, 5)")); // Test properties - ASSERT_EQ(t1.byte_width(), 16); - ASSERT_EQ(t1.bit_width(), 128); + EXPECT_EQ(t1.byte_width(), 16); + EXPECT_EQ(t1.bit_width(), 128); } TEST(TypesTest, TestDecimal128Large) { Decimal128Type t1(27, 7); - ASSERT_EQ(t1.id(), Type::DECIMAL); - ASSERT_EQ(t1.precision(), 27); - ASSERT_EQ(t1.scale(), 7); + EXPECT_EQ(t1.id(), Type::DECIMAL128); + EXPECT_EQ(t1.precision(), 27); + EXPECT_EQ(t1.scale(), 7); - ASSERT_EQ(t1.ToString(), std::string("decimal(27, 7)")); + EXPECT_EQ(t1.ToString(), std::string("decimal(27, 7)")); // Test properties - ASSERT_EQ(t1.byte_width(), 16); - ASSERT_EQ(t1.bit_width(), 128); + EXPECT_EQ(t1.byte_width(), 16); + EXPECT_EQ(t1.bit_width(), 128); +} + +TEST(TypesTest, TestDecimal256Small) { + Decimal256Type t1(8, 4); + + EXPECT_EQ(t1.id(), Type::DECIMAL256); + EXPECT_EQ(t1.precision(), 8); + EXPECT_EQ(t1.scale(), 4); + + EXPECT_EQ(t1.ToString(), std::string("decimal256(8, 4)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 32); + EXPECT_EQ(t1.bit_width(), 256); +} + +TEST(TypesTest, TestDecimal256Medium) { + Decimal256Type t1(12, 5); + + EXPECT_EQ(t1.id(), Type::DECIMAL256); + EXPECT_EQ(t1.precision(), 12); + EXPECT_EQ(t1.scale(), 5); + + EXPECT_EQ(t1.ToString(), std::string("decimal256(12, 5)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 32); + EXPECT_EQ(t1.bit_width(), 256); +} + +TEST(TypesTest, TestDecimal256Large) { + Decimal256Type t1(76, 38); + + EXPECT_EQ(t1.id(), Type::DECIMAL256); + EXPECT_EQ(t1.precision(), 76); + EXPECT_EQ(t1.scale(), 38); + + EXPECT_EQ(t1.ToString(), std::string("decimal256(76, 38)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 32); + EXPECT_EQ(t1.bit_width(), 256); } TEST(TypesTest, TestDecimalEquals) { @@ -1815,12 +1857,24 @@ TEST(TypesTest, TestDecimalEquals) { Decimal128Type t3(8, 5); Decimal128Type t4(27, 5); + Decimal256Type t5(8, 4); + Decimal256Type t6(8, 4); + Decimal256Type t7(8, 5); + Decimal256Type t8(27, 5); + FixedSizeBinaryType t9(16); + FixedSizeBinaryType t10(32); AssertTypeEqual(t1, t2); AssertTypeNotEqual(t1, t3); AssertTypeNotEqual(t1, t4); AssertTypeNotEqual(t1, t9); + + AssertTypeEqual(t5, t6); + AssertTypeNotEqual(t5, t1); + AssertTypeNotEqual(t5, t7); + AssertTypeNotEqual(t5, t8); + AssertTypeNotEqual(t5, t10); } } // namespace arrow diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index d2abe573cd5..2dcfc77c437 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -66,7 +66,8 @@ TYPE_ID_TRAIT(TIMESTAMP, TimestampType) TYPE_ID_TRAIT(INTERVAL_DAY_TIME, DayTimeIntervalType) TYPE_ID_TRAIT(INTERVAL_MONTHS, MonthIntervalType) TYPE_ID_TRAIT(DURATION, DurationType) -TYPE_ID_TRAIT(DECIMAL, Decimal128Type) // XXX or DecimalType? +TYPE_ID_TRAIT(DECIMAL128, Decimal128Type) +TYPE_ID_TRAIT(DECIMAL256, Decimal256Type) TYPE_ID_TRAIT(STRUCT, StructType) TYPE_ID_TRAIT(LIST, ListType) TYPE_ID_TRAIT(LARGE_LIST, LargeListType) @@ -288,6 +289,14 @@ struct TypeTraits { constexpr static bool is_parameter_free = false; }; +template <> +struct TypeTraits { + using ArrayType = Decimal256Array; + using BuilderType = Decimal256Builder; + using ScalarType = Decimal256Scalar; + constexpr static bool is_parameter_free = false; +}; + template <> struct TypeTraits { using ArrayType = BinaryArray; @@ -577,6 +586,18 @@ using is_decimal_type = std::is_base_of; template using enable_if_decimal = enable_if_t::value, R>; +template +using is_decimal128_type = std::is_base_of; + +template +using enable_if_decimal128 = enable_if_t::value, R>; + +template +using is_decimal256_type = std::is_base_of; + +template +using enable_if_decimal256 = enable_if_t::value, R>; + // Nested Types template @@ -614,7 +635,7 @@ template using is_list_type = std::integral_constant::value || std::is_same::value || - std::is_same::valuae>; + std::is_same::value>; template using enable_if_list_type = enable_if_t::value, R>; @@ -894,7 +915,8 @@ static inline bool is_dictionary(Type::type type_id) { static inline bool is_fixed_size_binary(Type::type type_id) { switch (type_id) { - case Type::DECIMAL: + case Type::DECIMAL128: + case Type::DECIMAL256: case Type::FIXED_SIZE_BINARY: return true; default: diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index a65ff8f8552..d69334e8e68 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -123,7 +123,7 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = { #ifdef ARROW_USE_NATIVE_INT128 static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF; #else -static constexpr uint64_t kIntMask = 0xFFFFFFFF; +static constexpr uint64_t kInt32Mask = 0xFFFFFFFF; #endif // same as ScaleMultipliers[38] - 1 @@ -254,67 +254,127 @@ BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) { namespace { -// TODO: Remove this guard once it's used by BasicDecimal256 -#ifndef ARROW_USE_NATIVE_INT128 -// This method losslessly multiplies x and y into a 128 bit unsigned integer -// whose high bits will be stored in hi and low bits in lo. -void ExtendAndMultiplyUint64(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* lo) { +// Convenience wrapper type over 128 bit unsigned integers. We opt not to +// replace the uint128_t type in int128_internal.h because it would require +// significantly more implementation work to be done. This class merely +// provides the minimum necessary set of functions to perform 128+ bit +// multiplication operations when there may or may not be native support. #ifdef ARROW_USE_NATIVE_INT128 - const __uint128_t r = static_cast<__uint128_t>(x) * y; - *lo = r & kInt64Mask; - *hi = r >> 64; +struct uint128_t { + uint128_t() {} + uint128_t(uint64_t hi, uint64_t lo) : val_((static_cast<__uint128_t>(hi) << 64) | lo) {} + explicit uint128_t(const BasicDecimal128& decimal) { + val_ = (static_cast<__uint128_t>(decimal.high_bits()) << 64) | decimal.low_bits(); + } + + explicit uint128_t(uint64_t value) : val_(value) {} + + uint64_t hi() { return val_ >> 64; } + uint64_t lo() { return val_ & kInt64Mask; } + + uint128_t& operator+=(const uint128_t& other) { + val_ += other.val_; + return *this; + } + + uint128_t& operator*=(const uint128_t& other) { + val_ *= other.val_; + return *this; + } + + __uint128_t val_; +}; + #else - // If we can't use a native fallback, perform multiplication +// Multiply two 64 bit word components into a 128 bit result, with high bits +// stored in hi and low bits in lo. +inline void ExtendAndMultiply(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* lo) { + // Perform multiplication on two 64 bit words x and y into a 128 bit result // by splitting up x and y into 32 bit high/low bit components, // allowing us to represent the multiplication as // x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32 - // + x_hi * y_hi * 2^64. + // + x_hi * y_hi * 2^64 // - // Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi. + // Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi // Therefore, // lo_lo is (x_lo * y_lo)_lo, // lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo, // hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi, // hi_hi is (x_hi * y_hi)_hi - const uint64_t x_lo = x & kIntMask; - const uint64_t y_lo = y & kIntMask; + const uint64_t x_lo = x & kInt32Mask; + const uint64_t y_lo = y & kInt32Mask; const uint64_t x_hi = x >> 32; const uint64_t y_hi = y >> 32; const uint64_t t = x_lo * y_lo; - const uint64_t t_lo = t & kIntMask; + const uint64_t t_lo = t & kInt32Mask; const uint64_t t_hi = t >> 32; const uint64_t u = x_hi * y_lo + t_hi; - const uint64_t u_lo = u & kIntMask; + const uint64_t u_lo = u & kInt32Mask; const uint64_t u_hi = u >> 32; const uint64_t v = x_lo * y_hi + u_lo; const uint64_t v_hi = v >> 32; *hi = x_hi * y_hi + u_hi + v_hi; - *lo = (v << 32) | t_lo; -#endif + *lo = (v << 32) + t_lo; } -#endif -void MultiplyUint128(uint64_t x_hi, uint64_t x_lo, uint64_t y_hi, uint64_t y_lo, - uint64_t* hi, uint64_t* lo) { -#ifdef ARROW_USE_NATIVE_INT128 - const __uint128_t x = (static_cast<__uint128_t>(x_hi) << 64) | x_lo; - const __uint128_t y = (static_cast<__uint128_t>(y_hi) << 64) | y_lo; - const __uint128_t r = x * y; - *lo = r & kInt64Mask; - *hi = r >> 64; -#else - // To perform 128 bit multiplication without a native fallback - // we first perform lossless 64 bit multiplication of the low - // bits, and then add x_hi * y_lo and x_lo * y_hi to the high - // bits. Note that we can skip adding x_hi * y_hi because it - // always will be over 128 bits. - ExtendAndMultiplyUint64(x_lo, y_lo, hi, lo); - *hi += (x_hi * y_lo) + (x_lo * y_hi); +struct uint128_t { + uint128_t() {} + uint128_t(uint64_t hi, uint64_t lo) : hi_(hi), lo_(lo) {} + explicit uint128_t(const BasicDecimal128& decimal) { + hi_ = decimal.high_bits(); + lo_ = decimal.low_bits(); + } + + uint64_t hi() const { return hi_; } + uint64_t lo() const { return lo_; } + + uint128_t& operator+=(const uint128_t& other) { + // To deduce the carry bit, we perform "65 bit" addition on the low bits and + // seeing if the resulting high bit is 1. This is accomplished by shifting the + // low bits to the right by 1 (chopping off the lowest bit), then adding 1 if the + // result of adding the two chopped bits would have produced a carry. + uint64_t carry = (((lo_ & other.lo_) & 1) + (lo_ >> 1) + (other.lo_ >> 1)) >> 63; + hi_ += other.hi_ + carry; + lo_ += other.lo_; + return *this; + } + + uint128_t& operator*=(const uint128_t& other) { + uint128_t r; + ExtendAndMultiply(lo_, other.lo_, &r.hi_, &r.lo_); + r.hi_ += (hi_ * other.lo_) + (lo_ * other.hi_); + *this = r; + return *this; + } + + uint64_t hi_; + uint64_t lo_; +}; #endif + +// Multiplies two N * 64 bit unsigned integer types, represented by a uint64_t +// array into a same sized output. Elements in the array should be in +// little endian order, and output will be the same. Overflow in multiplication +// will result in the lower N * 64 bits of the result being set. +template +inline void MultiplyUnsignedArray(const std::array& lh, + const std::array& rh, + std::array* result) { + for (int j = 0; j < N; ++j) { + uint64_t carry = 0; + for (int i = 0; i < N - j; ++i) { + uint128_t tmp(lh[i]); + tmp *= uint128_t(rh[j]); + tmp += uint128_t((*result)[i + j]); + tmp += uint128_t(carry); + (*result)[i + j] = tmp.lo(); + carry = tmp.hi(); + } + } } } // namespace @@ -325,10 +385,10 @@ BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) { const bool negate = Sign() != right.Sign(); BasicDecimal128 x = BasicDecimal128::Abs(*this); BasicDecimal128 y = BasicDecimal128::Abs(right); - uint64_t hi; - MultiplyUint128(x.high_bits(), x.low_bits(), y.high_bits(), y.low_bits(), &hi, - &low_bits_); - high_bits_ = hi; + uint128_t r(x); + r *= uint128_t{y}; + high_bits_ = r.hi(); + low_bits_ = r.lo(); if (negate) { Negate(); } @@ -775,4 +835,99 @@ int32_t BasicDecimal128::CountLeadingBinaryZeros() const { } } +#if ARROW_LITTLE_ENDIAN +BasicDecimal256::BasicDecimal256(const uint8_t* bytes) + : little_endian_array_( + std::array({reinterpret_cast(bytes)[0], + reinterpret_cast(bytes)[1], + reinterpret_cast(bytes)[2], + reinterpret_cast(bytes)[3]})) {} +#else +BasicDecimal256::BasicDecimal256(const uint8_t* bytes) + : little_endian_array_( + std::array({reinterpret_cast(bytes)[3], + reinterpret_cast(bytes)[2], + reinterpret_cast(bytes)[1], + reinterpret_cast(bytes)[0]})) { +#endif + +BasicDecimal256& BasicDecimal256::Negate() { + uint64_t carry = 1; + for (uint64_t& elem : little_endian_array_) { + elem = ~elem + carry; + carry &= (elem == 0); + } + return *this; +} + +BasicDecimal256& BasicDecimal256::Abs() { return *this < 0 ? Negate() : *this; } + +BasicDecimal256 BasicDecimal256::Abs(const BasicDecimal256& in) { + BasicDecimal256 result(in); + return result.Abs(); +} + +std::array BasicDecimal256::ToBytes() const { + std::array out{{0}}; + ToBytes(out.data()); + return out; +} + +void BasicDecimal256::ToBytes(uint8_t* out) const { + DCHECK_NE(out, nullptr); +#if ARROW_LITTLE_ENDIAN + reinterpret_cast(out)[0] = little_endian_array_[0]; + reinterpret_cast(out)[1] = little_endian_array_[1]; + reinterpret_cast(out)[2] = little_endian_array_[2]; + reinterpret_cast(out)[3] = little_endian_array_[3]; +#else + reinterpret_cast(out)[0] = little_endian_array_[3]; + reinterpret_cast(out)[1] = little_endian_array_[2]; + reinterpret_cast(out)[2] = little_endian_array_[1]; + reinterpret_cast(out)[3] = little_endian_array_[0]; +#endif +} + +BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) { + // Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the + // min the negation taking the absolute values here should always be safe. + const bool negate = Sign() != right.Sign(); + BasicDecimal256 x = BasicDecimal256::Abs(*this); + BasicDecimal256 y = BasicDecimal256::Abs(right); + + uint128_t r_hi; + uint128_t r_lo; + std::array res{0, 0, 0, 0}; + MultiplyUnsignedArray<4>(x.little_endian_array_, y.little_endian_array_, &res); + little_endian_array_ = res; + if (negate) { + Negate(); + } + return *this; +} + +DecimalStatus BasicDecimal256::Rescale(int32_t original_scale, int32_t new_scale, + BasicDecimal256* out) const { + if (original_scale == new_scale) { + return DecimalStatus::kSuccess; + } + // TODO: implement. + return DecimalStatus::kRescaleDataLoss; +} + +BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) { + BasicDecimal256 result = left; + result *= right; + return result; +} + +bool operator<(const BasicDecimal256& left, const BasicDecimal256& right) { + const std::array& lhs = left.little_endian_array(); + const std::array& rhs = right.little_endian_array(); + return lhs[3] != rhs[3] + ? static_cast(lhs[3]) < static_cast(rhs[3]) + : lhs[2] != rhs[2] ? lhs[2] < rhs[2] + : lhs[1] != rhs[1] ? lhs[1] < rhs[1] : lhs[0] < rhs[0]; +} + } // namespace arrow diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index 23c38bbb9d3..55a23183830 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -109,10 +109,10 @@ class ARROW_EXPORT BasicDecimal128 { BasicDecimal128& operator>>=(uint32_t bits); /// \brief Get the high bits of the two's complement representation of the number. - inline int64_t high_bits() const { return high_bits_; } + inline constexpr int64_t high_bits() const { return high_bits_; } /// \brief Get the low bits of the two's complement representation of the number. - inline uint64_t low_bits() const { return low_bits_; } + inline constexpr uint64_t low_bits() const { return low_bits_; } /// \brief Return the raw bytes of the value in native-endian byte order. std::array ToBytes() const; @@ -178,4 +178,104 @@ ARROW_EXPORT BasicDecimal128 operator/(const BasicDecimal128& left, ARROW_EXPORT BasicDecimal128 operator%(const BasicDecimal128& left, const BasicDecimal128& right); +class ARROW_EXPORT BasicDecimal256 { + private: + // Due to a bug in clang, we have to declare the extend method prior to its + // usage. + template + inline static constexpr uint64_t extend(T low_bits) noexcept { + return low_bits >= T() ? uint64_t{0} : ~uint64_t{0}; + } + + public: + /// \brief Create a BasicDecimal256 from the two's complement representation. + constexpr BasicDecimal256(const std::array& little_endian_array) noexcept + : little_endian_array_(little_endian_array) {} + + /// \brief Empty constructor creates a BasicDecimal256 with a value of 0. + constexpr BasicDecimal256() noexcept : little_endian_array_({0, 0, 0, 0}) {} + + /// \brief Convert any integer value into a BasicDecimal256. + template ::value && (sizeof(T) <= sizeof(uint64_t)), T>::type> + constexpr BasicDecimal256(T value) noexcept + : little_endian_array_({static_cast(value), extend(value), extend(value), + extend(value)}) {} + + constexpr BasicDecimal256(const BasicDecimal128& value) noexcept + : little_endian_array_({value.low_bits(), static_cast(value.high_bits()), + extend(value.high_bits()), extend(value.high_bits())}) {} + + /// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in + /// native-endian byte order. + explicit BasicDecimal256(const uint8_t* bytes); + + /// \brief Negate the current value (in-place) + BasicDecimal256& Negate(); + + /// \brief Absolute value (in-place) + BasicDecimal256& Abs(); + + /// \brief Absolute value + static BasicDecimal256 Abs(const BasicDecimal256& left); + + /// \brief Get the bits of the two's complement representation of the number. The 4 + /// elements are in little endian order. The bits within each uint64_t element are in + /// native endian order. For example, + /// BasicDecimal256(123).little_endian_array() = {123, 0, 0, 0}; + /// BasicDecimal256(-2).little_endian_array() = {0xFF...FE, 0xFF...FF, 0xFF...FF, + /// 0xFF...FF}. + inline const std::array& little_endian_array() const { + return little_endian_array_; + } + + /// \brief Return the raw bytes of the value in native-endian byte order. + std::array ToBytes() const; + void ToBytes(uint8_t* out) const; + + /// \brief Convert BasicDecimal128 from one scale to another + DecimalStatus Rescale(int32_t original_scale, int32_t new_scale, + BasicDecimal256* out) const; + + inline int64_t Sign() const { + return 1 | (static_cast(little_endian_array_[3]) >> 63); + } + + /// \brief Multiply this number by another number. The result is truncated to 256 bits. + BasicDecimal256& operator*=(const BasicDecimal256& right); + + private: + std::array little_endian_array_; +}; + +ARROW_EXPORT inline bool operator==(const BasicDecimal256& left, + const BasicDecimal256& right) { + return left.little_endian_array() == right.little_endian_array(); +} + +ARROW_EXPORT inline bool operator!=(const BasicDecimal256& left, + const BasicDecimal256& right) { + return left.little_endian_array() != right.little_endian_array(); +} + +ARROW_EXPORT bool operator<(const BasicDecimal256& left, const BasicDecimal256& right); + +ARROW_EXPORT inline bool operator<=(const BasicDecimal256& left, + const BasicDecimal256& right) { + return !operator<(right, left); +} + +ARROW_EXPORT inline bool operator>(const BasicDecimal256& left, + const BasicDecimal256& right) { + return operator<(right, left); +} + +ARROW_EXPORT inline bool operator>=(const BasicDecimal256& left, + const BasicDecimal256& right) { + return !operator<(left, right); +} + +ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left, + const BasicDecimal256& right); } // namespace arrow diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc index 52a1da4fca3..c38e66ca810 100644 --- a/cpp/src/arrow/util/decimal.cc +++ b/cpp/src/arrow/util/decimal.cc @@ -445,6 +445,24 @@ bool ParseDecimalComponents(const char* s, size_t size, DecimalComponents* out) return pos == size; } +inline Status ToArrowStatus(DecimalStatus dstatus, int num_bits) { + switch (dstatus) { + case DecimalStatus::kSuccess: + return Status::OK(); + + case DecimalStatus::kDivideByZero: + return Status::Invalid("Division by 0 in Decimal", num_bits); + + case DecimalStatus::kOverflow: + return Status::Invalid("Overflow occurred during Decimal", num_bits, " operation."); + + case DecimalStatus::kRescaleDataLoss: + return Status::Invalid("Rescaling Decimal", num_bits, + " value would cause data loss"); + } + return Status::OK(); +} + } // namespace Status Decimal128::FromString(const util::string_view& s, Decimal128* out, @@ -598,31 +616,114 @@ Result Decimal128::FromBigEndian(const uint8_t* bytes, int32_t lengt } Status Decimal128::ToArrowStatus(DecimalStatus dstatus) const { - Status status; + return arrow::ToArrowStatus(dstatus, 128); +} - switch (dstatus) { - case DecimalStatus::kSuccess: - status = Status::OK(); - break; +std::ostream& operator<<(std::ostream& os, const Decimal128& decimal) { + os << decimal.ToIntegerString(); + return os; +} - case DecimalStatus::kDivideByZero: - status = Status::Invalid("Division by 0 in Decimal128"); - break; +Decimal256::Decimal256(const std::string& str) : Decimal256() { + *this = Decimal256::FromString(str).ValueOrDie(); +} - case DecimalStatus::kOverflow: - status = Status::Invalid("Overflow occurred during Decimal128 operation."); - break; +std::string Decimal256::ToIntegerString() const { + std::string result; + if (static_cast(little_endian_array()[3]) < 0) { + result.push_back('-'); + Decimal256 abs = *this; + abs.Negate(); + AppendLittleEndianArrayToString(abs.little_endian_array(), &result); + } else { + AppendLittleEndianArrayToString(little_endian_array(), &result); + } + return result; +} - case DecimalStatus::kRescaleDataLoss: - status = Status::Invalid("Rescaling decimal value would cause data loss"); - break; +std::string Decimal256::ToString(int32_t scale) const { + std::string str(ToIntegerString()); + AdjustIntegerStringWithScale(scale, &str); + return str; +} + +Status Decimal256::FromString(const util::string_view& s, Decimal256* out, + int32_t* precision, int32_t* scale) { + if (s.empty()) { + return Status::Invalid("Empty string cannot be converted to decimal"); + } + + DecimalComponents dec; + if (!ParseDecimalComponents(s.data(), s.size(), &dec)) { + return Status::Invalid("The string '", s, "' is not a valid decimal number"); + } + + // Count number of significant digits (without leading zeros) + size_t first_non_zero = dec.whole_digits.find_first_not_of('0'); + size_t significant_digits = dec.fractional_digits.size(); + if (first_non_zero != std::string::npos) { + significant_digits += dec.whole_digits.size() - first_non_zero; + } + + if (precision != nullptr) { + *precision = static_cast(significant_digits); + } + + if (scale != nullptr) { + if (dec.has_exponent) { + auto adjusted_exponent = dec.exponent; + auto len = static_cast(significant_digits); + *scale = -adjusted_exponent + len - 1; + } else { + *scale = static_cast(dec.fractional_digits.size()); + } + } + + if (out != nullptr) { + std::array little_endian_array = {0, 0, 0, 0}; + ShiftAndAdd(dec.whole_digits, little_endian_array.data(), little_endian_array.size()); + ShiftAndAdd(dec.fractional_digits, little_endian_array.data(), + little_endian_array.size()); + *out = Decimal256(little_endian_array); + + if (dec.sign == '-') { + out->Negate(); + } } - return status; + + return Status::OK(); } -std::ostream& operator<<(std::ostream& os, const Decimal128& decimal) { +Status Decimal256::FromString(const std::string& s, Decimal256* out, int32_t* precision, + int32_t* scale) { + return FromString(util::string_view(s), out, precision, scale); +} + +Status Decimal256::FromString(const char* s, Decimal256* out, int32_t* precision, + int32_t* scale) { + return FromString(util::string_view(s), out, precision, scale); +} + +Result Decimal256::FromString(const util::string_view& s) { + Decimal256 out; + RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr)); + return std::move(out); +} + +Result Decimal256::FromString(const std::string& s) { + return FromString(util::string_view(s)); +} + +Result Decimal256::FromString(const char* s) { + return FromString(util::string_view(s)); +} + +Status Decimal256::ToArrowStatus(DecimalStatus dstatus) const { + return arrow::ToArrowStatus(dstatus, 256); +} + +std::ostream& operator<<(std::ostream& os, const Decimal256& decimal) { os << decimal.ToIntegerString(); return os; } - } // namespace arrow diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 1f727057c13..3b159bc8d88 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -172,4 +172,67 @@ struct Decimal128::ToRealConversion { } }; +/// Represents a signed 256-bit integer in two's complement. +/// The max decimal precision that can be safely represented is +/// 76 significant digits. +/// +/// The implementation is split into two parts : +/// +/// 1. BasicDecimal256 +/// - can be safely compiled to IR without references to libstdc++. +/// 2. Decimal256 +/// - (TODO) has additional functionality on top of BasicDecimal256 to deal with +/// strings and streams. +class ARROW_EXPORT Decimal256 : public BasicDecimal256 { + public: + /// \cond FALSE + // (need to avoid a duplicate definition in Sphinx) + using BasicDecimal256::BasicDecimal256; + /// \endcond + + /// \brief constructor creates a Decimal256 from a BasicDecimal128. + constexpr Decimal256(const BasicDecimal256& value) noexcept : BasicDecimal256(value) {} + + /// \brief Parse the number from a base 10 string representation. + explicit Decimal256(const std::string& value); + + /// \brief Empty constructor creates a Decimal256 with a value of 0. + // This is required on some older compilers. + constexpr Decimal256() noexcept : BasicDecimal256() {} + + /// \brief Convert the Decimal256 value to a base 10 decimal string with the given + /// scale. + std::string ToString(int32_t scale) const; + + /// \brief Convert the value to an integer string + std::string ToIntegerString() const; + + /// \brief Convert a decimal string to a Decimal256 value, optionally including + /// precision and scale if they're passed in and not null. + static Status FromString(const util::string_view& s, Decimal256* out, + int32_t* precision, int32_t* scale = NULLPTR); + static Status FromString(const std::string& s, Decimal256* out, int32_t* precision, + int32_t* scale = NULLPTR); + static Status FromString(const char* s, Decimal256* out, int32_t* precision, + int32_t* scale = NULLPTR); + static Result FromString(const util::string_view& s); + static Result FromString(const std::string& s); + static Result FromString(const char* s); + + /// \brief Convert Decimal256 from one scale to another + Result Rescale(int32_t original_scale, int32_t new_scale) const { + Decimal256 out; + auto dstatus = BasicDecimal256::Rescale(original_scale, new_scale, &out); + ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus)); + return std::move(out); + } + + friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os, + const Decimal256& decimal); + + private: + /// Converts internal error code to Status + Status ToArrowStatus(DecimalStatus dstatus) const; +}; + } // namespace arrow diff --git a/cpp/src/arrow/util/decimal_benchmark.cc b/cpp/src/arrow/util/decimal_benchmark.cc index c1acefc268e..8e2a63dcf9d 100644 --- a/cpp/src/arrow/util/decimal_benchmark.cc +++ b/cpp/src/arrow/util/decimal_benchmark.cc @@ -129,7 +129,7 @@ static void BinaryMathOpAggregate( state.SetItemsProcessed(state.iterations() * kValueSize); } -static void BinaryMathOp(benchmark::State& state) { // NOLINT non-const reference +static void BinaryMathOp128(benchmark::State& state) { // NOLINT non-const reference std::vector v1, v2; for (int x = 0; x < kValueSize; x++) { v1.emplace_back(100 + x, 100 + x); @@ -148,6 +148,21 @@ static void BinaryMathOp(benchmark::State& state) { // NOLINT non-const referen state.SetItemsProcessed(state.iterations() * kValueSize); } +static void BinaryMathOp256(benchmark::State& state) { // NOLINT non-const reference + std::vector v1, v2; + for (uint64_t x = 0; x < kValueSize; x++) { + v1.push_back(BasicDecimal256({100 + x, 100 + x, 100 + x, 100 + x})); + v2.push_back(BasicDecimal256({200 + x, 200 + x, 200 + x, 200 + x})); + } + + for (auto _ : state) { + for (int x = 0; x < kValueSize; x += 5) { + benchmark::DoNotOptimize(v1[x + 2] * v2[x + 2]); + } + } + state.SetItemsProcessed(state.iterations() * kValueSize); +} + static void UnaryOp(benchmark::State& state) { // NOLINT non-const reference std::vector v; for (int x = 0; x < kValueSize; x++) { @@ -190,7 +205,8 @@ static void BinaryBitOp(benchmark::State& state) { // NOLINT non-const referenc BENCHMARK(FromString); BENCHMARK(ToString); -BENCHMARK(BinaryMathOp); +BENCHMARK(BinaryMathOp128); +BENCHMARK(BinaryMathOp256); BENCHMARK(BinaryMathOpAggregate); BENCHMARK(BinaryCompareOp); BENCHMARK(BinaryCompareOpConstant); diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc index c78cd7c3f35..3c2f5335e74 100644 --- a/cpp/src/arrow/util/decimal_test.cc +++ b/cpp/src/arrow/util/decimal_test.cc @@ -26,6 +26,7 @@ #include #include +#include #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -37,6 +38,7 @@ namespace arrow { using internal::int128_t; +using internal::uint128_t; class DecimalTestFixture : public ::testing::Test { public: @@ -1154,4 +1156,156 @@ TEST(Decimal128Test, FitsInPrecision) { Decimal128("-100000000000000000000000000000000000000").FitsInPrecision(38)); } +static constexpr std::array kSortedDecimal256Bits[] = { + {0, 0, 0, 0x8000000000000000ULL}, // min + {0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0xFFFFFFFFFFFFFFFFULL}, // -2 + {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0xFFFFFFFFFFFFFFFFULL}, // -1 + {0, 0, 0, 0}, + {1, 0, 0, 0}, + {2, 0, 0, 0}, + {0xFFFFFFFFFFFFFFFFULL, 0, 0, 0}, + {0, 1, 0, 0}, + {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0, 0}, + {0, 0, 1, 0}, + {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0}, + {0, 0, 0, 1}, + {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, + 0x7FFFFFFFFFFFFFFFULL}, // max +}; + +TEST(Decimal256Test, TestComparators) { + constexpr size_t num_values = + sizeof(kSortedDecimal256Bits) / sizeof(kSortedDecimal256Bits[0]); + for (size_t i = 0; i < num_values; ++i) { + Decimal256 left(kSortedDecimal256Bits[i]); + for (size_t j = 0; j < num_values; ++j) { + Decimal256 right(kSortedDecimal256Bits[j]); + EXPECT_EQ(i == j, left == right); + EXPECT_EQ(i != j, left != right); + EXPECT_EQ(i < j, left < right); + EXPECT_EQ(i > j, left > right); + EXPECT_EQ(i <= j, left <= right); + EXPECT_EQ(i >= j, left >= right); + } + } +} + +TEST(Decimal256Test, TestToBytesRoundTrip) { + for (const std::array& bits : kSortedDecimal256Bits) { + Decimal256 decimal(bits); + EXPECT_EQ(decimal, Decimal256(decimal.ToBytes().data())); + } +} + +template +class Decimal256Test : public ::testing::Test { + public: + Decimal256Test() {} +}; + +using Decimal256Types = + ::testing::Types; + +TYPED_TEST_SUITE(Decimal256Test, Decimal256Types); + +TYPED_TEST(Decimal256Test, ConstructibleFromAnyIntegerType) { + using UInt64Array = std::array; + Decimal256 value(TypeParam{42}); + EXPECT_EQ(UInt64Array({42, 0, 0, 0}), value.little_endian_array()); + + TypeParam max = std::numeric_limits::max(); + Decimal256 max_value(max); + EXPECT_EQ(UInt64Array({static_cast(max), 0, 0, 0}), + max_value.little_endian_array()); + + TypeParam min = std::numeric_limits::min(); + Decimal256 min_value(min); + uint64_t high_bits = std::is_signed::value ? ~uint64_t{0} : uint64_t{0}; + EXPECT_EQ(UInt64Array({static_cast(min), high_bits, high_bits, high_bits}), + min_value.little_endian_array()); +} + +TEST(Decimal256Test, ConstructibleFromBool) { + EXPECT_EQ(Decimal256(0), Decimal256(false)); + EXPECT_EQ(Decimal256(1), Decimal256(true)); +} + +Decimal256 Decimal256FromInt128(int128_t value) { + return Decimal256(Decimal128(static_cast(value >> 64), + static_cast(value & 0xFFFFFFFFFFFFFFFFULL))); +} + +TEST(Decimal256Test, Multiply) { + using boost::multiprecision::int256_t; + using boost::multiprecision::uint256_t; + + ASSERT_EQ(Decimal256(60501), Decimal256(301) * Decimal256(201)); + + ASSERT_EQ(Decimal256(-60501), Decimal256(-301) * Decimal256(201)); + + ASSERT_EQ(Decimal256(-60501), Decimal256(301) * Decimal256(-201)); + + ASSERT_EQ(Decimal256(60501), Decimal256(-301) * Decimal256(-201)); + + // Test some random numbers. + std::vector left; + std::vector right; + for (auto x : GetRandomNumbers(16)) { + for (auto y : GetRandomNumbers(16)) { + for (auto z : GetRandomNumbers(16)) { + for (auto w : GetRandomNumbers(16)) { + // Test two 128 bit numbers which have a large amount of bits set. + int128_t l = static_cast(x) << 96 | static_cast(y) << 64 | + static_cast(z) << 32 | static_cast(w); + int128_t r = static_cast(w) << 96 | static_cast(z) << 64 | + static_cast(y) << 32 | static_cast(x); + int256_t expected = int256_t(l) * r; + Decimal256 actual = Decimal256FromInt128(l) * Decimal256FromInt128(r); + ASSERT_EQ(expected.str(), actual.ToIntegerString()) + << " " << int256_t(l).str() << " * " << int256_t(r).str(); + // Test a 96 bit number against a 160 bit number. + int128_t s = l >> 32; + uint256_t b = uint256_t(r) << 32; + Decimal256 b_dec = + Decimal256FromInt128(r) * Decimal256(static_cast(1) << 32); + ASSERT_EQ(b.str(), b_dec.ToIntegerString()) << int256_t(r).str(); + expected = int256_t(s) * b; + actual = Decimal256FromInt128(s) * b_dec; + ASSERT_EQ(expected.str(), actual.ToIntegerString()) + << " " << int256_t(s).str() << " * " << int256_t(b).str(); + } + } + } + } + + // Test some edge cases + for (auto x : std::vector{-INT64_MAX, -INT32_MAX, 0, INT32_MAX, INT64_MAX}) { + for (auto y : + std::vector{-INT32_MAX, -32, -2, -1, 0, 1, 2, 32, INT32_MAX}) { + Decimal256 decimal_x = Decimal256FromInt128(x); + Decimal256 decimal_y = Decimal256FromInt128(y); + Decimal256 result = decimal_x * decimal_y; + EXPECT_EQ(Decimal256FromInt128(x * y), result) + << " x: " << decimal_x << " y: " << decimal_y; + } + } +} + +class Decimal256ToStringTest : public ::testing::TestWithParam {}; + +TEST_P(Decimal256ToStringTest, ToString) { + const ToStringTestParam& data = GetParam(); + const Decimal256 value(data.test_value); + const std::string printed_value = value.ToString(data.scale); + ASSERT_EQ(data.expected_string, printed_value); +} + +INSTANTIATE_TEST_SUITE_P(Decimal256ToStringTest, Decimal256ToStringTest, + ::testing::ValuesIn(kToStringTestData)); + } // namespace arrow diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index f1c4b1e6318..00fb745f529 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -842,8 +842,8 @@ struct HashTraits::value && using MemoTableType = BinaryMemoTable; }; -template <> -struct HashTraits { +template +struct HashTraits> { using MemoTableType = BinaryMemoTable; }; diff --git a/cpp/src/arrow/visitor.cc b/cpp/src/arrow/visitor.cc index 0a452d5c594..851785081c7 100644 --- a/cpp/src/arrow/visitor.cc +++ b/cpp/src/arrow/visitor.cc @@ -67,6 +67,7 @@ ARRAY_VISITOR_DEFAULT(SparseUnionArray) ARRAY_VISITOR_DEFAULT(DenseUnionArray) ARRAY_VISITOR_DEFAULT(DictionaryArray) ARRAY_VISITOR_DEFAULT(Decimal128Array) +ARRAY_VISITOR_DEFAULT(Decimal256Array) ARRAY_VISITOR_DEFAULT(ExtensionArray) #undef ARRAY_VISITOR_DEFAULT @@ -106,6 +107,7 @@ TYPE_VISITOR_DEFAULT(DayTimeIntervalType) TYPE_VISITOR_DEFAULT(MonthIntervalType) TYPE_VISITOR_DEFAULT(DurationType) TYPE_VISITOR_DEFAULT(Decimal128Type) +TYPE_VISITOR_DEFAULT(Decimal256Type) TYPE_VISITOR_DEFAULT(ListType) TYPE_VISITOR_DEFAULT(LargeListType) TYPE_VISITOR_DEFAULT(MapType) @@ -154,6 +156,7 @@ SCALAR_VISITOR_DEFAULT(DayTimeIntervalScalar) SCALAR_VISITOR_DEFAULT(MonthIntervalScalar) SCALAR_VISITOR_DEFAULT(DurationScalar) SCALAR_VISITOR_DEFAULT(Decimal128Scalar) +SCALAR_VISITOR_DEFAULT(Decimal256Scalar) SCALAR_VISITOR_DEFAULT(ListScalar) SCALAR_VISITOR_DEFAULT(LargeListScalar) SCALAR_VISITOR_DEFAULT(MapScalar) diff --git a/cpp/src/arrow/visitor.h b/cpp/src/arrow/visitor.h index 7ab136c066f..0382e461199 100644 --- a/cpp/src/arrow/visitor.h +++ b/cpp/src/arrow/visitor.h @@ -54,6 +54,7 @@ class ARROW_EXPORT ArrayVisitor { virtual Status Visit(const MonthIntervalArray& array); virtual Status Visit(const DurationArray& array); virtual Status Visit(const Decimal128Array& array); + virtual Status Visit(const Decimal256Array& array); virtual Status Visit(const ListArray& array); virtual Status Visit(const LargeListArray& array); virtual Status Visit(const MapArray& array); @@ -96,6 +97,7 @@ class ARROW_EXPORT TypeVisitor { virtual Status Visit(const DayTimeIntervalType& type); virtual Status Visit(const DurationType& type); virtual Status Visit(const Decimal128Type& type); + virtual Status Visit(const Decimal256Type& type); virtual Status Visit(const ListType& type); virtual Status Visit(const LargeListType& type); virtual Status Visit(const MapType& type); @@ -138,6 +140,7 @@ class ARROW_EXPORT ScalarVisitor { virtual Status Visit(const MonthIntervalScalar& scalar); virtual Status Visit(const DurationScalar& scalar); virtual Status Visit(const Decimal128Scalar& scalar); + virtual Status Visit(const Decimal256Scalar& scalar); virtual Status Visit(const ListScalar& scalar); virtual Status Visit(const LargeListScalar& scalar); virtual Status Visit(const MapScalar& scalar); diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h index bff97fcd9eb..45193f20413 100644 --- a/cpp/src/arrow/visitor_inline.h +++ b/cpp/src/arrow/visitor_inline.h @@ -68,6 +68,7 @@ namespace arrow { ACTION(MonthInterval); \ ACTION(DayTimeInterval); \ ACTION(Decimal128); \ + ACTION(Decimal256); \ ACTION(List); \ ACTION(LargeList); \ ACTION(Map); \ diff --git a/cpp/src/parquet/arrow/reader_internal.cc b/cpp/src/parquet/arrow/reader_internal.cc index 9ce37f31579..c74d9f0567c 100644 --- a/cpp/src/parquet/arrow/reader_internal.cc +++ b/cpp/src/parquet/arrow/reader_internal.cc @@ -645,7 +645,9 @@ static Status DecimalIntegerTransfer(RecordReader* reader, MemoryPool* pool, template Status TransferDecimal(RecordReader* reader, MemoryPool* pool, const std::shared_ptr& type, Datum* out) { - DCHECK_EQ(type->id(), ::arrow::Type::DECIMAL); + if (type->id() != ::arrow::Type::DECIMAL128) { + return Status::NotImplemented("Only reading decimal128 types is currently supported"); + } auto binary_reader = dynamic_cast(reader); DCHECK(binary_reader); diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 91b2f451314..555c7a85f1a 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -299,8 +299,7 @@ Status FieldToNode(const std::string& name, const std::shared_ptr& field, } break; case ArrowTypeId::DECIMAL: { type = ParquetType::FIXED_LEN_BYTE_ARRAY; - const auto& decimal_type = - static_cast(*field->type()); + const auto& decimal_type = static_cast(*field->type()); precision = decimal_type.precision(); scale = decimal_type.scale(); length = DecimalSize(precision); diff --git a/cpp/src/parquet/arrow/writer.cc b/cpp/src/parquet/arrow/writer.cc index 6115c027ff2..a1018d2c32c 100644 --- a/cpp/src/parquet/arrow/writer.cc +++ b/cpp/src/parquet/arrow/writer.cc @@ -50,7 +50,6 @@ using arrow::BinaryArray; using arrow::BooleanArray; using arrow::ChunkedArray; using arrow::DataType; -using arrow::Decimal128Array; using arrow::DictionaryArray; using arrow::ExtensionArray; using arrow::ExtensionType; diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index eb3942f29c1..ed7058f82eb 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -1860,7 +1860,8 @@ using ::arrow::internal::checked_pointer_cast; // Requires a custom serializer because decimal128 in parquet are in big-endian // format. Thus, a temporary local buffer is required. template -struct SerializeFunctor> { +struct SerializeFunctor> { Status Serialize(const ::arrow::Decimal128Array& array, ArrowWriteContext* ctx, FLBA* out) { AllocateScratch(array, ctx); @@ -1908,13 +1909,23 @@ struct SerializeFunctor +struct SerializeFunctor> { + Status Serialize(const ::arrow::Decimal256Array& array, ArrowWriteContext* ctx, + FLBA* out) { + return Status::NotImplemented("Decimal256 serialization isn't implemented"); + } +}; + template <> Status TypedColumnWriterImpl::WriteArrowDense( const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels, const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) { switch (array.type()->id()) { WRITE_SERIALIZE_CASE(FIXED_SIZE_BINARY, FixedSizeBinaryType, FLBAType) - WRITE_SERIALIZE_CASE(DECIMAL, Decimal128Type, FLBAType) + WRITE_SERIALIZE_CASE(DECIMAL128, Decimal128Type, FLBAType) + WRITE_SERIALIZE_CASE(DECIMAL256, Decimal256Type, FLBAType) default: break; } diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 32ecbb6430f..3d50381f0d3 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -400,14 +400,15 @@ def generate_column(self, size, name=None): DECIMAL_PRECISION_TO_VALUE = { key: (1 << (8 * i - 1)) - 1 for i, key in enumerate( - [1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36], + [1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36, + 40, 42, 44, 50, 60, 70], start=1, ) } def decimal_range_from_precision(precision): - assert 1 <= precision <= 38 + assert 1 <= precision <= 76 try: max_value = DECIMAL_PRECISION_TO_VALUE[precision] except KeyError: @@ -417,7 +418,7 @@ def decimal_range_from_precision(precision): class DecimalField(PrimitiveField): - def __init__(self, name, precision, scale, bit_width=128, *, + def __init__(self, name, precision, scale, bit_width, *, nullable=True, metadata=None): super().__init__(name, nullable=True, metadata=metadata) @@ -434,6 +435,7 @@ def _get_type(self): ('name', 'decimal'), ('precision', self.precision), ('scale', self.scale), + ('bitWidth', self.bit_width), ]) def generate_column(self, size, name=None): @@ -448,7 +450,7 @@ def generate_column(self, size, name=None): class DecimalColumn(PrimitiveColumn): - def __init__(self, name, count, is_valid, values, bit_width=128): + def __init__(self, name, count, is_valid, values, bit_width): super().__init__(name, count, is_valid, values) self.bit_width = bit_width @@ -1272,17 +1274,33 @@ def generate_null_trivial_case(batch_sizes): return _generate_file('null_trivial', fields, batch_sizes) -def generate_decimal_case(): +def generate_decimal128_case(): fields = [ - DecimalField(name='f{}'.format(i), precision=precision, scale=2) + DecimalField(name='f{}'.format(i), precision=precision, scale=2, + bit_width=128) for i, precision in enumerate(range(3, 39)) ] possible_batch_sizes = 7, 10 batch_sizes = [possible_batch_sizes[i % 2] for i in range(len(fields))] + # 'decimal' is the original name for the test, and it must match + # provide "gold" files that test backwards compatibility, so they + # can be appropriately skipped. return _generate_file('decimal', fields, batch_sizes) +def generate_decimal256_case(): + fields = [ + DecimalField(name='f{}'.format(i), precision=precision, scale=5, + bit_width=256) + for i, precision in enumerate(range(37, 70)) + ] + + possible_batch_sizes = 7, 10 + batch_sizes = [possible_batch_sizes[i % 2] for i in range(len(fields))] + return _generate_file('decimal256', fields, batch_sizes) + + def generate_datetime_case(): fields = [ DateField('f0', DateField.DAY), @@ -1508,10 +1526,15 @@ def _temp_path(): .skip_category('JS') # TODO(ARROW-7900) .skip_category('Go'), # TODO(ARROW-7901) - generate_decimal_case() + generate_decimal128_case() .skip_category('Go') # TODO(ARROW-7948): Decimal + Go .skip_category('Rust'), + generate_decimal256_case() + .skip_category('Go') # TODO(ARROW-7948): Decimal + Go + .skip_category('JS') + .skip_category('Rust'), + generate_datetime_case(), generate_interval_case() diff --git a/java/adapter/avro/src/main/java/org/apache/arrow/AvroToArrowUtils.java b/java/adapter/avro/src/main/java/org/apache/arrow/AvroToArrowUtils.java index 29e44dad3a5..80293c8b85c 100644 --- a/java/adapter/avro/src/main/java/org/apache/arrow/AvroToArrowUtils.java +++ b/java/adapter/avro/src/main/java/org/apache/arrow/AvroToArrowUtils.java @@ -298,7 +298,7 @@ private static ArrowType createDecimalArrowType(LogicalTypes.Decimal logicalType Preconditions.checkArgument(scale <= precision, "Invalid decimal scale: %s (greater than precision: %s)", scale, precision); - return new ArrowType.Decimal(precision, scale); + return new ArrowType.Decimal(precision, scale, 128); } diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java index e534d2060c5..f64f178c6f4 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java @@ -251,7 +251,7 @@ public static ArrowType getArrowTypeForJdbcField(JdbcFieldInfo fieldInfo, Calend case Types.DECIMAL: int precision = fieldInfo.getPrecision(); int scale = fieldInfo.getScale(); - return new ArrowType.Decimal(precision, scale); + return new ArrowType.Decimal(precision, scale, 128); case Types.REAL: case Types.FLOAT: return new ArrowType.FloatingPoint(SINGLE); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java index f6b76f98188..e0c072cfbe5 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java @@ -87,7 +87,7 @@ private static Decimal adjustScaleIfNeeded(int precision, int scale) { precision = MAX_PRECISION; scale = Math.max(scale - delta, minScale); } - return new Decimal(precision, scale); + return new Decimal(precision, scale, 128); } } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java index 6b2610ff3d2..0155af08234 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java @@ -175,7 +175,7 @@ private static ArrowType getArrowType(ExtGandivaType type) { case GandivaType.NONE_VALUE: return new ArrowType.Null(); case GandivaType.DECIMAL_VALUE: - return new ArrowType.Decimal(0, 0); + return new ArrowType.Decimal(0, 0, 128); case GandivaType.INTERVAL_VALUE: return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); case GandivaType.FIXED_SIZE_BINARY_VALUE: diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java index 96bffe3f285..fe51c09e33d 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java @@ -83,7 +83,7 @@ public void testOutputTypesForMod() { } private ArrowType.Decimal getDecimal(int precision, int scale) { - return new ArrowType.Decimal(precision, scale); + return new ArrowType.Decimal(precision, scale, 128); } } diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java index f3de03b66f1..28a57c9f8a4 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java @@ -56,7 +56,7 @@ public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.Bas public void test_add() throws GandivaException { int precision = 38; int scale = 8; - ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale); + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); Field a = Field.nullable("a", decimal); Field b = Field.nullable("b", decimal); List args = Lists.newArrayList(a, b); @@ -115,8 +115,8 @@ public void test_add() throws GandivaException { public void test_add_literal() throws GandivaException { int precision = 2; int scale = 0; - ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale); - ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1); + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); + ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1, 128); Field a = Field.nullable("a", decimal); ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil @@ -169,7 +169,7 @@ public void test_add_literal() throws GandivaException { public void test_multiply() throws GandivaException { int precision = 38; int scale = 8; - ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale); + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); Field a = Field.nullable("a", decimal); Field b = Field.nullable("b", decimal); List args = Lists.newArrayList(a, b); @@ -226,8 +226,8 @@ public void test_multiply() throws GandivaException { @Test public void testCompare() throws GandivaException { - Decimal aType = new Decimal(38, 3); - Decimal bType = new Decimal(38, 2); + Decimal aType = new Decimal(38, 3, 128); + Decimal bType = new Decimal(38, 2, 128); Field a = Field.nullable("a", aType); Field b = Field.nullable("b", bType); List args = Lists.newArrayList(a, b); @@ -315,9 +315,9 @@ public void testCompare() throws GandivaException { @Test public void testRound() throws GandivaException { - Decimal aType = new Decimal(38, 2); - Decimal aWithScaleZero = new Decimal(38, 0); - Decimal aWithScaleOne = new Decimal(38, 1); + Decimal aType = new Decimal(38, 2, 128); + Decimal aWithScaleZero = new Decimal(38, 0, 128); + Decimal aWithScaleOne = new Decimal(38, 1, 128); Field a = Field.nullable("a", aType); List args = Lists.newArrayList(a); @@ -419,8 +419,8 @@ public void testRound() throws GandivaException { @Test public void testCastToDecimal() throws GandivaException { - Decimal decimalType = new Decimal(38, 2); - Decimal decimalWithScaleOne = new Decimal(38, 1); + Decimal decimalType = new Decimal(38, 2, 128); + Decimal decimalWithScaleOne = new Decimal(38, 1, 128); Field dec = Field.nullable("dec", decimalType); Field int64f = Field.nullable("int64", int64); Field doublef = Field.nullable("float64", float64); @@ -517,7 +517,7 @@ public void testCastToDecimal() throws GandivaException { @Test public void testCastToLong() throws GandivaException { - Decimal decimalType = new Decimal(38, 2); + Decimal decimalType = new Decimal(38, 2, 128); Field dec = Field.nullable("dec", decimalType); Schema schema = new Schema(Lists.newArrayList(dec)); @@ -575,7 +575,7 @@ public void testCastToLong() throws GandivaException { @Test public void testCastToDouble() throws GandivaException { - Decimal decimalType = new Decimal(38, 2); + Decimal decimalType = new Decimal(38, 2, 128); Field dec = Field.nullable("dec", decimalType); Schema schema = new Schema(Lists.newArrayList(dec)); @@ -633,7 +633,7 @@ public void testCastToDouble() throws GandivaException { @Test public void testCastToString() throws GandivaException { - Decimal decimalType = new Decimal(38, 2); + Decimal decimalType = new Decimal(38, 2, 128); Field dec = Field.nullable("dec", decimalType); Field str = Field.nullable("str", new ArrowType.Utf8()); TreeNode field = TreeBuilder.makeField(dec); @@ -695,7 +695,7 @@ public void testCastToString() throws GandivaException { @Test public void testCastStringToDecimal() throws GandivaException { - Decimal decimalType = new Decimal(4, 2); + Decimal decimalType = new Decimal(4, 2, 128); Field dec = Field.nullable("dec", decimalType); Field str = Field.nullable("str", new ArrowType.Utf8()); @@ -761,7 +761,7 @@ public void testInvalidDecimal() throws GandivaException { exception.expect(IllegalArgumentException.class); exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + " : 0"); - Decimal decimalType = new Decimal(0, 0); + Decimal decimalType = new Decimal(0, 0, 128); Field int64f = Field.nullable("int64", int64); Schema schema = new Schema(Lists.newArrayList(int64f)); @@ -780,7 +780,7 @@ public void testInvalidDecimalGt38() throws GandivaException { exception.expect(IllegalArgumentException.class); exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + " : 42"); - Decimal decimalType = new Decimal(42, 0); + Decimal decimalType = new Decimal(42, 0, 128); Field int64f = Field.nullable("int64", int64); Schema schema = new Schema(Lists.newArrayList(int64f)); diff --git a/java/vector/src/main/codegen/data/ArrowTypes.tdd b/java/vector/src/main/codegen/data/ArrowTypes.tdd index 4d2a540f572..3cf9a968791 100644 --- a/java/vector/src/main/codegen/data/ArrowTypes.tdd +++ b/java/vector/src/main/codegen/data/ArrowTypes.tdd @@ -92,7 +92,7 @@ }, { name: "Decimal", - fields: [{name: "precision", type: int}, {name: "scale", type: int}], + fields: [{name: "precision", type: int}, {name: "scale", type: int}, {name: "bitWidth", type: int}], complex: false }, { diff --git a/java/vector/src/main/codegen/data/ValueVectorTypes.tdd b/java/vector/src/main/codegen/data/ValueVectorTypes.tdd index 7612d3690b9..574b065662e 100644 --- a/java/vector/src/main/codegen/data/ValueVectorTypes.tdd +++ b/java/vector/src/main/codegen/data/ValueVectorTypes.tdd @@ -113,6 +113,22 @@ { class: "IntervalDay", millisecondsOffset: 4, friendlyType: "Duration", fields: [ {name: "days", type:"int"}, {name: "milliseconds", type:"int"}] } ] }, + { + major: "Fixed", + width: 32, + javaType: "ArrowBuf", + boxedType: "ArrowBuf", + + minor: [ + { + class: "Decimal256", + maxPrecisionDigits: 76, nDecimalDigits: 4, friendlyType: "BigDecimal", + typeParams: [ {name: "scale", type: "int"}, { name: "precision", type: "int"}], + arrowType: "org.apache.arrow.vector.types.pojo.ArrowType.Decimal", + fields: [{name: "start", type: "long"}, {name: "buffer", type: "ArrowBuf"}] + } + ] + }, { major: "Fixed", width: 16, @@ -129,6 +145,7 @@ } ] }, + { major: "Fixed", width: -1, diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index 4f6d5ea1aee..bce842d5911 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -76,20 +76,20 @@ public void write(${name}Holder holder) { fail("${name}"); } - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> public void write${minor.class}(${friendlyType} value) { fail("${name}"); } - public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, ArrowType arrowType) { + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, , ArrowType arrowType) { fail("${name}"); } - public void writeBigEndianBytesToDecimal(byte[] value) { + public void writeBigEndianBytesTo${minor.class}(byte[] value) { fail("${name}"); } - public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType) { + public void writeBigEndianBytesTo${minor.class}(byte[] value, ArrowType arrowType) { fail("${name}"); } diff --git a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java index 5566c808258..6b14dbf2a57 100644 --- a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java @@ -75,7 +75,7 @@ public void endList() { <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> - <#if minor.class != "Decimal"> + <#if minor.class != "Decimal" && minor.class != "Decimal256"> @Override public void write(${name}Holder holder) { getWriter(MinorType.${name?upper_case}).write(holder); @@ -85,7 +85,7 @@ public void write(${name}Holder holder) { getWriter(MinorType.${name?upper_case}).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); } - <#else> + <#elseif minor.class == "Decimal"> @Override public void write(DecimalHolder holder) { getWriter(MinorType.DECIMAL).write(holder); @@ -106,6 +106,28 @@ public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType) { public void writeBigEndianBytesToDecimal(byte[] value) { getWriter(MinorType.DECIMAL).writeBigEndianBytesToDecimal(value); } + <#elseif minor.class == "Decimal256"> + @Override + public void write(Decimal256Holder holder) { + getWriter(MinorType.DECIMAL256).write(holder); + } + + public void writeDecimal256(long start, ArrowBuf buffer, ArrowType arrowType) { + getWriter(MinorType.DECIMAL256).writeDecimal256(start, buffer, arrowType); + } + + public void writeDecimal256(long start, ArrowBuf buffer) { + getWriter(MinorType.DECIMAL256).writeDecimal256(start, buffer); + } + public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType) { + getWriter(MinorType.DECIMAL256).writeBigEndianBytesToDecimal256(value, arrowType); + } + + public void writeBigEndianBytesToDecimal256(byte[] value) { + getWriter(MinorType.DECIMAL256).writeBigEndianBytesToDecimal256(value); + } + + diff --git a/java/vector/src/main/codegen/templates/ArrowType.java b/java/vector/src/main/codegen/templates/ArrowType.java index 77894af2365..f8f0e20c940 100644 --- a/java/vector/src/main/codegen/templates/ArrowType.java +++ b/java/vector/src/main/codegen/templates/ArrowType.java @@ -165,7 +165,20 @@ public static class ${name} extends <#if type.complex>ComplexType<#else>Primitiv ${fieldType} ${field.name}; + + <#if type.name == "Decimal"> + // Needed to support golden file integration tests. + @JsonCreator + public static Decimal createDecimal( + @JsonProperty("precision") int precision, + @JsonProperty("scale") int scale, + @JsonProperty("bitWidth") Integer bitWidth) { + + return new Decimal(precision, scale, bitWidth == null ? 128 : bitWidth); + } + <#else> @JsonCreator + public ${type.name}( <#list type.fields as field> <#assign fieldType = field.valueType!field.type> @@ -327,9 +340,8 @@ public static org.apache.arrow.vector.types.pojo.ArrowType getTypeForField(org.a <#if type.name == "Decimal"> - int bitWidth = ${nameLower}Type.bitWidth(); - if (bitWidth != defaultDecimalBitWidth) { - throw new IllegalArgumentException("Library only supports 128-bit decimal values"); + if (bitWidth != defaultDecimalBitWidth && bitWidth != 256) { + throw new IllegalArgumentException("Library only supports 128-bit and 256-bit decimal values"); } return new ${name}(<#list type.fields as field><#if field.valueType??>${field.valueType}.fromFlatbufID(${field.name})<#else>${field.name}<#if field_has_next>, ); diff --git a/java/vector/src/main/codegen/templates/ComplexCopier.java b/java/vector/src/main/codegen/templates/ComplexCopier.java index 1189e8e04e0..39a84041e7e 100644 --- a/java/vector/src/main/codegen/templates/ComplexCopier.java +++ b/java/vector/src/main/codegen/templates/ComplexCopier.java @@ -124,7 +124,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer) { Nullable${name}Holder ${uncappedName}Holder = new Nullable${name}Holder(); reader.read(${uncappedName}Holder); if (${uncappedName}Holder.isSet == 1) { - writer.write${name}(<#list fields as field>${uncappedName}Holder.${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, new ArrowType.Decimal(decimalHolder.precision, decimalHolder.scale)); + writer.write${name}(<#list fields as field>${uncappedName}Holder.${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, new ArrowType.Decimal(${uncappedName}Holder.precision, ${uncappedName}Holder.scale, ${name}Holder.WIDTH * 8)); } } else { writer.writeNull(); @@ -145,7 +145,7 @@ private static FieldWriter getStructWriterForReader(FieldReader reader, StructWr case ${name?upper_case}: return (FieldWriter) writer.<#if name == "Int">integer<#else>${uncappedName}(name); - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> case ${name?upper_case}: if (reader.getField().getType() instanceof ArrowType.Decimal) { ArrowType.Decimal type = (ArrowType.Decimal) reader.getField().getType(); @@ -154,6 +154,7 @@ private static FieldWriter getStructWriterForReader(FieldReader reader, StructWr return (FieldWriter) writer.${uncappedName}(name); } + case STRUCT: return (FieldWriter) writer.struct(name); diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index 5f5025ff59e..0381e5559e4 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -99,7 +99,7 @@ public void setPosition(int idx) { <#else> - <#if minor.class != "Decimal"> + <#if !minor.class?starts_with("Decimal")> public void write(${minor.class}Holder h) { vector.setSafe(idx(), h); vector.setValueCount(idx()+1); @@ -123,15 +123,15 @@ public void write(Nullable${minor.class}Holder h) { } - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> - public void write(DecimalHolder h){ + public void write(${minor.class}Holder h){ DecimalUtility.checkPrecisionAndScale(h.precision, h.scale, vector.getPrecision(), vector.getScale()); vector.setSafe(idx(), h); vector.setValueCount(idx() + 1); } - public void write(NullableDecimalHolder h){ + public void write(Nullable${minor.class}Holder h){ if (h.isSet == 1) { DecimalUtility.checkPrecisionAndScale(h.precision, h.scale, vector.getPrecision(), vector.getScale()); } @@ -139,37 +139,38 @@ public void write(NullableDecimalHolder h){ vector.setValueCount(idx() + 1); } - public void writeDecimal(long start, ArrowBuf buffer){ + public void write${minor.class}(long start, ArrowBuf buffer){ vector.setSafe(idx(), 1, start, buffer); vector.setValueCount(idx() + 1); } - public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType){ + public void write${minor.class}(long start, ArrowBuf buffer, ArrowType arrowType){ DecimalUtility.checkPrecisionAndScale(((ArrowType.Decimal) arrowType).getPrecision(), ((ArrowType.Decimal) arrowType).getScale(), vector.getPrecision(), vector.getScale()); vector.setSafe(idx(), 1, start, buffer); vector.setValueCount(idx() + 1); } - public void writeDecimal(BigDecimal value){ + public void write${minor.class}(BigDecimal value){ // vector.setSafe already does precision and scale checking vector.setSafe(idx(), value); vector.setValueCount(idx() + 1); } - public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType){ + public void writeBigEndianBytesTo${minor.class}(byte[] value, ArrowType arrowType){ DecimalUtility.checkPrecisionAndScale(((ArrowType.Decimal) arrowType).getPrecision(), ((ArrowType.Decimal) arrowType).getScale(), vector.getPrecision(), vector.getScale()); vector.setBigEndianSafe(idx(), value); vector.setValueCount(idx() + 1); } - public void writeBigEndianBytesToDecimal(byte[] value){ + public void writeBigEndianBytesTo${minor.class}(byte[] value){ vector.setBigEndianSafe(idx(), value); vector.setValueCount(idx() + 1); } + public void writeNull() { vector.setNull(idx()); vector.setValueCount(idx()+1); @@ -190,18 +191,18 @@ public void writeNull() { public interface ${eName}Writer extends BaseWriter { public void write(${minor.class}Holder h); - <#if minor.class == "Decimal">@Deprecated + <#if minor.class?starts_with("Decimal")>@Deprecated public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ); -<#if minor.class == "Decimal"> +<#if minor.class?starts_with("Decimal")> public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, , ArrowType arrowType); public void write${minor.class}(${friendlyType} value); - public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType); + public void writeBigEndianBytesTo${minor.class}(byte[] value, ArrowType arrowType); @Deprecated - public void writeBigEndianBytesToDecimal(byte[] value); + public void writeBigEndianBytesTo${minor.class}(byte[] value); } diff --git a/java/vector/src/main/codegen/templates/DenseUnionReader.java b/java/vector/src/main/codegen/templates/DenseUnionReader.java index 51bd7d172de..f7e161ac86f 100644 --- a/java/vector/src/main/codegen/templates/DenseUnionReader.java +++ b/java/vector/src/main/codegen/templates/DenseUnionReader.java @@ -92,7 +92,7 @@ private FieldReader getReaderForIndex(int index) { <#list type.minor as minor> <#assign name = minor.class?cap_first /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> case ${name?upper_case}: reader = (FieldReader) get${name}(typeId); break; @@ -165,7 +165,7 @@ public int size() { <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> <#assign safeType=friendlyType /> <#if safeType=="byte[]"><#assign safeType="ByteArray" /> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> private ${name}ReaderImpl get${name}(byte typeId) { ${name}ReaderImpl reader = (${name}ReaderImpl) readers[typeId]; diff --git a/java/vector/src/main/codegen/templates/DenseUnionVector.java b/java/vector/src/main/codegen/templates/DenseUnionVector.java index d2154a3bcd6..de4d60d553b 100644 --- a/java/vector/src/main/codegen/templates/DenseUnionVector.java +++ b/java/vector/src/main/codegen/templates/DenseUnionVector.java @@ -305,13 +305,13 @@ public StructVector getStruct(byte typeId) { <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> <#assign lowerCaseName = name?lower_case/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> - public ${name}Vector get${name}Vector(byte typeId<#if minor.class == "Decimal">, ArrowType arrowType) { + public ${name}Vector get${name}Vector(byte typeId<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { ValueVector vector = typeId < 0 ? null : childVectors[typeId]; if (vector == null) { int vectorCount = internalStruct.size(); - vector = addOrGet(typeId, MinorType.${name?upper_case}<#if minor.class == "Decimal">, arrowType, ${name}Vector.class); + vector = addOrGet(typeId, MinorType.${name?upper_case}<#if minor.class?starts_with("Decimal")>, arrowType, ${name}Vector.class); childVectors[typeId] = vector; if (internalStruct.size() > vectorCount) { vector.allocateNew(); @@ -809,7 +809,7 @@ public void setSafe(int index, DenseUnionHolder holder) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> case ${name?upper_case}: Nullable${name}Holder ${uncappedName}Holder = new Nullable${name}Holder(); reader.read(${uncappedName}Holder); @@ -833,13 +833,13 @@ public void setSafe(int index, DenseUnionHolder holder) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> public void setSafe(int index, Nullable${name}Holder holder) { while (index >= getOffsetBufferValueCapacity()) { reallocOffsetBuffer(); } byte typeId = getTypeId(index); - ${name}Vector vector = get${name}Vector(typeId<#if minor.class == "Decimal">, new ArrowType.Decimal(holder.precision, holder.scale)); + ${name}Vector vector = get${name}Vector(typeId<#if minor.class?starts_with("Decimal")>, new ArrowType.Decimal(holder.precision, holder.scale, holder.WIDTH * 8)); int offset = vector.getValueCount(); vector.setValueCount(offset + 1); vector.setSafe(offset, holder); diff --git a/java/vector/src/main/codegen/templates/DenseUnionWriter.java b/java/vector/src/main/codegen/templates/DenseUnionWriter.java index ee6f614c8f8..769b84268af 100644 --- a/java/vector/src/main/codegen/templates/DenseUnionWriter.java +++ b/java/vector/src/main/codegen/templates/DenseUnionWriter.java @@ -123,7 +123,7 @@ BaseWriter getWriter(byte typeId) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> case ${name?upper_case}: return get${name}Writer(typeId); @@ -138,7 +138,7 @@ BaseWriter getWriter(byte typeId) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> private ${name}Writer get${name}Writer(byte typeId) { ${name}Writer writer = (${name}Writer) writers[typeId]; @@ -159,10 +159,10 @@ public void write(${name}Holder holder) { throw new UnsupportedOperationException(); } - public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, , byte typeId<#if minor.class == "Decimal">, ArrowType arrowType) { + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, , byte typeId<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { data.setTypeId(idx(), typeId); get${name}Writer(typeId).setPosition(data.getOffset(idx())); - get${name}Writer(typeId).write${name}(<#list fields as field>${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, arrowType); + get${name}Writer(typeId).write${name}(<#list fields as field>${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, arrowType); } @@ -208,7 +208,7 @@ public StructWriter struct(String name) { <#if lowerName == "int" ><#assign lowerName = "integer" /> <#assign upperName = minor.class?upper_case /> <#assign capName = minor.class?cap_first /> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > @Override public ${capName}Writer ${lowerName}(String name) { byte typeId = data.getTypeId(idx()); @@ -225,7 +225,7 @@ public StructWriter struct(String name) { return getListWriter(typeId).${lowerName}(); } - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> public ${capName}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}) { byte typeId = data.getTypeId(idx()); data.setTypeId(idx(), typeId); diff --git a/java/vector/src/main/codegen/templates/HolderReaderImpl.java b/java/vector/src/main/codegen/templates/HolderReaderImpl.java index fa7e83a9f8b..e41c7db2f2a 100644 --- a/java/vector/src/main/codegen/templates/HolderReaderImpl.java +++ b/java/vector/src/main/codegen/templates/HolderReaderImpl.java @@ -129,6 +129,11 @@ public void read(Nullable${name}Holder h) { holder.buffer.getBytes(holder.start, bytes, 0, ${type.width}); ${friendlyType} value = new BigDecimal(new BigInteger(bytes), holder.scale); return value; + <#elseif minor.class == "Decimal256"> + byte[] bytes = new byte[${type.width}]; + holder.buffer.getBytes(holder.start, bytes, 0, ${type.width}); + ${friendlyType} value = new BigDecimal(new BigInteger(bytes), holder.scale); + return value; <#elseif minor.class == "FixedSizeBinary"> byte[] value = new byte [holder.byteWidth]; holder.buffer.getBytes(0, value, 0, holder.byteWidth); diff --git a/java/vector/src/main/codegen/templates/StructWriters.java b/java/vector/src/main/codegen/templates/StructWriters.java index 7df22179eec..b908d1058fb 100644 --- a/java/vector/src/main/codegen/templates/StructWriters.java +++ b/java/vector/src/main/codegen/templates/StructWriters.java @@ -255,7 +255,7 @@ public void end() { <#assign constructorParams = constructorParams + [ typeParam.name ] /> - new ${minor.arrowType}(${constructorParams?join(", ")}) + new ${minor.arrowType}(${constructorParams?join(", ")}<#if minor.class?starts_with("Decimal")>, ${vectName}Vector.TYPE_WIDTH * 8) <#else> MinorType.${upperName}.getType() @@ -274,7 +274,7 @@ public void end() { } else { if (writer instanceof PromotableWriter) { // ensure writers are initialized - ((PromotableWriter)writer).getWriter(MinorType.${upperName}<#if minor.class == "Decimal">, new ${minor.arrowType}(precision, scale)); + ((PromotableWriter)writer).getWriter(MinorType.${upperName}<#if minor.class?starts_with("Decimal")>, new ${minor.arrowType}(precision, scale, ${vectName}Vector.TYPE_WIDTH * 8)); } } return writer; diff --git a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java index 94c7d8f6490..f04b4db3208 100644 --- a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java @@ -16,9 +16,12 @@ */ import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.complex.writer.Decimal256Writer; import org.apache.arrow.vector.complex.writer.DecimalWriter; +import org.apache.arrow.vector.holders.Decimal256Holder; import org.apache.arrow.vector.holders.DecimalHolder; + import java.lang.UnsupportedOperationException; import java.math.BigDecimal; @@ -127,6 +130,22 @@ public DecimalWriter decimal(String name) { return writer.decimal(name); } + + @Override + public Decimal256Writer decimal256() { + return this; + } + + @Override + public Decimal256Writer decimal256(String name, int scale, int precision) { + return writer.decimal256(name, scale, precision); + } + + @Override + public Decimal256Writer decimal256(String name) { + return writer.decimal256(name); + } + @Override public StructWriter struct() { inStruct = true; @@ -180,6 +199,16 @@ public void write(DecimalHolder holder) { writer.write(holder); writer.setPosition(writer.idx() + 1); } + + @Override + public void write(Decimal256Holder holder) { + if (writer.idx() >= (idx() + 1) * listSize) { + throw new IllegalStateException(String.format("values at index %s is greater than listSize %s", idx(), listSize)); + } + writer.write(holder); + writer.setPosition(writer.idx() + 1); + } + @Override public void writeNull() { @@ -213,6 +242,31 @@ public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType) { writer.setPosition(writer.idx() + 1); } + public void writeDecimal256(long start, ArrowBuf buffer, ArrowType arrowType) { + if (writer.idx() >= (idx() + 1) * listSize) { + throw new IllegalStateException(String.format("values at index %s is greater than listSize %s", idx(), listSize)); + } + writer.writeDecimal256(start, buffer, arrowType); + writer.setPosition(writer.idx() + 1); + } + + public void writeDecimal256(BigDecimal value) { + if (writer.idx() >= (idx() + 1) * listSize) { + throw new IllegalStateException(String.format("values at index %s is greater than listSize %s", idx(), listSize)); + } + writer.writeDecimal256(value); + writer.setPosition(writer.idx() + 1); + } + + public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType) { + if (writer.idx() >= (idx() + 1) * listSize) { + throw new IllegalStateException(String.format("values at index %s is greater than listSize %s", idx(), listSize)); + } + writer.writeBigEndianBytesToDecimal256(value, arrowType); + writer.setPosition(writer.idx() + 1); + } + + <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> diff --git a/java/vector/src/main/codegen/templates/UnionListWriter.java b/java/vector/src/main/codegen/templates/UnionListWriter.java index bb0cff4e06c..155895d8932 100644 --- a/java/vector/src/main/codegen/templates/UnionListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionListWriter.java @@ -16,9 +16,12 @@ */ import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.complex.writer.Decimal256Writer; import org.apache.arrow.vector.complex.writer.DecimalWriter; +import org.apache.arrow.vector.holders.Decimal256Holder; import org.apache.arrow.vector.holders.DecimalHolder; + import java.lang.UnsupportedOperationException; import java.math.BigDecimal; @@ -133,6 +136,22 @@ public DecimalWriter decimal(String name) { return writer.decimal(name); } + @Override + public Decimal256Writer decimal256() { + return this; + } + + @Override + public Decimal256Writer decimal256(String name, int scale, int precision) { + return writer.decimal256(name, scale, precision); + } + + @Override + public Decimal256Writer decimal256(String name) { + return writer.decimal256(name); + } + + @Override public StructWriter struct() { inStruct = true; @@ -199,6 +218,12 @@ public void write(DecimalHolder holder) { writer.setPosition(writer.idx()+1); } + @Override + public void write(Decimal256Holder holder) { + writer.write(holder); + writer.setPosition(writer.idx()+1); + } + @Override public void writeNull() { writer.writeNull(); @@ -224,6 +249,27 @@ public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType){ writer.setPosition(writer.idx() + 1); } + public void writeDecimal256(long start, ArrowBuf buffer, ArrowType arrowType) { + writer.writeDecimal256(start, buffer, arrowType); + writer.setPosition(writer.idx()+1); + } + + public void writeDecimal256(long start, ArrowBuf buffer) { + writer.writeDecimal256(start, buffer); + writer.setPosition(writer.idx()+1); + } + + public void writeDecimal256(BigDecimal value) { + writer.writeDecimal256(value); + writer.setPosition(writer.idx()+1); + } + + public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType){ + writer.writeBigEndianBytesToDecimal256(value, arrowType); + writer.setPosition(writer.idx() + 1); + } + + <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> diff --git a/java/vector/src/main/codegen/templates/UnionMapWriter.java b/java/vector/src/main/codegen/templates/UnionMapWriter.java index 01b371329a9..cec73c45f5c 100644 --- a/java/vector/src/main/codegen/templates/UnionMapWriter.java +++ b/java/vector/src/main/codegen/templates/UnionMapWriter.java @@ -16,7 +16,9 @@ */ import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.complex.writer.Decimal256Writer; import org.apache.arrow.vector.complex.writer.DecimalWriter; +import org.apache.arrow.vector.holders.Decimal256Holder; import org.apache.arrow.vector.holders.DecimalHolder; import java.lang.UnsupportedOperationException; @@ -169,6 +171,19 @@ public DecimalWriter decimal() { } } + @Override + public Decimal256Writer decimal256() { + switch (mode) { + case KEY: + return entryWriter.decimal256(MapVector.KEY_NAME); + case VALUE: + return entryWriter.decimal256(MapVector.VALUE_NAME); + default: + return this; + } + } + + @Override public StructWriter struct() { switch (mode) { diff --git a/java/vector/src/main/codegen/templates/UnionReader.java b/java/vector/src/main/codegen/templates/UnionReader.java index 20fdb41d4af..6ed03fa2117 100644 --- a/java/vector/src/main/codegen/templates/UnionReader.java +++ b/java/vector/src/main/codegen/templates/UnionReader.java @@ -34,7 +34,7 @@ @SuppressWarnings("unused") public class UnionReader extends AbstractFieldReader { - private BaseReader[] readers = new BaseReader[43]; + private BaseReader[] readers = new BaseReader[44]; public UnionVector data; public UnionReader(UnionVector data) { @@ -45,7 +45,7 @@ public MinorType getMinorType() { return TYPES[data.getTypeValue(idx())]; } - private static MinorType[] TYPES = new MinorType[43]; + private static MinorType[] TYPES = new MinorType[44]; static { for (MinorType minorType : MinorType.values()) { @@ -88,7 +88,7 @@ private FieldReader getReaderForIndex(int index) { <#list type.minor as minor> <#assign name = minor.class?cap_first /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> case ${name?upper_case}: return (FieldReader) get${name}(); @@ -157,7 +157,7 @@ public int size() { <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> <#assign safeType=friendlyType /> <#if safeType=="byte[]"><#assign safeType="ByteArray" /> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > private ${name}ReaderImpl ${uncappedName}Reader; diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 59a90cedee6..f33f44bbc60 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -272,18 +272,18 @@ public StructVector getStruct() { <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> <#assign lowerCaseName = name?lower_case/> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > private ${name}Vector ${uncappedName}Vector; - public ${name}Vector get${name}Vector(<#if minor.class == "Decimal"> ArrowType arrowType) { - return get${name}Vector(null<#if minor.class == "Decimal">, arrowType); + public ${name}Vector get${name}Vector(<#if minor.class?starts_with("Decimal")> ArrowType arrowType) { + return get${name}Vector(null<#if minor.class?starts_with("Decimal")>, arrowType); } - public ${name}Vector get${name}Vector(String name<#if minor.class == "Decimal">, ArrowType arrowType) { + public ${name}Vector get${name}Vector(String name<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { if (${uncappedName}Vector == null) { int vectorCount = internalStruct.size(); - ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case},<#if minor.class == "Decimal"> arrowType, ${name}Vector.class); + ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case},<#if minor.class?starts_with("Decimal")> arrowType, ${name}Vector.class); if (internalStruct.size() > vectorCount) { ${uncappedName}Vector.allocateNew(); if (callBack != null) { @@ -293,10 +293,10 @@ public StructVector getStruct() { } return ${uncappedName}Vector; } - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> public ${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No Decimal Vector present. Provide ArrowType argument to create a new vector"); + throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); } return ${uncappedName}Vector; } @@ -637,9 +637,9 @@ public ValueVector getVectorByType(int typeId, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > case ${name?upper_case}: - return get${name}Vector(name<#if minor.class == "Decimal">, arrowType); + return get${name}Vector(name<#if minor.class?starts_with("Decimal")>, arrowType); @@ -722,11 +722,11 @@ public void setSafe(int index, UnionHolder holder, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > case ${name?upper_case}: Nullable${name}Holder ${uncappedName}Holder = new Nullable${name}Holder(); reader.read(${uncappedName}Holder); - setSafe(index, ${uncappedName}Holder<#if minor.class == "Decimal">, arrowType); + setSafe(index, ${uncappedName}Holder<#if minor.class?starts_with("Decimal")>, arrowType); break; @@ -748,10 +748,10 @@ public void setSafe(int index, UnionHolder holder, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal" > - public void setSafe(int index, Nullable${name}Holder holder<#if minor.class == "Decimal">, ArrowType arrowType) { + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + public void setSafe(int index, Nullable${name}Holder holder<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { setType(index, MinorType.${name?upper_case}); - get${name}Vector(null<#if minor.class == "Decimal">, arrowType).setSafe(index, holder); + get${name}Vector(null<#if minor.class?starts_with("Decimal")>, arrowType).setSafe(index, holder); } diff --git a/java/vector/src/main/codegen/templates/UnionWriter.java b/java/vector/src/main/codegen/templates/UnionWriter.java index 6f2b2e1bf0e..59322d42fde 100644 --- a/java/vector/src/main/codegen/templates/UnionWriter.java +++ b/java/vector/src/main/codegen/templates/UnionWriter.java @@ -125,9 +125,9 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class == "Decimal"> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> case ${name?upper_case}: - return get${name}Writer(<#if minor.class == "Decimal" >arrowType); + return get${name}Writer(<#if minor.class?starts_with("Decimal") >arrowType); @@ -141,49 +141,49 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > private ${name}Writer ${name?uncap_first}Writer; - private ${name}Writer get${name}Writer(<#if minor.class == "Decimal">ArrowType arrowType) { + private ${name}Writer get${name}Writer(<#if minor.class?starts_with("Decimal")>ArrowType arrowType) { if (${uncappedName}Writer == null) { - ${uncappedName}Writer = new ${name}WriterImpl(data.get${name}Vector(<#if minor.class == "Decimal">arrowType)); + ${uncappedName}Writer = new ${name}WriterImpl(data.get${name}Vector(<#if minor.class?starts_with("Decimal")>arrowType)); ${uncappedName}Writer.setPosition(idx()); writers.add(${uncappedName}Writer); } return ${uncappedName}Writer; } - public ${name}Writer as${name}(<#if minor.class == "Decimal">ArrowType arrowType) { + public ${name}Writer as${name}(<#if minor.class?starts_with("Decimal")>ArrowType arrowType) { data.setType(idx(), MinorType.${name?upper_case}); - return get${name}Writer(<#if minor.class == "Decimal">arrowType); + return get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType); } @Override public void write(${name}Holder holder) { data.setType(idx(), MinorType.${name?upper_case}); - <#if minor.class == "Decimal">ArrowType arrowType = new ArrowType.Decimal(holder.precision, holder.scale); - get${name}Writer(<#if minor.class == "Decimal">arrowType).setPosition(idx()); - get${name}Writer(<#if minor.class == "Decimal">arrowType).write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, arrowType); + <#if minor.class?starts_with("Decimal")>ArrowType arrowType = new ArrowType.Decimal(holder.precision, holder.scale, ${name}Holder.WIDTH * 8); + get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).setPosition(idx()); + get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, arrowType); } - public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, ArrowType arrowType) { + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { data.setType(idx(), MinorType.${name?upper_case}); - get${name}Writer(<#if minor.class == "Decimal">arrowType).setPosition(idx()); - get${name}Writer(<#if minor.class == "Decimal">arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, <#if minor.class == "Decimal">, arrowType); + get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).setPosition(idx()); + get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, arrowType); } - <#if minor.class == "Decimal"> - public void write${minor.class}(${friendlyType} value) { - data.setType(idx(), MinorType.DECIMAL); - ArrowType arrowType = new ArrowType.Decimal(value.precision(), value.scale()); - getDecimalWriter(arrowType).setPosition(idx()); - getDecimalWriter(arrowType).writeDecimal(value); + <#if minor.class?starts_with("Decimal")> + public void write${name}(${friendlyType} value) { + data.setType(idx(), MinorType.${name?upper_case}); + ArrowType arrowType = new ArrowType.Decimal(value.precision(), value.scale(), ${name}Vector.TYPE_WIDTH * 8); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write${name}(value); } - public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType) { - data.setType(idx(), MinorType.DECIMAL); - getDecimalWriter(arrowType).setPosition(idx()); - getDecimalWriter(arrowType).writeBigEndianBytesToDecimal(value, arrowType); + public void writeBigEndianBytesTo${name}(byte[] value, ArrowType arrowType) { + data.setType(idx(), MinorType.${name?upper_case}); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).writeBigEndianBytesTo${name}(value, arrowType); } @@ -226,7 +226,7 @@ public StructWriter struct(String name) { <#if lowerName == "int" ><#assign lowerName = "integer" /> <#assign upperName = minor.class?upper_case /> <#assign capName = minor.class?cap_first /> - <#if !minor.typeParams?? || minor.class == "Decimal" > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > @Override public ${capName}Writer ${lowerName}(String name) { data.setType(idx(), MinorType.STRUCT); @@ -241,7 +241,7 @@ public StructWriter struct(String name) { return getListWriter().${lowerName}(); } - <#if minor.class == "Decimal"> + <#if minor.class?starts_with("Decimal")> @Override public ${capName}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}) { data.setType(idx(), MinorType.STRUCT); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java index 0bd64c06eab..09c874e3980 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BufferLayout.java @@ -51,6 +51,7 @@ public String getName() { private static final BufferLayout LARGE_OFFSET_BUFFER = new BufferLayout(BufferType.OFFSET, 64); private static final BufferLayout TYPE_BUFFER = new BufferLayout(BufferType.TYPE, 32); private static final BufferLayout BIT_BUFFER = new BufferLayout(BufferType.DATA, 1); + private static final BufferLayout VALUES_256 = new BufferLayout(BufferType.DATA, 256); private static final BufferLayout VALUES_128 = new BufferLayout(BufferType.DATA, 128); private static final BufferLayout VALUES_64 = new BufferLayout(BufferType.DATA, 64); private static final BufferLayout VALUES_32 = new BufferLayout(BufferType.DATA, 32); @@ -85,8 +86,10 @@ public static BufferLayout dataBuffer(int typeBitWidth) { return VALUES_64; case 128: return VALUES_128; + case 256: + return VALUES_256; default: - throw new IllegalArgumentException("only 8, 16, 32, 64, or 128 bits supported"); + throw new IllegalArgumentException("only 8, 16, 32, 64, 128, or 256 bits supported"); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java new file mode 100644 index 00000000000..ed10468095b --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java @@ -0,0 +1,554 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static org.apache.arrow.vector.NullCheckingForGet.NULL_CHECKING_ENABLED; + +import java.math.BigDecimal; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.impl.Decimal256ReaderImpl; +import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.holders.Decimal256Holder; +import org.apache.arrow.vector.holders.NullableDecimal256Holder; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; +import org.apache.arrow.vector.util.TransferPair; + +import io.netty.util.internal.PlatformDependent; + +/** + * Decimal256Vector implements a fixed width vector (32 bytes) of + * decimal values which could be null. A validity buffer (bit vector) is + * maintained to track which elements in the vector are null. + */ +public final class Decimal256Vector extends BaseFixedWidthVector { + public static final byte TYPE_WIDTH = 32; + private final FieldReader reader; + + private final int precision; + private final int scale; + + /** + * Instantiate a Decimal256Vector. This doesn't allocate any memory for + * the data in vector. + * + * @param name name of the vector + * @param allocator allocator for memory management. + */ + public Decimal256Vector(String name, BufferAllocator allocator, + int precision, int scale) { + this(name, FieldType.nullable(new ArrowType.Decimal(precision, scale, /*bitWidth=*/TYPE_WIDTH * 8)), allocator); + } + + /** + * Instantiate a Decimal256Vector. This doesn't allocate any memory for + * the data in vector. + * + * @param name name of the vector + * @param fieldType type of Field materialized by this vector + * @param allocator allocator for memory management. + */ + public Decimal256Vector(String name, FieldType fieldType, BufferAllocator allocator) { + this(new Field(name, fieldType, null), allocator); + } + + /** + * Instantiate a Decimal256Vector. This doesn't allocate any memory for + * the data in vector. + * + * @param field field materialized by this vector + * @param allocator allocator for memory management. + */ + public Decimal256Vector(Field field, BufferAllocator allocator) { + super(field, allocator, TYPE_WIDTH); + ArrowType.Decimal arrowType = (ArrowType.Decimal) field.getFieldType().getType(); + reader = new Decimal256ReaderImpl(Decimal256Vector.this); + this.precision = arrowType.getPrecision(); + this.scale = arrowType.getScale(); + } + + /** + * Get a reader that supports reading values from this vector. + * + * @return Field Reader for this vector + */ + @Override + public FieldReader getReader() { + return reader; + } + + /** + * Get minor type for this vector. The vector holds values belonging + * to a particular type. + * + * @return {@link org.apache.arrow.vector.types.Types.MinorType} + */ + @Override + public MinorType getMinorType() { + return MinorType.DECIMAL256; + } + + + /*----------------------------------------------------------------* + | | + | vector value retrieval methods | + | | + *----------------------------------------------------------------*/ + + + /** + * Get the element at the given index from the vector. + * + * @param index position of element + * @return element at given index + */ + public ArrowBuf get(int index) throws IllegalStateException { + if (NULL_CHECKING_ENABLED && isSet(index) == 0) { + throw new IllegalStateException("Value at index is null"); + } + return valueBuffer.slice((long) index * TYPE_WIDTH, TYPE_WIDTH); + } + + /** + * Get the element at the given index from the vector and + * sets the state in holder. If element at given index + * is null, holder.isSet will be zero. + * + * @param index position of element + */ + public void get(int index, NullableDecimal256Holder holder) { + if (isSet(index) == 0) { + holder.isSet = 0; + return; + } + holder.isSet = 1; + holder.buffer = valueBuffer; + holder.precision = precision; + holder.scale = scale; + holder.start = ((long) index) * TYPE_WIDTH; + } + + /** + * Same as {@link #get(int)}. + * + * @param index position of element + * @return element at given index + */ + public BigDecimal getObject(int index) { + if (isSet(index) == 0) { + return null; + } else { + return DecimalUtility.getBigDecimalFromArrowBuf(valueBuffer, index, scale, TYPE_WIDTH); + } + } + + /** + * Return precision for the decimal value. + */ + public int getPrecision() { + return precision; + } + + /** + * Return scale for the decimal value. + */ + public int getScale() { + return scale; + } + + + /*----------------------------------------------------------------* + | | + | vector value setter methods | + | | + *----------------------------------------------------------------*/ + + + /** + * Set the element at the given index to the given value. + * + * @param index position of element + * @param buffer ArrowBuf containing decimal value. + */ + public void set(int index, ArrowBuf buffer) { + BitVectorHelper.setBit(validityBuffer, index); + valueBuffer.setBytes((long) index * TYPE_WIDTH, buffer, 0, TYPE_WIDTH); + } + + /** + * Set the decimal element at given index to the provided array of bytes. + * Decimal256 is now implemented as Little Endian. This API allows the user + * to pass a decimal value in the form of byte array in BE byte order. + * + *

Consumers of Arrow code can use this API instead of first swapping + * the source bytes (doing a write and read) and then finally writing to + * ArrowBuf of decimal vector. + * + *

This method takes care of adding the necessary padding if the length + * of byte array is less then 32 (length of decimal type). + * + * @param index position of element + * @param value array of bytes containing decimal in big endian byte order. + */ + public void setBigEndian(int index, byte[] value) { + BitVectorHelper.setBit(validityBuffer, index); + final int length = value.length; + + // do the bound check. + valueBuffer.checkBytes((long) index * TYPE_WIDTH, (long) (index + 1) * TYPE_WIDTH); + + long outAddress = valueBuffer.memoryAddress() + (long) index * TYPE_WIDTH; + // swap bytes to convert BE to LE + for (int byteIdx = 0; byteIdx < length; ++byteIdx) { + PlatformDependent.putByte(outAddress + byteIdx, value[length - 1 - byteIdx]); + } + + if (length == TYPE_WIDTH) { + return; + } + + if (length == 0) { + PlatformDependent.setMemory(outAddress, Decimal256Vector.TYPE_WIDTH, (byte) 0); + } else if (length < TYPE_WIDTH) { + // sign extend + final byte pad = (byte) (value[0] < 0 ? 0xFF : 0x00); + PlatformDependent.setMemory(outAddress + length, Decimal256Vector.TYPE_WIDTH - length, pad); + } else { + throw new IllegalArgumentException( + "Invalid decimal value length. Valid length in [1 - 32], got " + length); + } + } + + /** + * Set the element at the given index to the given value. + * + * @param index position of element + * @param start start index of data in the buffer + * @param buffer ArrowBuf containing decimal value. + */ + public void set(int index, long start, ArrowBuf buffer) { + BitVectorHelper.setBit(validityBuffer, index); + valueBuffer.setBytes((long) index * TYPE_WIDTH, buffer, start, TYPE_WIDTH); + } + + /** + * Sets the element at given index using the buffer whose size maybe <= 32 bytes. + * @param index index to write the decimal to + * @param start start of value in the buffer + * @param buffer contains the decimal in little endian bytes + * @param length length of the value in the buffer + */ + public void setSafe(int index, long start, ArrowBuf buffer, int length) { + handleSafe(index); + BitVectorHelper.setBit(validityBuffer, index); + + // do the bound checks. + buffer.checkBytes(start, start + length); + valueBuffer.checkBytes((long) index * TYPE_WIDTH, (long) (index + 1) * TYPE_WIDTH); + + long inAddress = buffer.memoryAddress() + start; + long outAddress = valueBuffer.memoryAddress() + (long) index * TYPE_WIDTH; + PlatformDependent.copyMemory(inAddress, outAddress, length); + // sign extend + if (length < 32) { + byte msb = PlatformDependent.getByte(inAddress + length - 1); + final byte pad = (byte) (msb < 0 ? 0xFF : 0x00); + PlatformDependent.setMemory(outAddress + length, Decimal256Vector.TYPE_WIDTH - length, pad); + } + } + + + /** + * Sets the element at given index using the buffer whose size maybe <= 32 bytes. + * @param index index to write the decimal to + * @param start start of value in the buffer + * @param buffer contains the decimal in big endian bytes + * @param length length of the value in the buffer + */ + public void setBigEndianSafe(int index, long start, ArrowBuf buffer, int length) { + handleSafe(index); + BitVectorHelper.setBit(validityBuffer, index); + + // do the bound checks. + buffer.checkBytes(start, start + length); + valueBuffer.checkBytes((long) index * TYPE_WIDTH, (long) (index + 1) * TYPE_WIDTH); + + // not using buffer.getByte() to avoid boundary checks for every byte. + long inAddress = buffer.memoryAddress() + start; + long outAddress = valueBuffer.memoryAddress() + (long) index * TYPE_WIDTH; + // swap bytes to convert BE to LE + for (int byteIdx = 0; byteIdx < length; ++byteIdx) { + byte val = PlatformDependent.getByte((inAddress + length - 1) - byteIdx); + PlatformDependent.putByte(outAddress + byteIdx, val); + } + // sign extend + if (length < 32) { + byte msb = PlatformDependent.getByte(inAddress); + final byte pad = (byte) (msb < 0 ? 0xFF : 0x00); + PlatformDependent.setMemory(outAddress + length, Decimal256Vector.TYPE_WIDTH - length, pad); + } + } + + /** + * Set the element at the given index to the given value. + * + * @param index position of element + * @param value BigDecimal containing decimal value. + */ + public void set(int index, BigDecimal value) { + BitVectorHelper.setBit(validityBuffer, index); + DecimalUtility.checkPrecisionAndScale(value, precision, scale); + DecimalUtility.writeBigDecimalToArrowBuf(value, valueBuffer, index, TYPE_WIDTH); + } + + /** + * Set the element at the given index to the given value. + * + * @param index position of element + * @param value long value. + */ + public void set(int index, long value) { + BitVectorHelper.setBit(validityBuffer, index); + final long addressOfValue = valueBuffer.memoryAddress() + (long) index * TYPE_WIDTH; + PlatformDependent.putLong(addressOfValue, value); + final long padValue = Long.signum(value) == -1 ? -1L : 0L; + PlatformDependent.putLong(addressOfValue + Long.BYTES, padValue); + PlatformDependent.putLong(addressOfValue + 2 * Long.BYTES, padValue); + PlatformDependent.putLong(addressOfValue + 3 * Long.BYTES, padValue); + } + + /** + * Set the element at the given index to the value set in data holder. + * If the value in holder is not indicated as set, element in the + * at the given index will be null. + * + * @param index position of element + * @param holder nullable data holder for value of element + */ + public void set(int index, NullableDecimal256Holder holder) throws IllegalArgumentException { + if (holder.isSet < 0) { + throw new IllegalArgumentException(); + } else if (holder.isSet > 0) { + BitVectorHelper.setBit(validityBuffer, index); + valueBuffer.setBytes((long) index * TYPE_WIDTH, holder.buffer, holder.start, TYPE_WIDTH); + } else { + BitVectorHelper.unsetBit(validityBuffer, index); + } + } + + /** + * Set the element at the given index to the value set in data holder. + * + * @param index position of element + * @param holder data holder for value of element + */ + public void set(int index, Decimal256Holder holder) { + BitVectorHelper.setBit(validityBuffer, index); + valueBuffer.setBytes((long) index * TYPE_WIDTH, holder.buffer, holder.start, TYPE_WIDTH); + } + + /** + * Same as {@link #set(int, ArrowBuf)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param buffer ArrowBuf containing decimal value. + */ + public void setSafe(int index, ArrowBuf buffer) { + handleSafe(index); + set(index, buffer); + } + + /** + * Same as {@link #setBigEndian(int, byte[])} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + */ + public void setBigEndianSafe(int index, byte[] value) { + handleSafe(index); + setBigEndian(index, value); + } + + /** + * Same as {@link #set(int, int, ArrowBuf)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param start start index of data in the buffer + * @param buffer ArrowBuf containing decimal value. + */ + public void setSafe(int index, long start, ArrowBuf buffer) { + handleSafe(index); + set(index, start, buffer); + } + + /** + * Same as {@link #set(int, BigDecimal)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param value BigDecimal containing decimal value. + */ + public void setSafe(int index, BigDecimal value) { + handleSafe(index); + set(index, value); + } + + /** + * Same as {@link #set(int, long)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param value long value. + */ + public void setSafe(int index, long value) { + handleSafe(index); + set(index, value); + } + + /** + * Same as {@link #set(int, NullableDecimalHolder)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param holder nullable data holder for value of element + */ + public void setSafe(int index, NullableDecimal256Holder holder) throws IllegalArgumentException { + handleSafe(index); + set(index, holder); + } + + /** + * Same as {@link #set(int, Decimal256Holder)} except that it handles the + * case when index is greater than or equal to existing + * value capacity {@link #getValueCapacity()}. + * + * @param index position of element + * @param holder data holder for value of element + */ + public void setSafe(int index, Decimal256Holder holder) { + handleSafe(index); + set(index, holder); + } + + /** + * Store the given value at a particular position in the vector. isSet indicates + * whether the value is NULL or not. + * + * @param index position of the new value + * @param isSet 0 for NULL value, 1 otherwise + * @param start start position of the value in the buffer + * @param buffer buffer containing the value to be stored in the vector + */ + public void set(int index, int isSet, long start, ArrowBuf buffer) { + if (isSet > 0) { + set(index, start, buffer); + } else { + BitVectorHelper.unsetBit(validityBuffer, index); + } + } + + /** + * Same as {@link #setSafe(int, int, int, ArrowBuf)} except that it handles + * the case when the position of new value is beyond the current value + * capacity of the vector. + * + * @param index position of the new value + * @param isSet 0 for NULL value, 1 otherwise + * @param start start position of the value in the buffer + * @param buffer buffer containing the value to be stored in the vector + */ + public void setSafe(int index, int isSet, long start, ArrowBuf buffer) { + handleSafe(index); + set(index, isSet, start, buffer); + } + + /*----------------------------------------------------------------* + | | + | vector transfer | + | | + *----------------------------------------------------------------*/ + + + /** + * Construct a TransferPair comprising of this and a target vector of + * the same type. + * + * @param ref name of the target vector + * @param allocator allocator for the target vector + * @return {@link TransferPair} + */ + @Override + public TransferPair getTransferPair(String ref, BufferAllocator allocator) { + return new TransferImpl(ref, allocator); + } + + /** + * Construct a TransferPair with a desired target vector of the same type. + * + * @param to target vector + * @return {@link TransferPair} + */ + @Override + public TransferPair makeTransferPair(ValueVector to) { + return new TransferImpl((Decimal256Vector) to); + } + + private class TransferImpl implements TransferPair { + Decimal256Vector to; + + public TransferImpl(String ref, BufferAllocator allocator) { + to = new Decimal256Vector(ref, allocator, Decimal256Vector.this.precision, + Decimal256Vector.this.scale); + } + + public TransferImpl(Decimal256Vector to) { + this.to = to; + } + + @Override + public Decimal256Vector getTo() { + return to; + } + + @Override + public void transfer() { + transferTo(to); + } + + @Override + public void splitAndTransfer(int startIndex, int length) { + splitAndTransferTo(startIndex, length, to); + } + + @Override + public void copyValueSafe(int fromIndex, int toIndex) { + to.copyFromSafe(fromIndex, toIndex, Decimal256Vector.this); + } + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 04344c35e34..3fc54ba5ccc 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -57,7 +57,7 @@ public final class DecimalVector extends BaseFixedWidthVector { */ public DecimalVector(String name, BufferAllocator allocator, int precision, int scale) { - this(name, FieldType.nullable(new ArrowType.Decimal(precision, scale)), allocator); + this(name, FieldType.nullable(new ArrowType.Decimal(precision, scale, TYPE_WIDTH * 8)), allocator); } /** @@ -158,7 +158,7 @@ public BigDecimal getObject(int index) { if (isSet(index) == 0) { return null; } else { - return DecimalUtility.getBigDecimalFromArrowBuf(valueBuffer, index, scale); + return DecimalUtility.getBigDecimalFromArrowBuf(valueBuffer, index, scale, TYPE_WIDTH); } } @@ -318,7 +318,7 @@ public void setBigEndianSafe(int index, long start, ArrowBuf buffer, int length) public void set(int index, BigDecimal value) { BitVectorHelper.setBit(validityBuffer, index); DecimalUtility.checkPrecisionAndScale(value, precision, scale); - DecimalUtility.writeBigDecimalToArrowBuf(value, valueBuffer, index); + DecimalUtility.writeBigDecimalToArrowBuf(value, valueBuffer, index, TYPE_WIDTH); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java index 501ca98c0a4..1004ce1a74a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TypeLayout.java @@ -155,7 +155,7 @@ public TypeLayout visit(FloatingPoint type) { @Override public TypeLayout visit(Decimal type) { - return newFixedWidthTypeLayout(BufferLayout.dataBuffer(128)); + return newFixedWidthTypeLayout(BufferLayout.dataBuffer(type.getBitWidth())); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 51decee39fd..b87281dbc14 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.holders.Decimal256Holder; import org.apache.arrow.vector.holders.DecimalHolder; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -54,6 +55,7 @@ public class PromotableWriter extends AbstractPromotableFieldWriter { private final NullableStructWriterFactory nullableStructWriterFactory; private int position; private static final int MAX_DECIMAL_PRECISION = 38; + private static final int MAX_DECIMAL256_PRECISION = 76; private enum State { UNTYPED, SINGLE, UNION @@ -316,26 +318,54 @@ private FieldWriter promoteToUnion() { @Override public void write(DecimalHolder holder) { - getWriter(MinorType.DECIMAL, new ArrowType.Decimal(MAX_DECIMAL_PRECISION, holder.scale)).write(holder); + getWriter(MinorType.DECIMAL, + new ArrowType.Decimal(MAX_DECIMAL_PRECISION, holder.scale, /*bitWidth=*/128)).write(holder); } @Override public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType) { getWriter(MinorType.DECIMAL, new ArrowType.Decimal(MAX_DECIMAL_PRECISION, - ((ArrowType.Decimal) arrowType).getScale())).writeDecimal(start, buffer, arrowType); + ((ArrowType.Decimal) arrowType).getScale(), /*bitWidth=*/128)).writeDecimal(start, buffer, arrowType); } @Override public void writeDecimal(BigDecimal value) { - getWriter(MinorType.DECIMAL, new ArrowType.Decimal(MAX_DECIMAL_PRECISION, value.scale())).writeDecimal(value); + getWriter(MinorType.DECIMAL, + new ArrowType.Decimal(MAX_DECIMAL_PRECISION, value.scale(), /*bitWidth=*/128)).writeDecimal(value); } @Override public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType) { getWriter(MinorType.DECIMAL, new ArrowType.Decimal(MAX_DECIMAL_PRECISION, - ((ArrowType.Decimal) arrowType).getScale())).writeBigEndianBytesToDecimal(value, arrowType); + ((ArrowType.Decimal) arrowType).getScale(), /*bitWidth=*/128)).writeBigEndianBytesToDecimal(value, arrowType); } + @Override + public void write(Decimal256Holder holder) { + getWriter(MinorType.DECIMAL256, + new ArrowType.Decimal(MAX_DECIMAL256_PRECISION, holder.scale, /*bitWidth=*/256)).write(holder); + } + + @Override + public void writeDecimal256(long start, ArrowBuf buffer, ArrowType arrowType) { + getWriter(MinorType.DECIMAL256, new ArrowType.Decimal(MAX_DECIMAL256_PRECISION, + ((ArrowType.Decimal) arrowType).getScale(), /*bitWidth=*/256)).writeDecimal256(start, buffer, arrowType); + } + + @Override + public void writeDecimal256(BigDecimal value) { + getWriter(MinorType.DECIMAL256, + new ArrowType.Decimal(MAX_DECIMAL256_PRECISION, value.scale(), /*bitWidth=*/256)).writeDecimal256(value); + } + + @Override + public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType) { + getWriter(MinorType.DECIMAL256, new ArrowType.Decimal(MAX_DECIMAL256_PRECISION, + ((ArrowType.Decimal) arrowType).getScale(), + /*bitWidth=*/256)).writeBigEndianBytesToDecimal256(value, arrowType); + } + + @Override public void allocate() { getWriter().allocate(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 13935ef4f10..40f83c04fe5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -45,6 +45,7 @@ import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVectorHelper; import org.apache.arrow.vector.BufferLayout.BufferType; +import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float4Vector; @@ -438,7 +439,7 @@ protected ArrowBuf read(BufferAllocator allocator, int count) throws IOException for (int i = 0; i < count; i++) { parser.nextToken(); BigDecimal decimalValue = new BigDecimal(parser.readValueAs(String.class)); - DecimalUtility.writeBigDecimalToArrowBuf(decimalValue, buf, i); + DecimalUtility.writeBigDecimalToArrowBuf(decimalValue, buf, i, DecimalVector.TYPE_WIDTH); } buf.writerIndex(size); @@ -446,6 +447,24 @@ protected ArrowBuf read(BufferAllocator allocator, int count) throws IOException } }; + BufferReader DECIMAL256 = new BufferReader() { + @Override + protected ArrowBuf read(BufferAllocator allocator, int count) throws IOException { + final int size = count * Decimal256Vector.TYPE_WIDTH; + ArrowBuf buf = allocator.buffer(size); + + for (int i = 0; i < count; i++) { + parser.nextToken(); + BigDecimal decimalValue = new BigDecimal(parser.readValueAs(String.class)); + DecimalUtility.writeBigDecimalToArrowBuf(decimalValue, buf, i, Decimal256Vector.TYPE_WIDTH); + } + + buf.writerIndex(size); + return buf; + } + }; + + BufferReader FIXEDSIZEBINARY = new BufferReader() { @Override protected ArrowBuf read(BufferAllocator allocator, int count) throws IOException { @@ -615,6 +634,9 @@ private ArrowBuf readIntoBuffer(BufferAllocator allocator, BufferType bufferType case DECIMAL: reader = helper.DECIMAL; break; + case DECIMAL256: + reader = helper.DECIMAL256; + break; case FIXEDSIZEBINARY: reader = helper.FIXEDSIZEBINARY; break; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index e210b002890..f2854c95c30 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -36,6 +36,7 @@ import org.apache.arrow.vector.BufferLayout.BufferType; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; import org.apache.arrow.vector.FieldVector; @@ -377,11 +378,21 @@ private void writeValueToGenerator( } case DECIMAL: { int scale = ((DecimalVector) vector).getScale(); - BigDecimal decimalValue = DecimalUtility.getBigDecimalFromArrowBuf(buffer, index, scale); + BigDecimal decimalValue = DecimalUtility.getBigDecimalFromArrowBuf(buffer, index, scale, + DecimalVector.TYPE_WIDTH); // We write the unscaled value, because the scale is stored in the type metadata. generator.writeString(decimalValue.unscaledValue().toString()); break; } + case DECIMAL256: { + int scale = ((Decimal256Vector) vector).getScale(); + BigDecimal decimalValue = DecimalUtility.getBigDecimalFromArrowBuf(buffer, index, scale, + Decimal256Vector.TYPE_WIDTH); + // We write the unscaled value, because the scale is stored in the type metadata. + generator.writeString(decimalValue.unscaledValue().toString()); + break; + } + default: throw new UnsupportedOperationException("minor type: " + vector.getMinorType()); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java index 886478ce403..e93bc695917 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java @@ -27,6 +27,7 @@ import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; import org.apache.arrow.vector.ExtensionTypeVector; @@ -72,6 +73,7 @@ import org.apache.arrow.vector.complex.impl.BitWriterImpl; import org.apache.arrow.vector.complex.impl.DateDayWriterImpl; import org.apache.arrow.vector.complex.impl.DateMilliWriterImpl; +import org.apache.arrow.vector.complex.impl.Decimal256WriterImpl; import org.apache.arrow.vector.complex.impl.DecimalWriterImpl; import org.apache.arrow.vector.complex.impl.DenseUnionWriter; import org.apache.arrow.vector.complex.impl.DurationWriterImpl; @@ -528,6 +530,20 @@ public FieldWriter getNewFieldWriter(ValueVector vector) { return new DecimalWriterImpl((DecimalVector) vector); } }, + DECIMAL256(null) { + @Override + public FieldVector getNewVector( + Field field, + BufferAllocator allocator, + CallBack schemaChangeCallback) { + return new Decimal256Vector(field, allocator); + } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + return new Decimal256WriterImpl((Decimal256Vector) vector); + } + }, FIXEDSIZEBINARY(null) { @Override public FieldVector getNewVector( @@ -899,6 +915,9 @@ public MinorType visit(Bool type) { @Override public MinorType visit(Decimal type) { + if (type.getBitWidth() == 256) { + return MinorType.DECIMAL256; + } return MinorType.DECIMAL; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 36c988fac7e..6f707648a75 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -32,24 +32,26 @@ public class DecimalUtility { private DecimalUtility() {} public static final int DECIMAL_BYTE_LENGTH = 16; - public static final byte [] zeroes = new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - public static final byte [] minus_one = new byte[] {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; + public static final byte [] zeroes = new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + public static final byte [] minus_one = new byte[] {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; /** * Read an ArrowType.Decimal at the given value index in the ArrowBuf and convert to a BigDecimal * with the given scale. */ - public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale) { - byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale, int byteWidth) { + byte[] value = new byte[byteWidth]; byte temp; - final long startIndex = (long) index * DECIMAL_BYTE_LENGTH; + final long startIndex = (long) index * byteWidth; // Decimal stored as little endian, need to swap bytes to make BigDecimal - bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); - int stop = DECIMAL_BYTE_LENGTH / 2; + bytebuf.getBytes(startIndex, value, 0, byteWidth); + int stop = byteWidth / 2; for (int i = 0, j; i < stop; i++) { temp = value[i]; - j = (DECIMAL_BYTE_LENGTH - 1) - i; + j = (byteWidth - 1) - i; value[i] = value[j]; value[j] = temp; } @@ -61,8 +63,8 @@ public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, * Read an ArrowType.Decimal from the ByteBuffer and convert to a BigDecimal with the given * scale. */ - public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int scale) { - byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int scale, int byteWidth) { + byte[] value = new byte[byteWidth]; bytebuf.get(value); BigInteger unscaledValue = new BigInteger(value); return new BigDecimal(unscaledValue, scale); @@ -72,10 +74,10 @@ public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int sca * Read an ArrowType.Decimal from the ArrowBuf at the given value index and return it as a byte * array. */ - public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index) { - final byte[] value = new byte[DECIMAL_BYTE_LENGTH]; - final long startIndex = (long) index * DECIMAL_BYTE_LENGTH; - bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); + public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index, int byteWidth) { + final byte[] value = new byte[byteWidth]; + final long startIndex = (long) index * byteWidth; + bytebuf.getBytes(startIndex, value, 0, byteWidth); return value; } @@ -119,9 +121,9 @@ public static boolean checkPrecisionAndScale(int decimalPrecision, int decimalSc * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte * width. */ - public static void writeBigDecimalToArrowBuf(BigDecimal value, ArrowBuf bytebuf, int index) { + public static void writeBigDecimalToArrowBuf(BigDecimal value, ArrowBuf bytebuf, int index, int byteWidth) { final byte[] bytes = value.unscaledValue().toByteArray(); - writeByteArrayToArrowBufHelper(bytes, bytebuf, index); + writeByteArrayToArrowBufHelper(bytes, bytebuf, index, byteWidth); } /** @@ -139,14 +141,14 @@ public static void writeLongToArrowBuf(long value, ArrowBuf bytebuf, int index) * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte * width. */ - public static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index) { - writeByteArrayToArrowBufHelper(bytes, bytebuf, index); + public static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index, int byteWidth) { + writeByteArrayToArrowBufHelper(bytes, bytebuf, index, byteWidth); } - private static void writeByteArrayToArrowBufHelper(byte[] bytes, ArrowBuf bytebuf, int index) { - final long startIndex = (long) index * DECIMAL_BYTE_LENGTH; - if (bytes.length > DECIMAL_BYTE_LENGTH) { - throw new UnsupportedOperationException("Decimal size greater than 16 bytes"); + private static void writeByteArrayToArrowBufHelper(byte[] bytes, ArrowBuf bytebuf, int index, int byteWidth) { + final long startIndex = (long) index * byteWidth; + if (bytes.length > byteWidth) { + throw new UnsupportedOperationException("Decimal size greater than " + byteWidth + " bytes: " + bytes.length); } // Decimal stored as little endian, need to swap data bytes before writing to ArrowBuf @@ -156,8 +158,8 @@ private static void writeByteArrayToArrowBufHelper(byte[] bytes, ArrowBuf bytebu } // Write LE data - byte [] padByes = bytes[0] < 0 ? minus_one : zeroes; + byte [] padBytes = bytes[0] < 0 ? minus_one : zeroes; bytebuf.setBytes(startIndex, bytesLE, 0, bytes.length); - bytebuf.setBytes(startIndex + bytes.length, padByes, 0, DECIMAL_BYTE_LENGTH - bytes.length); + bytebuf.setBytes(startIndex + bytes.length, padBytes, 0, byteWidth - bytes.length); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorTypeVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorTypeVisitor.java index d9b0e7b9429..de00c6d8ff3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorTypeVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorTypeVisitor.java @@ -26,6 +26,7 @@ import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DurationVector; import org.apache.arrow.vector.FixedSizeBinaryVector; @@ -172,7 +173,7 @@ public Void visit(BaseFixedWidthVector vector, Void value) { validateIntVector(vector, 64, false); } else if (vector instanceof BitVector) { validateVectorCommon(vector, ArrowType.Bool.class); - } else if (vector instanceof DecimalVector) { + } else if (vector instanceof DecimalVector || vector instanceof Decimal256Vector) { validateVectorCommon(vector, ArrowType.Decimal.class); ArrowType.Decimal arrowType = (ArrowType.Decimal) vector.getField().getType(); validateOrThrow(arrowType.getScale() > 0, "The scale of decimal %s is not positive.", arrowType.getScale()); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java new file mode 100644 index 00000000000..7aa48f4bfb4 --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimal256Vector.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDecimal256Vector { + + private static long[] intValues; + + static { + intValues = new long[60]; + for (int i = 0; i < intValues.length / 2; i++) { + intValues[i] = 1 << i + 1; + intValues[2 * i] = -1 * (1 << i + 1); + } + } + + private int scale = 3; + + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new DirtyRootAllocator(Long.MAX_VALUE, (byte) 100); + } + + @After + public void terminate() throws Exception { + allocator.close(); + } + + @Test + public void testValuesWriteRead() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(10, scale, 256), allocator);) { + + try (Decimal256Vector oldConstructor = new Decimal256Vector("decimal", allocator, 10, scale);) { + assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); + } + + decimalVector.allocateNew(); + BigDecimal[] values = new BigDecimal[intValues.length]; + for (int i = 0; i < intValues.length; i++) { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); + values[i] = decimal; + decimalVector.setSafe(i, decimal); + } + + decimalVector.setValueCount(intValues.length); + + for (int i = 0; i < intValues.length; i++) { + BigDecimal value = decimalVector.getObject(i); + assertEquals("unexpected data at index: " + i, values[i], value); + } + } + } + + @Test + public void testDecimal256DifferentScaleAndPrecision() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(4, 2, 256), allocator)) { + decimalVector.allocateNew(); + + // test Decimal256 with different scale + { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(0), 3); + UnsupportedOperationException ue = + assertThrows(UnsupportedOperationException.class, () -> decimalVector.setSafe(0, decimal)); + assertEquals("BigDecimal scale must equal that in the Arrow vector: 3 != 2", ue.getMessage()); + } + + // test BigDecimal with larger precision than initialized + { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(12345), 2); + UnsupportedOperationException ue = + assertThrows(UnsupportedOperationException.class, () -> decimalVector.setSafe(0, decimal)); + assertEquals("BigDecimal precision can not be greater than that in the Arrow vector: 5 > 4", ue.getMessage()); + } + } + } + + @Test + public void testWriteBigEndian() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(38, 18, 256), allocator);) { + decimalVector.allocateNew(); + BigDecimal decimal1 = new BigDecimal("123456789.000000000000000000"); + BigDecimal decimal2 = new BigDecimal("11.123456789123456789"); + BigDecimal decimal3 = new BigDecimal("1.000000000000000000"); + BigDecimal decimal4 = new BigDecimal("0.111111111000000000"); + BigDecimal decimal5 = new BigDecimal("987654321.123456789000000000"); + BigDecimal decimal6 = new BigDecimal("222222222222.222222222000000000"); + BigDecimal decimal7 = new BigDecimal("7777777777777.666666667000000000"); + BigDecimal decimal8 = new BigDecimal("1212121212.343434343000000000"); + + byte[] decimalValue1 = decimal1.unscaledValue().toByteArray(); + byte[] decimalValue2 = decimal2.unscaledValue().toByteArray(); + byte[] decimalValue3 = decimal3.unscaledValue().toByteArray(); + byte[] decimalValue4 = decimal4.unscaledValue().toByteArray(); + byte[] decimalValue5 = decimal5.unscaledValue().toByteArray(); + byte[] decimalValue6 = decimal6.unscaledValue().toByteArray(); + byte[] decimalValue7 = decimal7.unscaledValue().toByteArray(); + byte[] decimalValue8 = decimal8.unscaledValue().toByteArray(); + + decimalVector.setBigEndian(0, decimalValue1); + decimalVector.setBigEndian(1, decimalValue2); + decimalVector.setBigEndian(2, decimalValue3); + decimalVector.setBigEndian(3, decimalValue4); + decimalVector.setBigEndian(4, decimalValue5); + decimalVector.setBigEndian(5, decimalValue6); + decimalVector.setBigEndian(6, decimalValue7); + decimalVector.setBigEndian(7, decimalValue8); + + decimalVector.setValueCount(8); + assertEquals(8, decimalVector.getValueCount()); + assertEquals(decimal1, decimalVector.getObject(0)); + assertEquals(decimal2, decimalVector.getObject(1)); + assertEquals(decimal3, decimalVector.getObject(2)); + assertEquals(decimal4, decimalVector.getObject(3)); + assertEquals(decimal5, decimalVector.getObject(4)); + assertEquals(decimal6, decimalVector.getObject(5)); + assertEquals(decimal7, decimalVector.getObject(6)); + assertEquals(decimal8, decimalVector.getObject(7)); + } + } + + @Test + public void testLongReadWrite() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(38, 0, 256), allocator)) { + decimalVector.allocateNew(); + + long[] longValues = {0L, -2L, Long.MAX_VALUE, Long.MIN_VALUE, 187L}; + + for (int i = 0; i < longValues.length; ++i) { + decimalVector.set(i, longValues[i]); + } + + decimalVector.setValueCount(longValues.length); + + for (int i = 0; i < longValues.length; ++i) { + assertEquals(new BigDecimal(longValues[i]), decimalVector.getObject(i)); + } + } + } + + + @Test + public void testBigDecimalReadWrite() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(38, 9, 256), allocator);) { + decimalVector.allocateNew(); + BigDecimal decimal1 = new BigDecimal("123456789.000000000"); + BigDecimal decimal2 = new BigDecimal("11.123456789"); + BigDecimal decimal3 = new BigDecimal("1.000000000"); + BigDecimal decimal4 = new BigDecimal("-0.111111111"); + BigDecimal decimal5 = new BigDecimal("-987654321.123456789"); + BigDecimal decimal6 = new BigDecimal("-222222222222.222222222"); + BigDecimal decimal7 = new BigDecimal("7777777777777.666666667"); + BigDecimal decimal8 = new BigDecimal("1212121212.343434343"); + + decimalVector.set(0, decimal1); + decimalVector.set(1, decimal2); + decimalVector.set(2, decimal3); + decimalVector.set(3, decimal4); + decimalVector.set(4, decimal5); + decimalVector.set(5, decimal6); + decimalVector.set(6, decimal7); + decimalVector.set(7, decimal8); + + decimalVector.setValueCount(8); + assertEquals(8, decimalVector.getValueCount()); + assertEquals(decimal1, decimalVector.getObject(0)); + assertEquals(decimal2, decimalVector.getObject(1)); + assertEquals(decimal3, decimalVector.getObject(2)); + assertEquals(decimal4, decimalVector.getObject(3)); + assertEquals(decimal5, decimalVector.getObject(4)); + assertEquals(decimal6, decimalVector.getObject(5)); + assertEquals(decimal7, decimalVector.getObject(6)); + assertEquals(decimal8, decimalVector.getObject(7)); + } + } + + /** + * Test {@link Decimal256Vector#setBigEndian(int, byte[])} which takes BE layout input and stores in LE layout. + * Cases to cover: input byte array in different lengths in range [1-16] and negative values. + */ + @Test + public void decimalBE2LE() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(23, 2, 256), allocator)) { + decimalVector.allocateNew(); + + BigInteger[] testBigInts = new BigInteger[] { + new BigInteger("0"), + new BigInteger("-1"), + new BigInteger("23"), + new BigInteger("234234"), + new BigInteger("-234234234"), + new BigInteger("234234234234"), + new BigInteger("-56345345345345"), + new BigInteger("2982346298346289346293467923465345634500"), // converts to 16+ byte array + new BigInteger("-389457298347598237459832459823434653600"), // converts to 16+ byte array + new BigInteger("-345345"), + new BigInteger("754533") + }; + + int insertionIdx = 0; + insertionIdx++; // insert a null + for (BigInteger val : testBigInts) { + decimalVector.setBigEndian(insertionIdx++, val.toByteArray()); + } + insertionIdx++; // insert a null + // insert a zero length buffer + decimalVector.setBigEndian(insertionIdx++, new byte[0]); + + // Try inserting a buffer larger than 33 bytes and expect a failure + final int insertionIdxCapture = insertionIdx; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> decimalVector.setBigEndian(insertionIdxCapture, new byte[33])); + assertTrue(ex.getMessage().equals("Invalid decimal value length. Valid length in [1 - 32], got 33")); + decimalVector.setValueCount(insertionIdx); + + // retrieve values and check if they are correct + int outputIdx = 0; + assertTrue(decimalVector.isNull(outputIdx++)); + for (BigInteger expected : testBigInts) { + final BigDecimal actual = decimalVector.getObject(outputIdx++); + assertEquals(expected, actual.unscaledValue()); + } + assertTrue(decimalVector.isNull(outputIdx++)); + assertEquals(BigInteger.valueOf(0), decimalVector.getObject(outputIdx).unscaledValue()); + } + } + + @Test + public void setUsingArrowBufOfLEInts() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(5, 2, 256), allocator); + ArrowBuf buf = allocator.buffer(8);) { + decimalVector.allocateNew(); + + // add a positive value equivalent to 705.32 + int val = 70532; + buf.setInt(0, val); + decimalVector.setSafe(0, 0, buf, 4); + + // add a -ve value equivalent to -705.32 + val = -70532; + buf.setInt(4, val); + decimalVector.setSafe(1, 4, buf, 4); + + decimalVector.setValueCount(2); + + BigDecimal [] expectedValues = new BigDecimal[] {BigDecimal.valueOf(705.32), BigDecimal + .valueOf(-705.32)}; + for (int i = 0; i < 2; i ++) { + BigDecimal value = decimalVector.getObject(i); + assertEquals(expectedValues[i], value); + } + } + + } + + @Test + public void setUsingArrowLongLEBytes() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(18, 0, 256), allocator); + ArrowBuf buf = allocator.buffer(16);) { + decimalVector.allocateNew(); + + long val = Long.MAX_VALUE; + buf.setLong(0, val); + decimalVector.setSafe(0, 0, buf, 8); + + val = Long.MIN_VALUE; + buf.setLong(8, val); + decimalVector.setSafe(1, 8, buf, 8); + + decimalVector.setValueCount(2); + + BigDecimal [] expectedValues = new BigDecimal[] {BigDecimal.valueOf(Long.MAX_VALUE), BigDecimal + .valueOf(Long.MIN_VALUE)}; + for (int i = 0; i < 2; i ++) { + BigDecimal value = decimalVector.getObject(i); + assertEquals(expectedValues[i], value); + } + } + } + + @Test + public void setUsingArrowBufOfBEBytes() { + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(5, 2, 256), allocator); + ArrowBuf buf = allocator.buffer(9);) { + BigDecimal [] expectedValues = new BigDecimal[] {BigDecimal.valueOf(705.32), BigDecimal + .valueOf(-705.32), BigDecimal.valueOf(705.32)}; + verifyWritingArrowBufWithBigEndianBytes(decimalVector, buf, expectedValues, 3); + } + + try (Decimal256Vector decimalVector = TestUtils.newVector(Decimal256Vector.class, "decimal", + new ArrowType.Decimal(43, 2, 256), allocator); + ArrowBuf buf = allocator.buffer(45);) { + BigDecimal[] expectedValues = new BigDecimal[] {new BigDecimal("29823462983462893462934679234653450000000.63"), + new BigDecimal("-2982346298346289346293467923465345.63"), + new BigDecimal("2982346298346289346293467923465345.63")}; + verifyWritingArrowBufWithBigEndianBytes(decimalVector, buf, expectedValues, 15); + } + } + + private void verifyWritingArrowBufWithBigEndianBytes(Decimal256Vector decimalVector, + ArrowBuf buf, BigDecimal[] expectedValues, + int length) { + decimalVector.allocateNew(); + for (int i = 0; i < expectedValues.length; i++) { + byte[] bigEndianBytes = expectedValues[i].unscaledValue().toByteArray(); + buf.setBytes(length * i , bigEndianBytes, 0 , bigEndianBytes.length); + decimalVector.setBigEndianSafe(i, length * i, buf, bigEndianBytes.length); + } + + decimalVector.setValueCount(3); + + for (int i = 0; i < expectedValues.length; i ++) { + BigDecimal value = decimalVector.getObject(i); + assertEquals(expectedValues[i], value); + } + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index 28d799bb1ca..25f480119db 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -60,7 +60,7 @@ public void terminate() throws Exception { @Test public void testValuesWriteRead() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(10, scale), allocator);) { + new ArrowType.Decimal(10, scale, 128), allocator);) { try (DecimalVector oldConstructor = new DecimalVector("decimal", allocator, 10, scale);) { assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); @@ -86,7 +86,7 @@ public void testValuesWriteRead() { @Test public void testBigDecimalDifferentScaleAndPrecision() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(4, 2), allocator);) { + new ArrowType.Decimal(4, 2, 128), allocator);) { decimalVector.allocateNew(); // test BigDecimal with different scale @@ -116,7 +116,7 @@ public void testBigDecimalDifferentScaleAndPrecision() { @Test public void testWriteBigEndian() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(38, 9), allocator);) { + new ArrowType.Decimal(38, 9, 128), allocator);) { decimalVector.allocateNew(); BigDecimal decimal1 = new BigDecimal("123456789.000000000"); BigDecimal decimal2 = new BigDecimal("11.123456789"); @@ -161,7 +161,7 @@ public void testWriteBigEndian() { @Test public void testLongReadWrite() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(38, 0), allocator)) { + new ArrowType.Decimal(38, 0, 128), allocator)) { decimalVector.allocateNew(); long[] longValues = {0L, -2L, Long.MAX_VALUE, Long.MIN_VALUE, 187L}; @@ -182,7 +182,7 @@ public void testLongReadWrite() { @Test public void testBigDecimalReadWrite() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(38, 9), allocator);) { + new ArrowType.Decimal(38, 9, 128), allocator);) { decimalVector.allocateNew(); BigDecimal decimal1 = new BigDecimal("123456789.000000000"); BigDecimal decimal2 = new BigDecimal("11.123456789"); @@ -222,7 +222,7 @@ public void testBigDecimalReadWrite() { @Test public void decimalBE2LE() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(21, 2), allocator)) { + new ArrowType.Decimal(21, 2, 128), allocator)) { decimalVector.allocateNew(); BigInteger[] testBigInts = new BigInteger[] { @@ -272,7 +272,7 @@ public void decimalBE2LE() { @Test public void setUsingArrowBufOfLEInts() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(5, 2), allocator); + new ArrowType.Decimal(5, 2, 128), allocator); ArrowBuf buf = allocator.buffer(8);) { decimalVector.allocateNew(); @@ -301,7 +301,7 @@ public void setUsingArrowBufOfLEInts() { @Test public void setUsingArrowLongLEBytes() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(18, 0), allocator); + new ArrowType.Decimal(18, 0, 128), allocator); ArrowBuf buf = allocator.buffer(16);) { decimalVector.allocateNew(); @@ -327,7 +327,7 @@ public void setUsingArrowLongLEBytes() { @Test public void setUsingArrowBufOfBEBytes() { try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(5, 2), allocator); + new ArrowType.Decimal(5, 2, 128), allocator); ArrowBuf buf = allocator.buffer(9);) { BigDecimal [] expectedValues = new BigDecimal[] {BigDecimal.valueOf(705.32), BigDecimal .valueOf(-705.32), BigDecimal.valueOf(705.32)}; @@ -335,7 +335,7 @@ public void setUsingArrowBufOfBEBytes() { } try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", - new ArrowType.Decimal(36, 2), allocator); + new ArrowType.Decimal(36, 2, 128), allocator); ArrowBuf buf = allocator.buffer(45);) { BigDecimal[] expectedValues = new BigDecimal[] {new BigDecimal("2982346298346289346293467923465345.63"), new BigDecimal("-2982346298346289346293467923465345.63"), diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java b/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java index 18175276737..97930f433d3 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestTypeLayout.java @@ -61,9 +61,13 @@ public void testTypeBufferCount() { type = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); - type = new ArrowType.Decimal(10, 10); + type = new ArrowType.Decimal(10, 10, 128); assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + type = new ArrowType.Decimal(10, 10, 256); + assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); + + type = new ArrowType.FixedSizeBinary(5); assertEquals(TypeLayout.getTypeBufferCount(type), TypeLayout.getTypeLayout(type).getBufferLayouts().size()); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java index 089f1f84ff8..b9e7c8661a7 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorAlloc.java @@ -72,7 +72,7 @@ public void testVectorAllocWithField() { field("UTF8", MinorType.VARCHAR.getType()), field("VARBINARY", MinorType.VARBINARY.getType()), field("BIT", MinorType.BIT.getType()), - field("DECIMAL", new Decimal(38, 5)), + field("DECIMAL", new Decimal(38, 5, 128)), field("FIXEDSIZEBINARY", new FixedSizeBinary(50)), field("DATEDAY", MinorType.DATEDAY.getType()), field("DATEMILLI", MinorType.DATEMILLI.getType()), diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java index a0f35052634..0992ffc7317 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java @@ -522,12 +522,12 @@ public void testMapWithListValue() throws Exception { public void testCopyFixedSizedListOfDecimalsVector() { try (FixedSizeListVector from = FixedSizeListVector.empty("v", 4, allocator); FixedSizeListVector to = FixedSizeListVector.empty("v", 4, allocator)) { - from.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 0))); - to.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 0))); + from.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 0, 128))); + to.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 0, 128))); DecimalHolder holder = new DecimalHolder(); holder.buffer = allocator.buffer(DecimalUtility.DECIMAL_BYTE_LENGTH); - ArrowType arrowType = new ArrowType.Decimal(3, 0); + ArrowType arrowType = new ArrowType.Decimal(3, 0, 128); // populate from vector UnionFixedSizeListWriter writer = from.getWriter(); @@ -535,13 +535,13 @@ public void testCopyFixedSizedListOfDecimalsVector() { writer.startList(); writer.decimal().writeDecimal(BigDecimal.valueOf(i)); - DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(i * 2), holder.buffer, 0); + DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(i * 2), holder.buffer, 0, /*byteWidth=*/16); holder.start = 0; holder.scale = 0; holder.precision = 3; writer.decimal().write(holder); - DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(i * 3), holder.buffer, 0); + DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(i * 3), holder.buffer, 0, /*byteWidth=*/16); writer.decimal().writeDecimal(0, holder.buffer, arrowType); writer.decimal().writeBigEndianBytesToDecimal(BigDecimal.valueOf(i * 4).unscaledValue().toByteArray(), @@ -582,7 +582,7 @@ public void testCopyUnionListWithDecimal() { listWriter.decimal().writeDecimal(BigDecimal.valueOf(i * 2)); listWriter.integer().writeInt(i); listWriter.decimal().writeBigEndianBytesToDecimal(BigDecimal.valueOf(i * 3).unscaledValue().toByteArray(), - new ArrowType.Decimal(3, 0)); + new ArrowType.Decimal(3, 0, 128)); listWriter.endList(); } @@ -623,7 +623,7 @@ public void testCopyStructVector() { innerStructWriter.integer("innerint").writeInt(i * 3); innerStructWriter.decimal("innerdec", 0, 38).writeDecimal(BigDecimal.valueOf(i * 4)); innerStructWriter.decimal("innerdec", 0, 38).writeBigEndianBytesToDecimal(BigDecimal.valueOf(i * 4) - .unscaledValue().toByteArray(), new ArrowType.Decimal(3, 0)); + .unscaledValue().toByteArray(), new ArrowType.Decimal(3, 0, 128)); innerStructWriter.end(); structWriter.end(); } @@ -649,8 +649,8 @@ public void testCopyStructVector() { public void testCopyDecimalVectorWrongScale() { try (FixedSizeListVector from = FixedSizeListVector.empty("v", 3, allocator); FixedSizeListVector to = FixedSizeListVector.empty("v", 3, allocator)) { - from.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 2))); - to.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 1))); + from.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 2, 128))); + to.addOrGetVector(FieldType.nullable(new ArrowType.Decimal(3, 1, 128))); // populate from vector UnionFixedSizeListWriter writer = from.getWriter(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index 769a94f50f2..1e6fe495750 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -311,20 +311,20 @@ public void listDecimalType() { UnionListWriter listWriter = new UnionListWriter(listVector); DecimalHolder holder = new DecimalHolder(); holder.buffer = allocator.buffer(DecimalUtility.DECIMAL_BYTE_LENGTH); - ArrowType arrowType = new ArrowType.Decimal(10, 0); + ArrowType arrowType = new ArrowType.Decimal(10, 0, 128); for (int i = 0; i < COUNT; i++) { listWriter.startList(); for (int j = 0; j < i % 7; j++) { if (j % 4 == 0) { listWriter.writeDecimal(new BigDecimal(j)); } else if (j % 4 == 1) { - DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(j), holder.buffer, 0); + DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(j), holder.buffer, 0, 16); holder.start = 0; holder.scale = 0; holder.precision = 10; listWriter.write(holder); } else if (j % 4 == 2) { - DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(j), holder.buffer, 0); + DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(j), holder.buffer, 0, 16); listWriter.writeDecimal(0, holder.buffer, arrowType); } else { byte[] value = BigDecimal.valueOf(j).unscaledValue().toByteArray(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestSchema.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestSchema.java index 3ca8d0af6a6..3d93407ac64 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestSchema.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestSchema.java @@ -90,7 +90,7 @@ public void testAll() throws IOException { field("g", new Utf8()), field("h", new Binary()), field("i", new Bool()), - field("j", new Decimal(5, 5)), + field("j", new Decimal(5, 5, 128)), field("k", new Date(DateUnit.DAY)), field("l", new Date(DateUnit.MILLISECOND)), field("m", new Time(TimeUnit.SECOND, 32)), diff --git a/java/vector/src/test/java/org/apache/arrow/vector/util/DecimalUtilityTest.java b/java/vector/src/test/java/org/apache/arrow/vector/util/DecimalUtilityTest.java index 667e9624ed8..2e255e6dda2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/util/DecimalUtilityTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/util/DecimalUtilityTest.java @@ -27,76 +27,81 @@ import org.junit.Test; public class DecimalUtilityTest { - private static final BigInteger MAX_BIG_INT = java.math.BigInteger.valueOf(10).pow(38) - .subtract(java.math.BigInteger.ONE); - private static final BigDecimal MAX_DECIMAL = new java.math.BigDecimal(MAX_BIG_INT, 0); - private static final BigInteger MIN_BIG_INT = MAX_BIG_INT.multiply(BigInteger.valueOf(-1)); - private static final BigDecimal MIN_DECIMAL = new java.math.BigDecimal(MIN_BIG_INT, 0); + private static final BigInteger[] MAX_BIG_INT = new BigInteger[]{BigInteger.valueOf(10).pow(38) + .subtract(java.math.BigInteger.ONE), java.math.BigInteger.valueOf(10).pow(76)}; + private static final BigInteger[] MIN_BIG_INT = new BigInteger[]{MAX_BIG_INT[0].multiply(BigInteger.valueOf(-1)), + MAX_BIG_INT[1].multiply(BigInteger.valueOf(-1))}; @Test public void testSetByteArrayInDecimalArrowBuf() { - try (BufferAllocator allocator = new RootAllocator(128); - ArrowBuf buf = allocator.buffer(16); - ) { - int [] intValues = new int [] {Integer.MAX_VALUE, Integer.MIN_VALUE, 0}; - for (int val : intValues) { - buf.clear(); - DecimalUtility.writeByteArrayToArrowBuf(BigInteger.valueOf(val).toByteArray(), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = BigDecimal.valueOf(val); - Assert.assertEquals(expected, actual); - } + int[] byteLengths = new int[]{16, 32}; + for (int x = 0; x < 2; x++) { + try (BufferAllocator allocator = new RootAllocator(128); + ArrowBuf buf = allocator.buffer(byteLengths[x]); + ) { + int [] intValues = new int [] {Integer.MAX_VALUE, Integer.MIN_VALUE, 0}; + for (int val : intValues) { + buf.clear(); + DecimalUtility.writeByteArrayToArrowBuf(BigInteger.valueOf(val).toByteArray(), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = BigDecimal.valueOf(val); + Assert.assertEquals(expected, actual); + } - long [] longValues = new long[] {Long.MIN_VALUE, 0 , Long.MAX_VALUE}; - for (long val : longValues) { - buf.clear(); - DecimalUtility.writeByteArrayToArrowBuf(BigInteger.valueOf(val).toByteArray(), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = BigDecimal.valueOf(val); - Assert.assertEquals(expected, actual); - } + long [] longValues = new long[] {Long.MIN_VALUE, 0 , Long.MAX_VALUE}; + for (long val : longValues) { + buf.clear(); + DecimalUtility.writeByteArrayToArrowBuf(BigInteger.valueOf(val).toByteArray(), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = BigDecimal.valueOf(val); + Assert.assertEquals(expected, actual); + } - BigInteger [] decimals = new BigInteger[] {MAX_BIG_INT, new BigInteger("0"), MIN_BIG_INT}; - for (BigInteger val : decimals) { - buf.clear(); - DecimalUtility.writeByteArrayToArrowBuf(val.toByteArray(), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = new BigDecimal(val); - Assert.assertEquals(expected, actual); + BigInteger [] decimals = new BigInteger[] {MAX_BIG_INT[x], new BigInteger("0"), MIN_BIG_INT[x]}; + for (BigInteger val : decimals) { + buf.clear(); + DecimalUtility.writeByteArrayToArrowBuf(val.toByteArray(), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = new BigDecimal(val); + Assert.assertEquals(expected, actual); + } } } } @Test public void testSetBigDecimalInDecimalArrowBuf() { - try (BufferAllocator allocator = new RootAllocator(128); - ArrowBuf buf = allocator.buffer(16); - ) { - int [] intValues = new int [] {Integer.MAX_VALUE, Integer.MIN_VALUE, 0}; - for (int val : intValues) { - buf.clear(); - DecimalUtility.writeBigDecimalToArrowBuf(BigDecimal.valueOf(val), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = BigDecimal.valueOf(val); - Assert.assertEquals(expected, actual); - } + int[] byteLengths = new int[]{16, 32}; + for (int x = 0; x < 2; x++) { + try (BufferAllocator allocator = new RootAllocator(128); + ArrowBuf buf = allocator.buffer(byteLengths[x]); + ) { + int [] intValues = new int [] {Integer.MAX_VALUE, Integer.MIN_VALUE, 0}; + for (int val : intValues) { + buf.clear(); + DecimalUtility.writeBigDecimalToArrowBuf(BigDecimal.valueOf(val), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = BigDecimal.valueOf(val); + Assert.assertEquals(expected, actual); + } - long [] longValues = new long[] {Long.MIN_VALUE, 0 , Long.MAX_VALUE}; - for (long val : longValues) { - buf.clear(); - DecimalUtility.writeBigDecimalToArrowBuf(BigDecimal.valueOf(val), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = BigDecimal.valueOf(val); - Assert.assertEquals(expected, actual); - } + long [] longValues = new long[] {Long.MIN_VALUE, 0 , Long.MAX_VALUE}; + for (long val : longValues) { + buf.clear(); + DecimalUtility.writeBigDecimalToArrowBuf(BigDecimal.valueOf(val), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = BigDecimal.valueOf(val); + Assert.assertEquals(expected, actual); + } - BigInteger [] decimals = new BigInteger[] {MAX_BIG_INT, new BigInteger("0"), MIN_BIG_INT}; - for (BigInteger val : decimals) { - buf.clear(); - DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(val), buf, 0); - BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0); - BigDecimal expected = new BigDecimal(val); - Assert.assertEquals(expected, actual); + BigInteger [] decimals = new BigInteger[] {MAX_BIG_INT[x], new BigInteger("0"), MIN_BIG_INT[x]}; + for (BigInteger val : decimals) { + buf.clear(); + DecimalUtility.writeBigDecimalToArrowBuf(new BigDecimal(val), buf, 0, byteLengths[x]); + BigDecimal actual = DecimalUtility.getBigDecimalFromArrowBuf(buf, 0, 0, byteLengths[x]); + BigDecimal expected = new BigDecimal(val); + Assert.assertEquals(expected, actual); + } } } } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index fd09e09fdcc..bc669325c2e 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -95,7 +95,7 @@ def show_versions(): float16, float32, float64, binary, string, utf8, large_binary, large_string, large_utf8, - decimal128, + decimal128, decimal256, list_, large_list, map_, struct, union, dictionary, field, type_for_alias, @@ -103,7 +103,7 @@ def show_versions(): ListType, LargeListType, MapType, FixedSizeListType, UnionType, TimestampType, Time32Type, Time64Type, DurationType, - FixedSizeBinaryType, Decimal128Type, + FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, @@ -133,13 +133,13 @@ def show_versions(): DictionaryArray, Date32Array, Date64Array, TimestampArray, Time32Array, Time64Array, DurationArray, - Decimal128Array, StructArray, ExtensionArray, + Decimal128Array, Decimal256Array, StructArray, ExtensionArray, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, HalfFloatScalar, FloatScalar, DoubleScalar, - Decimal128Scalar, + Decimal128Scalar, Decimal256Scalar, ListScalar, LargeListScalar, FixedSizeListScalar, Date32Scalar, Date64Scalar, Time32Scalar, Time64Scalar, @@ -344,6 +344,7 @@ def _deprecate_scalar(ty, symbol): FixedSizeBinaryValue = _deprecate_scalar("FixedSizeBinary", FixedSizeBinaryScalar) Decimal128Value = _deprecate_scalar("Decimal128", Decimal128Scalar) +Decimal256Value = _deprecate_scalar("Decimal256", Decimal256Scalar) UnionValue = _deprecate_scalar("Union", UnionScalar) StructValue = _deprecate_scalar("Struct", StructScalar) DictionaryValue = _deprecate_scalar("Dictionary", DictionaryScalar) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index e4bfc36c5ec..006e35150bf 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1486,6 +1486,12 @@ cdef class Decimal128Array(FixedSizeBinaryArray): Concrete class for Arrow arrays of decimal128 data type. """ + +cdef class Decimal256Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal256 data type. + """ + cdef class BaseListArray(Array): def flatten(self): @@ -2276,7 +2282,8 @@ cdef dict _array_classes = { _Type_LARGE_STRING: LargeStringArray, _Type_DICTIONARY: DictionaryArray, _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray, - _Type_DECIMAL: Decimal128Array, + _Type_DECIMAL128: Decimal128Array, + _Type_DECIMAL256: Decimal256Array, _Type_STRUCT: StructArray, _Type_EXTENSION: ExtensionArray, } diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0b8181baa7a..ddec351ba6b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -49,6 +49,11 @@ cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil: c_string ToString(int32_t scale) const +cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil: + cdef cppclass CDecimal256" arrow::Decimal256": + c_string ToString(int32_t scale) const + + cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CBuildInfo "arrow::BuildInfo": @@ -86,7 +91,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: _Type_FLOAT" arrow::Type::FLOAT" _Type_DOUBLE" arrow::Type::DOUBLE" - _Type_DECIMAL" arrow::Type::DECIMAL" + _Type_DECIMAL128" arrow::Type::DECIMAL128" + _Type_DECIMAL256" arrow::Type::DECIMAL256" _Type_DATE32" arrow::Type::DATE32" _Type_DATE64" arrow::Type::DATE64" @@ -356,6 +362,12 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: int precision() int scale() + cdef cppclass CDecimal256Type \ + " arrow::Decimal256Type"(CFixedSizeBinaryType): + CDecimal256Type(int precision, int scale) + int precision() + int scale() + cdef cppclass CField" arrow::Field": cppclass CMergeOptions "arrow::Field::MergeOptions": c_bool promote_nullability @@ -541,6 +553,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: ): c_string FormatValue(int i) + cdef cppclass CDecimal256Array" arrow::Decimal256Array"( + CFixedSizeBinaryArray + ): + c_string FormatValue(int i) + cdef cppclass CListArray" arrow::ListArray"(CArray): @staticmethod CResult[shared_ptr[CArray]] FromArrays( @@ -935,6 +952,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDecimal128Scalar" arrow::Decimal128Scalar"(CScalar): CDecimal128 value + cdef cppclass CDecimal256Scalar" arrow::Decimal256Scalar"(CScalar): + CDecimal256 value + cdef cppclass CDate32Scalar" arrow::Date32Scalar"(CScalar): int32_t value diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 5b2958a0647..fb390e1af42 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -135,6 +135,11 @@ cdef class Decimal128Type(FixedSizeBinaryType): const CDecimal128Type* decimal128_type +cdef class Decimal256Type(FixedSizeBinaryType): + cdef: + const CDecimal256Type* decimal256_type + + cdef class BaseExtensionType(DataType): cdef: const CExtensionType* ext_type @@ -345,6 +350,10 @@ cdef class Decimal128Array(FixedSizeBinaryArray): pass +cdef class Decimal256Array(FixedSizeBinaryArray): + pass + + cdef class StructArray(Array): pass diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 0f57ced3745..eba0f5a47af 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -73,7 +73,8 @@ Type_INT64 = _Type_INT64 Type_HALF_FLOAT = _Type_HALF_FLOAT Type_FLOAT = _Type_FLOAT Type_DOUBLE = _Type_DOUBLE -Type_DECIMAL = _Type_DECIMAL +Type_DECIMAL128 = _Type_DECIMAL128 +Type_DECIMAL256 = _Type_DECIMAL256 Type_DATE32 = _Type_DATE32 Type_DATE64 = _Type_DATE64 Type_TIMESTAMP = _Type_TIMESTAMP diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 81c57b11926..aa738f9aaea 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -103,8 +103,10 @@ cdef api object pyarrow_wrap_data_type( out = DurationType.__new__(DurationType) elif type.get().id() == _Type_FIXED_SIZE_BINARY: out = FixedSizeBinaryType.__new__(FixedSizeBinaryType) - elif type.get().id() == _Type_DECIMAL: + elif type.get().id() == _Type_DECIMAL128: out = Decimal128Type.__new__(Decimal128Type) + elif type.get().id() == _Type_DECIMAL256: + out = Decimal256Type.__new__(Decimal256Type) elif type.get().id() == _Type_EXTENSION: ext_type = type.get() cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type) diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 3e72d060d69..effe60c73b2 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -306,6 +306,26 @@ cdef class Decimal128Scalar(Scalar): return None +cdef class Decimal256Scalar(Scalar): + """ + Concrete class for decimal256 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal256Scalar* sp = self.wrapped.get() + CDecimal256Type* dtype = sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + cdef class Date32Scalar(Scalar): """ Concrete class for date32 scalars. @@ -805,7 +825,8 @@ cdef dict _scalar_classes = { _Type_HALF_FLOAT: HalfFloatScalar, _Type_FLOAT: FloatScalar, _Type_DOUBLE: DoubleScalar, - _Type_DECIMAL: Decimal128Scalar, + _Type_DECIMAL128: Decimal128Scalar, + _Type_DECIMAL256: Decimal256Scalar, _Type_DATE32: Date32Scalar, _Type_DATE64: Date64Scalar, _Type_TIME32: Time32Scalar, diff --git a/python/pyarrow/tests/strategies.py b/python/pyarrow/tests/strategies.py index cb9d9434fbe..92b0d3617c0 100644 --- a/python/pyarrow/tests/strategies.py +++ b/python/pyarrow/tests/strategies.py @@ -73,12 +73,18 @@ pa.float32(), pa.float64() ]) -decimal_type = st.builds( +decimal128_type = st.builds( pa.decimal128, precision=st.integers(min_value=1, max_value=38), scale=st.integers(min_value=1, max_value=38) ) -numeric_types = st.one_of(integer_types, floating_types, decimal_type) +decimal256_type = st.builds( + pa.decimal256, + precision=st.integers(min_value=1, max_value=76), + scale=st.integers(min_value=1, max_value=76) +) +numeric_types = st.one_of(integer_types, floating_types, + decimal128_type, decimal256_type) date_types = st.sampled_from([ pa.date32(), @@ -359,7 +365,7 @@ def tables(draw, type, rows=None, max_fields=None): bool_type, integer_types, st.sampled_from([pa.float32(), pa.float64()]), - decimal_type, + decimal128_type, date_types, time_types, # Need to exclude timestamp and duration types otherwise hypothesis diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 5c49dcd4937..17d52188602 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -1283,7 +1283,7 @@ def test_decimal_to_int_non_integer(): for case in non_integer_cases: # test safe casting raises - msg_regexp = 'Rescaling decimal value would cause data loss' + msg_regexp = 'Rescaling Decimal128 value would cause data loss' with pytest.raises(pa.ArrowInvalid, match=msg_regexp): _check_cast_case(case) @@ -1302,8 +1302,8 @@ def test_decimal_to_decimal(): ) assert result.equals(expected) - with pytest.raises(pa.ArrowInvalid, - match='Rescaling decimal value would cause data loss'): + msg_regexp = 'Rescaling Decimal128 value would cause data loss' + with pytest.raises(pa.ArrowInvalid, match=msg_regexp): result = arr.cast(pa.decimal128(9, 1)) result = arr.cast(pa.decimal128(9, 1), safe=False) diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py index cb6b4b3b133..6edf049075f 100644 --- a/python/pyarrow/tests/test_convert_builtin.py +++ b/python/pyarrow/tests/test_convert_builtin.py @@ -1425,61 +1425,62 @@ def test_sequence_mixed_types_with_specified_type_fails(): def test_sequence_decimal(): data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')] - type = pa.decimal128(precision=7, scale=3) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=7, scale=3)) + assert arr.to_pylist() == data def test_sequence_decimal_different_precisions(): data = [ decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234') ] - type = pa.decimal128(precision=13, scale=3) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=13, scale=3)) + assert arr.to_pylist() == data def test_sequence_decimal_no_scale(): data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')] - type = pa.decimal128(precision=10) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=10)) + assert arr.to_pylist() == data def test_sequence_decimal_negative(): data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')] - type = pa.decimal128(precision=10, scale=6) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=10, scale=6)) + assert arr.to_pylist() == data def test_sequence_decimal_no_whole_part(): data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')] - type = pa.decimal128(precision=7, scale=7) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=7, scale=7)) + assert arr.to_pylist() == data def test_sequence_decimal_large_integer(): data = [decimal.Decimal('-394029506937548693.42983'), decimal.Decimal('32358695912932.01033')] - type = pa.decimal128(precision=23, scale=5) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=23, scale=5)) + assert arr.to_pylist() == data def test_sequence_decimal_from_integers(): data = [0, 1, -39402950693754869342983] expected = [decimal.Decimal(x) for x in data] + # TODO: update this test after scaling implementation. type = pa.decimal128(precision=28, scale=5) arr = pa.array(data, type=type) assert arr.to_pylist() == expected def test_sequence_decimal_too_high_precision(): - # ARROW-6989 python decimal created from float has too high precision + # ARROW-6989 python decimal has too high precision with pytest.raises(ValueError, match="precision out of range"): - pa.array([decimal.Decimal(123.234)]) + pa.array([decimal.Decimal('1' * 80)]) def test_range_types(): diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index fa48ad8b5f2..f516afdf2fe 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -43,6 +43,8 @@ (np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue), (1.0, pa.float32(), pa.FloatScalar, pa.FloatValue), (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value), + (decimal.Decimal("1.1234567890123456789012345678901234567890"), + None, pa.Decimal256Scalar, pa.Decimal256Value), ("string", None, pa.StringScalar, pa.StringValue), (b"bytes", None, pa.BinaryScalar, pa.BinaryValue), ("largestring", pa.large_string(), pa.LargeStringScalar, @@ -176,7 +178,7 @@ def test_numerics(): assert s.as_py() == 0.5 -def test_decimal(): +def test_decimal128(): v = decimal.Decimal("1.123") s = pa.scalar(v) assert isinstance(s, pa.Decimal128Scalar) @@ -194,6 +196,25 @@ def test_decimal(): assert s.as_py() == v +def test_decimal256(): + v = decimal.Decimal("1234567890123456789012345678901234567890.123") + s = pa.scalar(v) + assert isinstance(s, pa.Decimal256Scalar) + assert s.as_py() == v + assert s.type == pa.decimal256(43, 3) + + v = decimal.Decimal("1.1234") + with pytest.raises(pa.ArrowInvalid): + pa.scalar(v, type=pa.decimal256(4, scale=3)) + # TODO: Add the following after implementing Decimal256 scaling. + # with pytest.raises(pa.ArrowInvalid): + # pa.scalar(v, type=pa.decimal256(5, scale=3)) + + s = pa.scalar(v, type=pa.decimal256(5, scale=4)) + assert isinstance(s, pa.Decimal256Scalar) + assert s.as_py() == v + + def test_date(): # ARROW-5125 d1 = datetime.date(3200, 1, 1) diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index 3ba9b7bbe4c..da67aaa19aa 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -604,6 +604,7 @@ def test_type_schema_pickling(): pa.timestamp('ms'), pa.timestamp('ns'), pa.decimal128(12, 2), + pa.decimal256(76, 38), pa.field('a', 'string', metadata={b'foo': b'bar'}) ] diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 4de5ffabfad..e5c11415c05 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -53,6 +53,7 @@ def get_many_types(): pa.float32(), pa.float64(), pa.decimal128(19, 4), + pa.decimal256(76, 38), pa.string(), pa.binary(), pa.binary(10), @@ -124,8 +125,21 @@ def test_null_field_may_not_be_non_nullable(): def test_is_decimal(): - assert types.is_decimal(pa.decimal128(19, 4)) - assert not types.is_decimal(pa.int32()) + decimal128 = pa.decimal128(19, 4) + decimal256 = pa.decimal256(76, 38) + int32 = pa.int32() + + assert types.is_decimal(decimal128) + assert types.is_decimal(decimal256) + assert not types.is_decimal(int32) + + assert types.is_decimal128(decimal128) + assert not types.is_decimal128(decimal256) + assert not types.is_decimal128(int32) + + assert not types.is_decimal256(decimal128) + assert types.is_decimal256(decimal256) + assert not types.is_decimal256(int32) def test_is_list(): @@ -695,6 +709,7 @@ def test_bit_width(): (pa.uint32(), 32), (pa.float16(), 16), (pa.decimal128(19, 4), 128), + (pa.decimal256(76, 38), 256), (pa.binary(42), 42 * 8)]: assert ty.bit_width == expected for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]: @@ -712,6 +727,10 @@ def test_decimal_properties(): assert ty.byte_width == 16 assert ty.precision == 19 assert ty.scale == 4 + ty = pa.decimal256(76, 38) + assert ty.byte_width == 32 + assert ty.precision == 76 + assert ty.scale == 38 def test_decimal_overflow(): @@ -719,7 +738,13 @@ def test_decimal_overflow(): pa.decimal128(38, 0) for i in (0, -1, 39): with pytest.raises(ValueError): - pa.decimal128(39, 0) + pa.decimal128(i, 0) + + pa.decimal256(1, 0) + pa.decimal256(76, 0) + for i in (0, -1, 77): + with pytest.raises(ValueError): + pa.decimal256(i, 0) def test_type_equality_operators(): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 19d00ac33e4..b337f2428c7 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -46,8 +46,8 @@ cdef dict _pandas_type_map = { _Type_FIXED_SIZE_BINARY: np.object_, _Type_STRING: np.object_, _Type_LIST: np.object_, - _Type_DECIMAL: np.object_, _Type_MAP: np.object_, + _Type_DECIMAL128: np.object_, } cdef dict _pep3118_type_map = { @@ -624,6 +624,33 @@ cdef class Decimal128Type(FixedSizeBinaryType): return self.decimal128_type.scale() +cdef class Decimal256Type(FixedSizeBinaryType): + """ + Concrete class for Decimal256 data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal256_type = type.get() + + def __reduce__(self): + return decimal256, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + """ + return self.decimal256_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + """ + return self.decimal256_type.scale() + + cdef class BaseExtensionType(DataType): """ Concrete base class for extension types. @@ -2093,6 +2120,26 @@ cpdef DataType decimal128(int precision, int scale=0): return pyarrow_wrap_data_type(decimal_type) +cpdef DataType decimal256(int precision, int scale=0): + """ + Create decimal type with precision and scale and 256bit width. + + Parameters + ---------- + precision : int + scale : int + + Returns + ------- + decimal_type : Decimal256Type + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 76: + raise ValueError("precision should be between 1 and 76") + decimal_type.reset(new CDecimal256Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + def string(): """ Create UTF8 variable-length string type. diff --git a/python/pyarrow/types.py b/python/pyarrow/types.py index 66791543fec..708e2bc4643 100644 --- a/python/pyarrow/types.py +++ b/python/pyarrow/types.py @@ -31,6 +31,7 @@ lib.Type_UINT64} _INTEGER_TYPES = _SIGNED_INTEGER_TYPES | _UNSIGNED_INTEGER_TYPES _FLOATING_TYPES = {lib.Type_HALF_FLOAT, lib.Type_FLOAT, lib.Type_DOUBLE} +_DECIMAL_TYPES = {lib.Type_DECIMAL128, lib.Type_DECIMAL256} _DATE_TYPES = {lib.Type_DATE32, lib.Type_DATE64} _TIME_TYPES = {lib.Type_TIME32, lib.Type_TIME64} _TEMPORAL_TYPES = {lib.Type_TIMESTAMP, @@ -325,7 +326,21 @@ def is_decimal(t): """ Return True if value is an instance of a decimal type. """ - return t.id == lib.Type_DECIMAL + return t.id in _DECIMAL_TYPES + + +def is_decimal128(t): + """ + Return True if value is an instance of a decimal type. + """ + return t.id == lib.Type_DECIMAL128 + + +def is_decimal256(t): + """ + Return True if value is an instance of a decimal type. + """ + return t.id == lib.Type_DECIMAL256 def is_dictionary(t): diff --git a/r/R/enums.R b/r/R/enums.R index 05905710231..14910bc92e0 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -66,18 +66,19 @@ Type <- enum("Type::type", INTERVAL_MONTHS = 21L, INTERVAL_DAY_TIME = 22L, DECIMAL = 23L, - LIST = 24L, - STRUCT = 25L, - SPARSE_UNION = 26L, - DENSE_UNION = 27L, - DICTIONARY = 28L, - MAP = 29L, - EXTENSION = 30L, - FIXED_SIZE_LIST = 31L, - DURATION = 32L, - LARGE_STRING = 33L, - LARGE_BINARY = 34L, - LARGE_LIST = 35L + DECIMAL256 = 24L, + LIST = 25L, + STRUCT = 26L, + SPARSE_UNION = 27L, + DENSE_UNION = 28L, + DICTIONARY = 29L, + MAP = 30L, + EXTENSION = 31L, + FIXED_SIZE_LIST = 32L, + DURATION = 33L, + LARGE_STRING = 34L, + LARGE_BINARY = 35L, + LARGE_LIST = 36L ) #' @rdname enums diff --git a/testing b/testing index 90e15c6bd4f..860376d4e58 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 90e15c6bd4fc50eb8cac6ee35cc3ab43807cfabe +Subproject commit 860376d4e586a3ac34ec93089889da624ead6c2a