Skip to content

Commit

Permalink
Replace GetFieldAccess with indexing function in SqlToRel (#10375)
Browse files Browse the repository at this point in the history
* use func in parser

Signed-off-by: jayzhan211 <[email protected]>

* add tests

Signed-off-by: jayzhan211 <[email protected]>

* add test

Signed-off-by: jayzhan211 <[email protected]>

* rm test1

Signed-off-by: jayzhan211 <[email protected]>

* parser done

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* fix exprapi test

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* fix conflicts

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored May 14, 2024
1 parent 18fc376 commit b8fab5c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 50 deletions.
14 changes: 4 additions & 10 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ fn test_eq_with_coercion() {

#[test]
fn test_get_field() {
// field access Expr::field() requires a rewrite to work
evaluate_expr_test(
col("props").field("a"),
get_field(col("props"), lit("a")),
vec![
"+------------+",
"| expr |",
Expand All @@ -77,11 +76,8 @@ fn test_get_field() {

#[test]
fn test_nested_get_field() {
// field access Expr::field() requires a rewrite to work, test when it is
// not the root expression
evaluate_expr_test(
col("props")
.field("a")
get_field(col("props"), lit("a"))
.eq(lit("2021-02-02"))
.or(col("id").eq(lit(1))),
vec![
Expand All @@ -98,9 +94,8 @@ fn test_nested_get_field() {

#[test]
fn test_list() {
// list access also requires a rewrite to work
evaluate_expr_test(
col("list").index(lit(1i64)),
array_element(col("list"), lit(1i64)),
vec![
"+------+", "| expr |", "+------+", "| one |", "| two |", "| five |",
"+------+",
Expand All @@ -110,9 +105,8 @@ fn test_list() {

#[test]
fn test_list_range() {
// range access also requires a rewrite to work
evaluate_expr_test(
col("list").range(lit(1i64), lit(2i64)),
array_slice(col("list"), lit(1i64), lit(2i64), None),
vec![
"+--------------+",
"| expr |",
Expand Down
29 changes: 1 addition & 28 deletions datafusion/functions-array/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@

use crate::array_has::array_has_all;
use crate::concat::{array_append, array_concat, array_prepend};
use crate::extract::{array_element, array_slice};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::utils::list_ndims;
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator};
use datafusion_functions::expr_fn::get_field;
use datafusion_expr::{BinaryExpr, Expr, Operator};

/// Rewrites expressions into function calls to array functions
pub(crate) struct ArrayFunctionRewriter {}
Expand Down Expand Up @@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
Transformed::yes(array_prepend(*left, *right))
}

Expr::GetIndexedField(GetIndexedField {
expr,
field: GetFieldAccess::NamedStructField { name },
}) => {
let name = Expr::Literal(name);
Transformed::yes(get_field(*expr, name))
}

// expr[idx] ==> array_element(expr, idx)
Expr::GetIndexedField(GetIndexedField {
expr,
field: GetFieldAccess::ListIndex { key },
}) => Transformed::yes(array_element(*expr, *key)),

// expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
Expr::GetIndexedField(GetIndexedField {
expr,
field:
GetFieldAccess::ListRange {
start,
stop,
stride,
},
}) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))),

_ => Transformed::no(expr),
};
Ok(transformed)
Expand Down
17 changes: 14 additions & 3 deletions datafusion/sql/src/expr/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::Field;
use datafusion_common::{
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
TableReference,
ScalarValue, TableReference,
};
use datafusion_expr::{Case, Expr};
use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr};
use sqlparser::ast::{Expr as SQLExpr, Ident};

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand Down Expand Up @@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
);
}
let nested_name = nested_names[0].to_string();
Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name))

let col = Expr::Column(Column::from((qualifier, field)));
if let Some(udf) =
self.context_provider.get_function_meta("get_field")
{
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![col, lit(ScalarValue::from(nested_name))],
)))
} else {
internal_err!("get_field not found")
}
}
// found matching field with no spare identifier(s)
Some((field, qualifier, _nested_names)) => {
Expand Down
48 changes: 43 additions & 5 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable,
GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast,
GetFieldAccess, Like, Literal, Operator, TryCast,
};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
Expand Down Expand Up @@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
expr
};

Ok(Expr::GetIndexedField(GetIndexedField::new(
Box::new(expr),
self.plan_indices(indices, schema, planner_context)?,
)))
let field = self.plan_indices(indices, schema, planner_context)?;
match field {
GetFieldAccess::NamedStructField { name } => {
if let Some(udf) = self.context_provider.get_function_meta("get_field") {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![expr, lit(name)],
)))
} else {
internal_err!("get_field not found")
}
}
// expr[idx] ==> array_element(expr, idx)
GetFieldAccess::ListIndex { key } => {
if let Some(udf) =
self.context_provider.get_function_meta("array_element")
{
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![expr, *key],
)))
} else {
internal_err!("get_field not found")
}
}
// expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
if let Some(udf) = self.context_provider.get_function_meta("array_slice")
{
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![expr, *start, *stop, *stride],
)))
} else {
internal_err!("array_slice not found")
}
}
}
}
}

Expand Down
114 changes: 110 additions & 4 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2324,28 +2324,134 @@ host3 3.3

# can have an aggregate function with an inner CASE WHEN
query TR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
select
t2.server_host as host,
sum((
case when t2.server_host is not null
then t2.server_load2
end
))
from (
select
struct(time,load1,load2,host)['c2'] as server_load2,
struct(time,load1,load2,host)['c3'] as server_host
from t1
) t2
where server_host IS NOT NULL
group by server_host order by host;
----
host1 101
host2 202
host3 303

# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
query error
select
t2.server['c3'] as host,
sum((
case when t2.server['c3'] is not null
then t2.server['c2']
end
))
from (
select
struct(time,load1,load2,host) as server
from t1
) t2
where t2.server['c3'] IS NOT NULL
group by t2.server['c3'] order by host;

# can have 2 projections with aggr(short_circuited), with different short-circuited expr
query TRR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
select
t2.server_host as host,
sum(coalesce(server_load1)),
sum((
case when t2.server_host is not null
then t2.server_load2
end
))
from (
select
struct(time,load1,load2,host)['c1'] as server_load1,
struct(time,load1,load2,host)['c2'] as server_load2,
struct(time,load1,load2,host)['c3'] as server_host
from t1
) t2
where server_host IS NOT NULL
group by server_host order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
query error
select
t2.server['c3'] as host,
sum(coalesce(server['c1'])),
sum((
case when t2.server['c3'] is not null
then t2.server['c2']
end
))
from (
select
struct(time,load1,load2,host) as server,
from t1
) t2
where server_host IS NOT NULL
group by server_host order by host;

query TRR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
select
t2.server_host as host,
sum((
case when t2.server_host is not null
then server_load1
end
)),
sum((
case when server_host is not null
then server_load2
end
))
from (
select
struct(time,load1,load2,host)['c1'] as server_load1,
struct(time,load1,load2,host)['c2'] as server_load2,
struct(time,load1,load2,host)['c3'] as server_host
from t1
) t2
where server_host IS NOT NULL
group by server_host order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303

# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
query error
select
t2.server['c3'] as host,
sum((
case when t2.server['c3'] is not null
then t2.server['c1']
end
)),
sum((
case when t2.server['c3'] is not null
then t2.server['c2']
end
))
from (
select
struct(time,load1,load2,host) as server
from t1
) t2
where t2.server['c3'] IS NOT NULL
group by t2.server['c3'] order by host;

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
query TRR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
Expand Down

0 comments on commit b8fab5c

Please sign in to comment.