diff --git a/rust/arrow/src/util/pretty.rs b/rust/arrow/src/util/pretty.rs index dc564ddb4aa..3c23e28f7ed 100644 --- a/rust/arrow/src/util/pretty.rs +++ b/rust/arrow/src/util/pretty.rs @@ -19,9 +19,13 @@ use crate::array; use crate::array::{Array, PrimitiveArrayOps}; -use crate::datatypes::{DataType, TimeUnit}; +use crate::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; use crate::record_batch::RecordBatch; +use array::DictionaryArray; use prettytable::format; use prettytable::{Cell, Row, Table}; @@ -60,7 +64,7 @@ fn create_table(results: &[RecordBatch]) -> Result { let mut cells = Vec::new(); for col in 0..batch.num_columns() { let column = batch.column(col); - cells.push(Cell::new(&array_value_to_string(column.clone(), row)?)); + cells.push(Cell::new(&array_value_to_string(&column, row)?)); } table.add_row(Row::new(cells)); } @@ -83,8 +87,8 @@ macro_rules! make_string { }}; } -/// Get the value at the given row in an array as a string -fn array_value_to_string(column: array::ArrayRef, row: usize) -> Result { +/// Get the value at the given row in an array as a String +pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result { match column.data_type() { DataType::Utf8 => make_string!(array::StringArray, column, row), DataType::Boolean => make_string!(array::BooleanArray, column, row), @@ -124,15 +128,55 @@ fn array_value_to_string(column: array::ArrayRef, row: usize) -> Result DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => { make_string!(array::Time64NanosecondArray, column, row) } + DataType::Dictionary(index_type, _value_type) => match **index_type { + DataType::Int8 => dict_array_value_to_string::(column, row), + DataType::Int16 => dict_array_value_to_string::(column, row), + DataType::Int32 => dict_array_value_to_string::(column, row), + DataType::Int64 => dict_array_value_to_string::(column, row), + DataType::UInt8 => dict_array_value_to_string::(column, row), + DataType::UInt16 => dict_array_value_to_string::(column, row), + DataType::UInt32 => dict_array_value_to_string::(column, row), + DataType::UInt64 => dict_array_value_to_string::(column, row), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Pretty printing not supported for {:?} due to index type", + column.data_type() + ))), + }, _ => Err(ArrowError::InvalidArgumentError(format!( - "Unsupported {:?} type for repl.", + "Pretty printing not implemented for {:?} type", column.data_type() ))), } } +/// Converts the value of the dictionary array at `row` to a String +fn dict_array_value_to_string( + colum: &array::ArrayRef, + row: usize, +) -> Result { + let dict_array = colum.as_any().downcast_ref::>().unwrap(); + + let keys_array = dict_array.keys_array(); + + if keys_array.is_null(row) { + return Ok(String::from("")); + } + + let dict_index = keys_array.value(row).to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Can not convert value {:?} at index {:?} to usize for repl.", + keys_array.value(row), + row + )) + })?; + + array_value_to_string(&dict_array.values(), dict_index) +} + #[cfg(test)] mod tests { + use array::{PrimitiveBuilder, StringBuilder, StringDictionaryBuilder}; + use super::*; use crate::datatypes::{Field, Schema}; use std::sync::Arc; @@ -183,4 +227,41 @@ mod tests { Ok(()) } + + #[test] + fn test_pretty_format_dictionary() -> Result<()> { + // define a schema. + let field_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + builder.append("one")?; + builder.append_null()?; + builder.append("three")?; + let array = Arc::new(builder.finish()); + + let batch = RecordBatch::try_new(schema.clone(), vec![array])?; + + let table = pretty_format_batches(&[batch])?; + + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| one |", + "| |", + "| three |", + "+-------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{}", table); + + Ok(()) + } }