diff --git a/rust/parquet/src/arrow/arrow_writer.rs b/rust/parquet/src/arrow/arrow_writer.rs index d1f84d8d47b..298ab34e008 100644 --- a/rust/parquet/src/arrow/arrow_writer.rs +++ b/rust/parquet/src/arrow/arrow_writer.rs @@ -25,7 +25,9 @@ use arrow::record_batch::RecordBatch; use arrow_array::Array; use super::levels::LevelInfo; -use super::schema::add_encoded_arrow_schema_to_metadata; +use super::schema::{ + add_encoded_arrow_schema_to_metadata, decimal_length_from_precision, +}; use crate::column::writer::ColumnWriter; use crate::errors::{ParquetError, Result}; @@ -143,7 +145,8 @@ fn write_leaves( | ArrowDataType::LargeBinary | ArrowDataType::Binary | ArrowDataType::Utf8 - | ArrowDataType::LargeUtf8 => { + | ArrowDataType::LargeUtf8 + | ArrowDataType::Decimal(_, _) => { let mut col_writer = get_col_writer(&mut row_group_writer)?; write_leaf( &mut col_writer, @@ -188,7 +191,6 @@ fn write_leaves( )), ArrowDataType::FixedSizeList(_, _) | ArrowDataType::FixedSizeBinary(_) - | ArrowDataType::Decimal(_, _) | ArrowDataType::Union(_) => Err(ParquetError::NYI( "Attempting to write an Arrow type that is not yet implemented".to_string(), )), @@ -315,6 +317,13 @@ fn write_leaf( .unwrap(); get_fsb_array_slice(&array, &indices) } + ArrowDataType::Decimal(_, _) => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + get_decimal_array_slice(&array, &indices) + } _ => { return Err(ParquetError::NYI( "Attempting to write an Arrow type that is not yet implemented" @@ -416,6 +425,20 @@ fn get_interval_dt_array_slice( values } +fn get_decimal_array_slice( + array: &arrow_array::DecimalArray, + indices: &[usize], +) -> Vec { + let mut values = Vec::with_capacity(indices.len()); + let size = decimal_length_from_precision(array.precision()); + for i in indices { + let as_be_bytes = array.value(*i).to_be_bytes(); + let resized_value = as_be_bytes[(16 - size)..].to_vec(); + values.push(FixedLenByteArray::from(ByteArray::from(resized_value))); + } + values +} + fn get_fsb_array_slice( array: &arrow_array::FixedSizeBinaryArray, indices: &[usize], @@ -633,6 +656,49 @@ mod tests { } } + #[test] + fn arrow_writer_decimal() { + let decimal_field = Field::new("a", DataType::Decimal(5, 2), false); + let schema = Schema::new(vec![decimal_field]); + + let mut dec_builder = DecimalBuilder::new(4, 5, 2); + dec_builder.append_value(10_000).unwrap(); + dec_builder.append_value(50_000).unwrap(); + dec_builder.append_value(0).unwrap(); + dec_builder.append_value(-100).unwrap(); + + let raw_decimal_i128_values: Vec = vec![10_000, 50_000, 0, -100]; + let decimal_values = dec_builder.finish(); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(decimal_values)], + ) + .unwrap(); + + let mut file = get_temp_file("test_arrow_writer_decimal.parquet", &[]); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + file.seek(std::io::SeekFrom::Start(0)).unwrap(); + let file_reader = SerializedFileReader::new(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader)); + let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + + let batch = record_batch_reader.next().unwrap().unwrap(); + let decimal_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..batch.num_rows() { + assert_eq!(decimal_col.value(i), raw_decimal_i128_values[i]); + } + } + #[test] #[ignore = "See ARROW-11294, data is correct but list field name is incorrect"] fn arrow_writer_complex() { diff --git a/rust/parquet/src/arrow/levels.rs b/rust/parquet/src/arrow/levels.rs index 7a26b05204f..4ea1811d29b 100644 --- a/rust/parquet/src/arrow/levels.rs +++ b/rust/parquet/src/arrow/levels.rs @@ -135,7 +135,8 @@ impl LevelInfo { | DataType::Duration(_) | DataType::Interval(_) | DataType::Binary - | DataType::LargeBinary => { + | DataType::LargeBinary + | DataType::Decimal(_, _) => { // we return a vector of 1 value to represent the primitive vec![self.calculate_child_levels( array_offsets, @@ -145,7 +146,6 @@ impl LevelInfo { )] } DataType::FixedSizeBinary(_) => unimplemented!(), - DataType::Decimal(_, _) => unimplemented!(), DataType::List(list_field) | DataType::LargeList(list_field) => { // Calculate the list level let list_level = self.calculate_child_levels( @@ -188,7 +188,8 @@ impl LevelInfo { | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8 - | DataType::Dictionary(_, _) => { + | DataType::Dictionary(_, _) + | DataType::Decimal(_, _) => { vec![list_level.calculate_child_levels( child_offsets, child_mask, @@ -197,7 +198,6 @@ impl LevelInfo { )] } DataType::FixedSizeBinary(_) => unimplemented!(), - DataType::Decimal(_, _) => unimplemented!(), DataType::List(_) | DataType::LargeList(_) | DataType::Struct(_) => { list_level.calculate_array_levels(&child_array, list_field) } diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index 97e04c5b948..fa973b5cc0e 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -306,6 +306,10 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result usize { + (10.0_f64.powi(precision as i32).log2() / 8.0).ceil() as usize +} + /// Convert an arrow field to a parquet `Type` fn arrow_to_parquet_type(field: &Field) -> Result { let name = field.name().as_str(); @@ -409,13 +413,15 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal(precision, _) => Type::primitive_type_builder( - name, - PhysicalType::FIXED_LEN_BYTE_ARRAY, - ) - .with_repetition(repetition) - .with_length((10.0_f64.powi(*precision as i32).log2() / 8.0).ceil() as i32) - .build(), + DataType::Decimal(precision, scale) => { + Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(repetition) + .with_length(decimal_length_from_precision(*precision) as i32) + .with_logical_type(LogicalType::DECIMAL) + .with_precision(*precision as i32) + .with_scale(*scale as i32) + .build() + } DataType::Utf8 | DataType::LargeUtf8 => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) .with_logical_type(LogicalType::UTF8)