Skip to content

Commit

Permalink
Use struct instead of named_struct when there are no aliases (#9897)
Browse files Browse the repository at this point in the history
* Revert "use alias (#9894)"

This reverts commit 9487ca0.

* Use `struct` instead of `named_struct` when there are no aliases

* Update docs

* fmt
  • Loading branch information
alamb committed Apr 2, 2024
1 parent f51fda5 commit a6ff1fe
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 10 deletions.
48 changes: 48 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

/// Parses a struct(..) expression
fn parse_struct(
&self,
values: Vec<SQLExpr>,
Expand All @@ -599,6 +600,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if !fields.is_empty() {
return not_impl_err!("Struct fields are not supported yet");
}

if values
.iter()
.any(|value| matches!(value, SQLExpr::Named { .. }))
{
self.create_named_struct(values, input_schema, planner_context)
} else {
self.create_struct(values, input_schema, planner_context)
}
}

// Handles a call to struct(...) where the arguments are named. For example
// `struct (v as foo, v2 as bar)` by creating a call to the `named_struct` function
fn create_named_struct(
&self,
values: Vec<SQLExpr>,
input_schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let args = values
.into_iter()
.enumerate()
Expand Down Expand Up @@ -643,6 +663,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)))
}

// Handles a call to struct(...) where the arguments are not named. For example
// `struct (v, v2)` by creating a call to the `struct` function
// which will create a struct with fields named `c0`, `c1`, etc.
fn create_struct(
&self,
values: Vec<SQLExpr>,
input_schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let args = values
.into_iter()
.map(|value| {
self.sql_expr_to_logical_expr(value, input_schema, planner_context)
})
.collect::<Result<Vec<_>>>()?;
let struct_func = self
.context_provider
.get_function_meta("struct")
.ok_or_else(|| {
internal_datafusion_err!("Unable to find expected 'struct' function")
})?;

Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
struct_func,
args,
)))
}

fn parse_array_agg(
&self,
array_agg: ArrayAgg,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ query TT
explain select struct(1, 2.3, 'abc');
----
logical_plan
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc"))
--EmptyRelation
physical_plan
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))]
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))]
--PlaceholderRowExec
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2288,39 +2288,39 @@ select struct(time,load1,load2,host) from t1;

# can have an aggregate function with an inner coalesce
query TR
select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host;
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1
host2 2.2
host3 3.3

# can have an aggregate function with an inner CASE WHEN
query TR
select t2.info['c3'] as host, sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host;
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 101
host2 202
host3 303

# can have 2 projections with aggr(short_circuited), with different short-circuited expr
query TRR
select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']), sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host;
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
query TRR
select t2.info['c3'] as host, sum((case when t2.info['c3'] is not null then t2.info end)['c1']), sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host;
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
query TRR
select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']), sum(coalesce(t2.info)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host;
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1 101
host2 2.2 202
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/struct.slt
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ query TT
explain select struct(a, b, c) from values;
----
logical_plan
Projection: named_struct(Utf8("c0"), values.a, Utf8("c1"), values.b, Utf8("c2"), values.c)
Projection: struct(values.a, values.b, values.c)
--TableScan: values projection=[a, b, c]
physical_plan
ProjectionExec: expr=[named_struct(c0, a@0, c1, b@1, c2, c@2) as named_struct(Utf8("c0"),values.a,Utf8("c1"),values.b,Utf8("c2"),values.c)]
ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)]
--MemoryExec: partitions=1, partition_sizes=[1]

# error on 0 arguments
Expand Down Expand Up @@ -179,4 +179,4 @@ drop table values;
query T
select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3));
----
Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
10 changes: 10 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -3336,6 +3336,16 @@ select * from t;
| 3 | 4 |
+---+---+
-- use default names `c0`, `c1`
❯ select struct(a, b) from t;
+-----------------+
| struct(t.a,t.b) |
+-----------------+
| {c0: 1, c1: 2} |
| {c0: 3, c1: 4} |
+-----------------+
-- name the first field `field_a`
select struct(a as field_a, b) from t;
+--------------------------------------------------+
| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) |
Expand Down

0 comments on commit a6ff1fe

Please sign in to comment.