Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, U
use arrow_schema::DataType;
use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};
use jiter::{Jiter, JiterError, JsonError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};

Expand Down Expand Up @@ -230,6 +230,12 @@ impl From<JiterError> for GetError {
}
}

impl From<JsonError> for GetError {
fn from(_: JsonError) -> Self {
GetError
}
}

impl From<Utf8Error> for GetError {
fn from(_: Utf8Error) -> Self {
GetError
Expand Down
14 changes: 8 additions & 6 deletions src/json_get_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ impl ScalarUDFImpl for JsonGetFloat {

fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result<f64, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) {
match peek {
let n = match peek {
// Peek::String => NumberAny::try_from(jiter.next_bytes()?)?,
// numbers are represented by everything else in peek, hence doing it this way
Peek::Null
| Peek::True
Expand All @@ -75,11 +76,12 @@ fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result<f6
| Peek::NaN
| Peek::String
| Peek::Array
| Peek::Object => get_err!(),
_ => match jiter.known_number(peek)? {
NumberAny::Float(f) => Ok(f),
NumberAny::Int(int) => Ok(int.into()),
},
| Peek::Object => return get_err!(),
_ => jiter.known_number(peek)?,
};
match n {
NumberAny::Float(f) => Ok(f),
NumberAny::Int(int) => Ok(int.into()),
}
} else {
get_err!()
Expand Down
15 changes: 8 additions & 7 deletions src/json_get_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ impl ScalarUDFImpl for JsonGetInt {

fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result<i64, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) {
match peek {
let n = match peek {
Peek::String => NumberInt::try_from(jiter.next_bytes()?)?,
// numbers are represented by everything else in peek, hence doing it this way
Peek::Null
| Peek::True
| Peek::False
| Peek::Minus
| Peek::Infinity
| Peek::NaN
| Peek::String
| Peek::Array
| Peek::Object => get_err!(),
_ => match jiter.known_int(peek)? {
NumberInt::Int(i) => Ok(i),
NumberInt::BigInt(_) => get_err!(),
},
| Peek::Object => return get_err!(),
_ => jiter.known_int(peek)?,
};
match n {
NumberInt::Int(i) => Ok(i),
NumberInt::BigInt(_) => get_err!(),
}
} else {
get_err!()
Expand Down
2 changes: 1 addition & 1 deletion src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
Expr::Alias(alias) => extract_scalar_function(&alias.expr),
_ => None,
}
}
Expand Down
22 changes: 22 additions & 0 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,8 +1131,30 @@ async fn test_long_arrow_cast() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_cast_numeric() {
let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#;
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string()));
}

#[tokio::test]
async fn test_json_get_int_string() {
let sql = r#"select json_get_int('{"foo": "420"}'->'foo')"#;
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Int64, "420".to_string()));
}

// #[tokio::test]
// async fn test_json_get_float_string() {
// let sql = r#"select json_get_float('{"foo": "420.123"}'->'foo')"#;
// let batches = run_query(sql).await.unwrap();
// assert_eq!(display_val(batches).await, (DataType::Int64, "420.123".to_string()));
// }
//
// #[tokio::test]
// async fn test_json_get_float_string_2() {
// let sql = r#"select json_get_float('{"foo": "420"}'->'foo')"#;
// let batches = run_query(sql).await.unwrap();
// assert_eq!(display_val(batches).await, (DataType::Int64, "420".to_string()));
// }