diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index 985c88b4978f..9fafc38a09ba 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -51,7 +51,9 @@ use std::sync::Arc; use csv as csv_crate; -use crate::array::{ArrayRef, BooleanArray, PrimitiveArray, StringArray}; +use crate::array::{ + ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray, +}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; @@ -385,7 +387,7 @@ impl Iterator for Reader { "Error parsing line {}: {:?}", self.line_number + i, e - )))) + )))); } } } @@ -429,56 +431,102 @@ fn parse( let i = *i; let field = &fields[i]; match field.data_type() { - &DataType::Boolean => build_boolean_array(line_number, rows, i), - &DataType::Int8 => { - build_primitive_array::(line_number, rows, i) - } - &DataType::Int16 => { + DataType::Boolean => build_boolean_array(line_number, rows, i), + DataType::Int8 => build_primitive_array::(line_number, rows, i), + DataType::Int16 => { build_primitive_array::(line_number, rows, i) } - &DataType::Int32 => { + DataType::Int32 => { build_primitive_array::(line_number, rows, i) } - &DataType::Int64 => { + DataType::Int64 => { build_primitive_array::(line_number, rows, i) } - &DataType::UInt8 => { + DataType::UInt8 => { build_primitive_array::(line_number, rows, i) } - &DataType::UInt16 => { + DataType::UInt16 => { build_primitive_array::(line_number, rows, i) } - &DataType::UInt32 => { + DataType::UInt32 => { build_primitive_array::(line_number, rows, i) } - &DataType::UInt64 => { + DataType::UInt64 => { build_primitive_array::(line_number, rows, i) } - &DataType::Float32 => { + DataType::Float32 => { build_primitive_array::(line_number, rows, i) } - &DataType::Float64 => { + DataType::Float64 => { build_primitive_array::(line_number, rows, i) } - &DataType::Date32 => { + DataType::Date32 => { build_primitive_array::(line_number, rows, i) } - &DataType::Date64 => { + DataType::Date64 => { build_primitive_array::(line_number, rows, i) } - &DataType::Timestamp(TimeUnit::Microsecond, _) => { - build_primitive_array::( - line_number, - rows, - i, - ) - } - &DataType::Timestamp(TimeUnit::Nanosecond, _) => { + DataType::Timestamp(TimeUnit::Microsecond, _) => build_primitive_array::< + TimestampMicrosecondType, + >( + line_number, rows, i + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => { build_primitive_array::(line_number, rows, i) } - &DataType::Utf8 => Ok(Arc::new( + DataType::Utf8 => Ok(Arc::new( rows.iter().map(|row| row.get(i)).collect::(), ) as ArrayRef), + DataType::Dictionary(key_type, value_type) + if value_type.as_ref() == &DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + _ => Err(ArrowError::ParseError(format!( + "Unsupported dictionary key type {:?}", + key_type + ))), + } + } other => Err(ArrowError::ParseError(format!( "Unsupported data type {:?}", other @@ -510,6 +558,7 @@ impl Parser for Float32Type { lexical_core::parse(string.as_bytes()).ok() } } + impl Parser for Float64Type { fn parse(string: &str) -> Option { lexical_core::parse(string.as_bytes()).ok() @@ -814,6 +863,7 @@ mod tests { use tempfile::NamedTempFile; use crate::array::*; + use crate::compute::cast; use crate::datatypes::Field; #[test] @@ -1021,6 +1071,51 @@ mod tests { assert_eq!(2, batch.num_columns()); } + #[test] + fn test_csv_with_dictionary() { + let schema = Schema::new(vec![ + Field::new( + "city", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = Reader::new( + file, + Arc::new(schema), + false, + None, + 1024, + None, + Some(vec![0, 1]), + ); + let projected_schema = Arc::new(Schema::new(vec![ + Field::new( + "city", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(projected_schema, batch.schema()); + assert_eq!(37, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + + let strings = cast(batch.column(0), &DataType::Utf8).unwrap(); + let strings = strings.as_any().downcast_ref::().unwrap(); + + assert_eq!(strings.value(0), "Elgin, Scotland, the UK"); + assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK"); + assert_eq!(strings.value(29), "Uckfield, East Sussex, UK"); + } + #[test] fn test_nulls() { let schema = Schema::new(vec![