diff --git a/rust/arrow/src/array/array_binary.rs b/rust/arrow/src/array/array_binary.rs index a8fca67197c..fc8cf0ae9ec 100644 --- a/rust/arrow/src/array/array_binary.rs +++ b/rust/arrow/src/array/array_binary.rs @@ -572,6 +572,13 @@ impl DecimalArray { let data = builder.build(); Self::from(data) } + pub fn precision(&self) -> usize { + self.precision + } + + pub fn scale(&self) -> usize { + self.scale + } } impl From for DecimalArray { diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index c0e05d82bbf..18fb478b70c 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -25,10 +25,10 @@ use std::vec::Vec; use arrow::array::{ Array, ArrayData, ArrayDataBuilder, ArrayDataRef, ArrayRef, BinaryArray, - BinaryBuilder, BooleanArray, BooleanBufferBuilder, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, ListBuilder, - OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, - StructArray, + BinaryBuilder, BooleanArray, BooleanBufferBuilder, DecimalBuilder, + FixedSizeBinaryArray, FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, + Int32Array, Int64Array, ListBuilder, OffsetSizeTrait, PrimitiveArray, + PrimitiveBuilder, StringArray, StringBuilder, StructArray, }; use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::{ @@ -350,6 +350,36 @@ impl ArrayReader for PrimitiveArrayReader { let a = arrow::compute::cast(&array, &ArrowType::Date32(DateUnit::Day))?; arrow::compute::cast(&a, target_type)? } + ArrowType::Decimal(p, s) => { + let mut builder = DecimalBuilder::new(array.len(), *p, *s); + match array.data_type() { + ArrowType::Int32 => { + let values = array.as_any().downcast_ref::().unwrap(); + for maybe_value in values.iter() { + match maybe_value { + Some(value) => builder.append_value(value as i128)?, + None => builder.append_null()?, + } + } + } + ArrowType::Int64 => { + let values = array.as_any().downcast_ref::().unwrap(); + for maybe_value in values.iter() { + match maybe_value { + Some(value) => builder.append_value(value as i128)?, + None => builder.append_null()?, + } + } + } + _ => { + return Err(ArrowError(format!( + "Cannot convert {:?} to decimal", + array.data_type() + ))) + } + } + Arc::new(builder.finish()) as ArrayRef + } _ => arrow::compute::cast(&array, target_type)?, }; @@ -1550,20 +1580,10 @@ impl<'a> ArrayReaderBuilder { PhysicalType::FIXED_LEN_BYTE_ARRAY if cur_type.get_basic_info().logical_type() == LogicalType::DECIMAL => { - let (precision, scale) = match *cur_type { - Type::PrimitiveType { - ref precision, - ref scale, - .. - } => (*precision, *scale), - _ => { - return Err(ArrowError( - "Expected a physical type, not a group type".to_string(), - )) - } - }; - let converter = - DecimalConverter::new(DecimalArrayConverter::new(precision, scale)); + let converter = DecimalConverter::new(DecimalArrayConverter::new( + cur_type.get_precision(), + cur_type.get_scale(), + )); Ok(Box::new(ComplexObjectArrayReader::< FixedLenByteArrayType, DecimalConverter, diff --git a/rust/parquet/src/arrow/arrow_reader.rs b/rust/parquet/src/arrow/arrow_reader.rs index 304ba18bc37..1559c97e4cf 100644 --- a/rust/parquet/src/arrow/arrow_reader.rs +++ b/rust/parquet/src/arrow/arrow_reader.rs @@ -406,25 +406,31 @@ mod tests { fn test_read_decimal_file() { use arrow::array::DecimalArray; let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/fixed_length_decimal.parquet", testdata); - let parquet_reader = - SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); - - let mut record_reader = arrow_reader.get_record_reader(32).unwrap(); - - let batch = record_reader.next().unwrap().unwrap(); - assert_eq!(batch.num_rows(), 24); - let col = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - - let expected = 1..25; - - for (i, v) in expected.enumerate() { - assert_eq!(col.value(i), v * 100_i128); + let file_variants = vec![("fixed_length", 25), ("int32", 4), ("int64", 10)]; + for (prefix, target_precision) in file_variants { + let path = format!("{}/{}_decimal.parquet", testdata, prefix); + let parquet_reader = + SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + + let mut record_reader = arrow_reader.get_record_reader(32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 24); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = 1..25; + + assert_eq!(col.precision(), target_precision); + assert_eq!(col.scale(), 2); + + for (i, v) in expected.enumerate() { + assert_eq!(col.value(i), v * 100_i128); + } } } diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index 22213d4f0db..fc5861af1b2 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -591,6 +591,7 @@ impl ParquetTypeConverter<'_> { LogicalType::INT_32 => Ok(DataType::Int32), LogicalType::DATE => Ok(DataType::Date32(DateUnit::Day)), LogicalType::TIME_MILLIS => Ok(DataType::Time32(TimeUnit::Millisecond)), + LogicalType::DECIMAL => Ok(self.to_decimal()), other => Err(ArrowError(format!( "Unable to convert parquet INT32 logical type {}", other @@ -610,6 +611,7 @@ impl ParquetTypeConverter<'_> { LogicalType::TIMESTAMP_MICROS => { Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) } + LogicalType::DECIMAL => Ok(self.to_decimal()), other => Err(ArrowError(format!( "Unable to convert parquet INT64 logical type {}", other @@ -619,21 +621,7 @@ impl ParquetTypeConverter<'_> { fn from_fixed_len_byte_array(&self) -> Result { match self.schema.get_basic_info().logical_type() { - LogicalType::DECIMAL => { - let (precision, scale) = match self.schema { - Type::PrimitiveType { - ref precision, - ref scale, - .. - } => (*precision, *scale), - _ => { - return Err(ArrowError( - "Expected a physical type, not a group type".to_string(), - )) - } - }; - Ok(DataType::Decimal(precision as usize, scale as usize)) - } + LogicalType::DECIMAL => Ok(self.to_decimal()), LogicalType::INTERVAL => { // There is currently no reliable way of determining which IntervalUnit // to return. Thus without the original Arrow schema, the results @@ -657,6 +645,14 @@ impl ParquetTypeConverter<'_> { } } + fn to_decimal(&self) -> DataType { + assert!(self.schema.is_primitive()); + DataType::Decimal( + self.schema.get_precision() as usize, + self.schema.get_scale() as usize, + ) + } + fn from_byte_array(&self) -> Result { match self.schema.get_basic_info().logical_type() { LogicalType::NONE => Ok(DataType::Binary), diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index c9eeaa0f901..27768fbb63e 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -103,6 +103,24 @@ impl Type { } } + /// Gets precision of this primitive type. + /// Note that this will panic if called on a non-primitive type. + pub fn get_precision(&self) -> i32 { + match *self { + Type::PrimitiveType { precision, .. } => precision, + _ => panic!("Cannot call get_precision() on non-primitive type"), + } + } + + /// Gets scale of this primitive type. + /// Note that this will panic if called on a non-primitive type. + pub fn get_scale(&self) -> i32 { + match *self { + Type::PrimitiveType { scale, .. } => scale, + _ => panic!("Cannot call get_scale() on non-primitive type"), + } + } + /// Checks if `sub_type` schema is part of current schema. /// This method can be used to check if projected columns are part of the root schema. pub fn check_contains(&self, sub_type: &Type) -> bool {