Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,21 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
Date32(DateUnit::Day) => {
use chrono::Datelike;
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = PrimitiveBuilder::<Date32Type>::new(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
builder.append_null()?;
} else {
match string_array.value(i).parse::<chrono::NaiveDate>() {
Ok(date) => builder.append_value(
date.num_days_from_ce() - EPOCH_DAYS_FROM_CE,
)?,
Err(_) => builder.append_null()?, // not a valid date
};
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
let array = (0..string_array.len())
.map(|i| {
if string_array.is_null(i) {
None
} else {
match string_array.value(i).parse::<chrono::NaiveDate>() {
Ok(date) => {
Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
}
Err(_) => None, // not a valid date
}
}
})
.collect::<Date32Array>();
Ok(Arc::new(array) as ArrayRef)
}
Date64(DateUnit::Millisecond) => {
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
Expand Down
236 changes: 22 additions & 214 deletions rust/arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ use std::sync::Arc;

use csv as csv_crate;

use crate::array::{ArrayRef, BooleanArray, PrimitiveArray, StringArray};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use crate::{
array::{ArrayRef, BooleanArray, StringArray},
compute::cast,
};

use self::csv_crate::{ByteRecord, StringRecord};

Expand Down Expand Up @@ -413,51 +416,24 @@ fn parse(
let field = &fields[i];
match field.data_type() {
&DataType::Boolean => build_boolean_array(line_number, rows, i),
&DataType::Int8 => {
build_primitive_array::<Int8Type>(line_number, rows, i)
}
&DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i)
}
&DataType::Int32 => {
build_primitive_array::<Int32Type>(line_number, rows, i)
}
&DataType::Int64 => {
build_primitive_array::<Int64Type>(line_number, rows, i)
}
&DataType::UInt8 => {
build_primitive_array::<UInt8Type>(line_number, rows, i)
}
&DataType::UInt16 => {
build_primitive_array::<UInt16Type>(line_number, rows, i)
}
&DataType::UInt32 => {
build_primitive_array::<UInt32Type>(line_number, rows, i)
}
&DataType::UInt64 => {
build_primitive_array::<UInt64Type>(line_number, rows, i)
}
&DataType::Float32 => {
build_primitive_array::<Float32Type>(line_number, rows, i)
}
&DataType::Float64 => {
build_primitive_array::<Float64Type>(line_number, rows, i)
}
&DataType::Date32(_) => {
build_primitive_array::<Date32Type>(line_number, rows, i)
}
&DataType::Date64(_) => {
build_primitive_array::<Date64Type>(line_number, rows, i)
}
&DataType::Timestamp(TimeUnit::Microsecond, _) => {
build_primitive_array::<TimestampMicrosecondType>(
line_number,
rows,
i,
)
}
&DataType::Timestamp(TimeUnit::Nanosecond, _) => {
build_primitive_array::<TimestampNanosecondType>(line_number, rows, i)
&DataType::Int8
| &DataType::Int16
| &DataType::Int32
| &DataType::Int64
| &DataType::UInt8
| &DataType::UInt16
| &DataType::UInt32
| &DataType::UInt64
| &DataType::Timestamp(TimeUnit::Nanosecond, _)
| &DataType::Timestamp(TimeUnit::Microsecond, _)
| &DataType::Date64(_)
| &DataType::Date32(_)
| &DataType::Float64
| &DataType::Float32 => {
let string_array = Arc::new(
rows.iter().map(|row| row.get(i)).collect::<StringArray>(),
) as ArrayRef;
cast(&string_array, field.data_type())
}
&DataType::Utf8 => Ok(Arc::new(
rows.iter().map(|row| row.get(i)).collect::<StringArray>(),
Expand All @@ -478,97 +454,6 @@ fn parse(
arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr))
}

/// Specialized parsing implementations
trait Parser: ArrowPrimitiveType {
fn parse(string: &str) -> Option<Self::Native> {
string.parse::<Self::Native>().ok()
}
}

impl Parser for Float32Type {
fn parse(string: &str) -> Option<f32> {
lexical_core::parse(string.as_bytes()).ok()
}
}
impl Parser for Float64Type {
fn parse(string: &str) -> Option<f64> {
lexical_core::parse(string.as_bytes()).ok()
}
}

impl Parser for UInt64Type {}

impl Parser for UInt32Type {}

impl Parser for UInt16Type {}

impl Parser for UInt8Type {}

impl Parser for Int64Type {}

impl Parser for Int32Type {}

impl Parser for Int16Type {}

impl Parser for Int8Type {}

/// Number of days between 0001-01-01 and 1970-01-01
const EPOCH_DAYS_FROM_CE: i32 = 719_163;

impl Parser for Date32Type {
fn parse(string: &str) -> Option<i32> {
use chrono::Datelike;

match Self::DATA_TYPE {
DataType::Date32(DateUnit::Day) => {
let date = string.parse::<chrono::NaiveDate>().ok()?;
Self::Native::from_i32(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
}
_ => None,
}
}
}

impl Parser for Date64Type {
fn parse(string: &str) -> Option<i64> {
match Self::DATA_TYPE {
DataType::Date64(DateUnit::Millisecond) => {
let date_time = string.parse::<chrono::NaiveDateTime>().ok()?;
Self::Native::from_i64(date_time.timestamp_millis())
}
_ => None,
}
}
}

impl Parser for TimestampNanosecondType {
fn parse(string: &str) -> Option<i64> {
match Self::DATA_TYPE {
DataType::Timestamp(TimeUnit::Nanosecond, None) => {
let date_time = string.parse::<chrono::NaiveDateTime>().ok()?;
Self::Native::from_i64(date_time.timestamp_nanos())
}
_ => None,
}
}
}

impl Parser for TimestampMicrosecondType {
fn parse(string: &str) -> Option<i64> {
match Self::DATA_TYPE {
DataType::Timestamp(TimeUnit::Microsecond, None) => {
let date_time = string.parse::<chrono::NaiveDateTime>().ok()?;
Self::Native::from_i64(date_time.timestamp_nanos() / 1000)
}
_ => None,
}
}
}

fn parse_item<T: Parser>(string: &str) -> Option<T::Native> {
T::parse(string)
}

fn parse_bool(string: &str) -> Option<bool> {
if string.eq_ignore_ascii_case("false") {
Some(false)
Expand All @@ -579,40 +464,6 @@ fn parse_bool(string: &str) -> Option<bool> {
}
}

// parses a specific column (col_idx) into an Arrow Array.
fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
rows: &[StringRecord],
col_idx: usize,
) -> Result<ArrayRef> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
match row.get(col_idx) {
Some(s) => {
if s.is_empty() {
return Ok(None);
}

let parsed = parse_item::<T>(s);
match parsed {
Some(e) => Ok(Some(e)),
None => Err(ArrowError::ParseError(format!(
// TODO: we should surface the underlying error here.
"Error while parsing value {} for column {} at line {}",
s,
col_idx,
line_number + row_index
))),
}
}
None => Ok(None),
}
})
.collect::<Result<PrimitiveArray<T>>>()
.map(|e| Arc::new(e) as ArrayRef)
}

// parses a specific column (col_idx) into an Arrow Array.
fn build_boolean_array(
line_number: usize,
Expand Down Expand Up @@ -1098,30 +949,6 @@ mod tests {
);
}

#[test]
fn parse_date32() {
assert_eq!(parse_item::<Date32Type>("1970-01-01").unwrap(), 0);
assert_eq!(parse_item::<Date32Type>("2020-03-15").unwrap(), 18336);
assert_eq!(parse_item::<Date32Type>("1945-05-08").unwrap(), -9004);
}

#[test]
fn parse_date64() {
assert_eq!(parse_item::<Date64Type>("1970-01-01T00:00:00").unwrap(), 0);
assert_eq!(
parse_item::<Date64Type>("2018-11-13T17:11:10").unwrap(),
1542129070000
);
assert_eq!(
parse_item::<Date64Type>("2018-11-13T17:11:10.011").unwrap(),
1542129070011
);
assert_eq!(
parse_item::<Date64Type>("1900-02-28T12:34:56").unwrap(),
-2203932304000
);
}

#[test]
fn test_infer_schema_from_multiple_files() -> Result<()> {
let mut csv1 = NamedTempFile::new()?;
Expand Down Expand Up @@ -1230,23 +1057,4 @@ mod tests {
assert_eq!(None, parse_bool("F"));
assert_eq!(None, parse_bool(""));
}

#[test]
fn test_parsing_float() {
assert_eq!(Some(12.34), parse_item::<Float64Type>("12.34"));
assert_eq!(Some(-12.34), parse_item::<Float64Type>("-12.34"));
assert_eq!(Some(12.0), parse_item::<Float64Type>("12"));
assert_eq!(Some(0.0), parse_item::<Float64Type>("0"));
assert!(parse_item::<Float64Type>("nan").unwrap().is_nan());
assert!(parse_item::<Float64Type>("NaN").unwrap().is_nan());
assert!(parse_item::<Float64Type>("inf").unwrap().is_infinite());
assert!(parse_item::<Float64Type>("inf").unwrap().is_sign_positive());
assert!(parse_item::<Float64Type>("-inf").unwrap().is_infinite());
assert!(parse_item::<Float64Type>("-inf")
.unwrap()
.is_sign_negative());
assert_eq!(None, parse_item::<Float64Type>(""));
assert_eq!(None, parse_item::<Float64Type>("dd"));
assert_eq!(None, parse_item::<Float64Type>("12.34.56"));
}
}