Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Substrait List/EmptyList literals #10615

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
Expand Down Expand Up @@ -1058,7 +1059,7 @@ pub async fn from_substrait_rex(
}
}

fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
Expand Down Expand Up @@ -1138,7 +1139,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
from_substrait_type(list.r#type.as_ref().ok_or_else(|| {
substrait_datafusion_err!("List type must have inner type")
})?)?;
let field = Arc::new(Field::new("list_item", inner_type, true));
let field = Arc::new(Field::new_list_field(inner_type, true));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change in the sense that the new field name is just "item" - to align with Arrow default

match list.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)),
LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)),
Expand Down Expand Up @@ -1278,6 +1279,45 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
s,
)
}
Some(LiteralType::List(l)) => {
let elements = l
.values
.iter()
.map(from_substrait_literal)
.collect::<Result<Vec<_>>>()?;
if elements.is_empty() {
return substrait_err!(
"Empty list must be encoded as EmptyList literal type, not List"
);
}
let element_type = elements[0].data_type();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check if elements are empty and report an error? The literal input might come from systems other than DataFusion, and they might not be properly implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, done: cdc525c

match lit.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => ScalarValue::List(ScalarValue::new_list(
elements.as_slice(),
&element_type,
)),
LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
ScalarValue::new_large_list(elements.as_slice(), &element_type),
),
others => {
return substrait_err!("Unknown type variation reference {others}");
}
}
}
Some(LiteralType::EmptyList(l)) => {
let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can remove unwrap, it can become more robust. 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, without the type specified we don't know what it should be - I guess we could default to NullType (which I think is what DataFusion does if you just do a "SELECT [] FROM ..", do you think that'd make sense?

I feel like Substrait probably intends this field to always exist, though I'm not sure, but e.g. in the Java library they have it as required: https://github.com/substrait-io/substrait-java/blob/79decd20e85d6a1a5623890042ebcf1415cf784a/core/src/main/java/io/substrait/expression/Expression.java#L451

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can return an error like "invalid parameter", but it may not be necessary to do so. Let's keep it as it is for now until someone requests this behavior.

match lit.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => {
ScalarValue::List(ScalarValue::new_list(&[], &element_type))
}
LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
ScalarValue::new_large_list(&[], &element_type),
),
others => {
return substrait_err!("Unknown type variation reference {others}");
}
}
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
};
Expand Down Expand Up @@ -1361,7 +1401,24 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
d.precision as u8,
d.scale as i8,
)),
_ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
r#type::Kind::List(l) => {
let field = Field::new_list_field(
from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
true,
);
match l.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::List(Arc::new(
GenericListArray::new_null(field.into(), 1),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the correct way for creating null lists, or is there something better? The list-of-lists structure ScalarValue::List uses is a bit confusing to me..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is correct.

))),
LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeList(Arc::new(
GenericListArray::new_null(field.into(), 1),
))),
v => not_impl_err!(
"Unsupported Substrait type variation {v} of type {kind:?}"
),
}
}
_ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"),
}
} else {
not_impl_err!("Null type without kind is not supported")
Expand Down
145 changes: 127 additions & 18 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion::{
scalar::ScalarValue,
};

use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{exec_err, internal_err, not_impl_err};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
Expand All @@ -42,6 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::List;
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::{CrossRel, ExchangeRel};
Expand Down Expand Up @@ -1100,7 +1102,7 @@ pub fn to_substrait_rex(
))),
})
}
Expr::Literal(value) => to_substrait_literal(value),
Expr::Literal(value) => to_substrait_literal_expr(value),
Expr::Alias(Alias { expr, .. }) => {
to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
}
Expand Down Expand Up @@ -1526,8 +1528,9 @@ fn make_substrait_like_expr(
};
let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?;
let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?;
let escape_char =
to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?;
let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8(
escape_char.map(|c| c.to_string()),
))?;
let arguments = vec![
FunctionArgument {
arg_type: Some(ArgType::Value(expr)),
Expand Down Expand Up @@ -1683,7 +1686,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
))
}

fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
let (literal_type, type_variation_reference) = match value {
ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF),
ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF),
Expand Down Expand Up @@ -1741,15 +1744,50 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
}),
DECIMAL_128_TYPE_REF,
),
ScalarValue::List(l) if !value.is_null() => (
convert_array_to_literal_list(l)?,
DEFAULT_CONTAINER_TYPE_REF,
),
ScalarValue::LargeList(l) if !value.is_null() => {
(convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
}
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
};

Ok(Literal {
nullable: true,
type_variation_reference,
literal_type: Some(literal_type),
})
}

fn convert_array_to_literal_list<T: OffsetSizeTrait>(
array: &GenericListArray<T>,
) -> Result<LiteralType> {
assert_eq!(array.len(), 1);
let nested_array = array.value(0);

let values = (0..nested_array.len())
.map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?))
.collect::<Result<Vec<_>>>()?;

if values.is_empty() {
let et = match to_substrait_type(array.data_type())? {
substrait::proto::Type {
kind: Some(r#type::Kind::List(lt)),
} => lt.as_ref().to_owned(),
_ => unreachable!(),
};
Ok(LiteralType::EmptyList(et))
} else {
Ok(LiteralType::List(List { values }))
}
}

fn to_substrait_literal_expr(value: &ScalarValue) -> Result<Expression> {
let literal = to_substrait_literal(value)?;
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
type_variation_reference,
literal_type: Some(literal_type),
})),
rex_type: Some(RexType::Literal(literal)),
})
}

Expand Down Expand Up @@ -1937,6 +1975,10 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
})),
}))
}
ScalarValue::List(l) => Ok(LiteralType::Null(to_substrait_type(l.data_type())?)),
ScalarValue::LargeList(l) => {
Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
}
// TODO: Extend support for remaining data types
_ => not_impl_err!("Unsupported literal: {v:?}"),
}
Expand Down Expand Up @@ -2016,7 +2058,9 @@ fn substrait_field_ref(index: usize) -> Result<Expression> {

#[cfg(test)]
mod test {
use crate::logical_plan::consumer::from_substrait_literal;
use crate::logical_plan::consumer::{from_substrait_literal, from_substrait_type};
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::Field;

use super::*;

Expand Down Expand Up @@ -2054,22 +2098,87 @@ mod test {
round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;

round_trip_literal(ScalarValue::List(ScalarValue::new_list(
&[ScalarValue::Float32(Some(1.0))],
&DataType::Float32,
)))?;
round_trip_literal(ScalarValue::List(ScalarValue::new_list(
&[],
&DataType::Float32,
)))?;
round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null(
Field::new_list_field(DataType::Float32, true).into(),
1,
))))?;
round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
&[ScalarValue::Float32(Some(1.0))],
&DataType::Float32,
)))?;
round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
&[],
&DataType::Float32,
)))?;
round_trip_literal(ScalarValue::LargeList(Arc::new(
GenericListArray::new_null(
Field::new_list_field(DataType::Float32, true).into(),
1,
),
)))?;

Ok(())
}

fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {scalar:?}");

let substrait = to_substrait_literal(&scalar)?;
let Expression {
rex_type: Some(RexType::Literal(substrait_literal)),
} = substrait
else {
panic!("Expected Literal expression, got {substrait:?}");
};

let substrait_literal = to_substrait_literal(&scalar)?;
let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}

#[test]
fn round_trip_types() -> Result<()> {
round_trip_type(DataType::Boolean)?;
round_trip_type(DataType::Int8)?;
round_trip_type(DataType::UInt8)?;
round_trip_type(DataType::Int16)?;
round_trip_type(DataType::UInt16)?;
round_trip_type(DataType::Int32)?;
round_trip_type(DataType::UInt32)?;
round_trip_type(DataType::Int64)?;
round_trip_type(DataType::UInt64)?;
round_trip_type(DataType::Float32)?;
round_trip_type(DataType::Float64)?;
round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?;
round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?;
round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?;
round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?;
round_trip_type(DataType::Date32)?;
round_trip_type(DataType::Date64)?;
round_trip_type(DataType::Binary)?;
round_trip_type(DataType::FixedSizeBinary(10))?;
round_trip_type(DataType::LargeBinary)?;
round_trip_type(DataType::Utf8)?;
round_trip_type(DataType::LargeUtf8)?;
round_trip_type(DataType::Decimal128(10, 2))?;
round_trip_type(DataType::Decimal256(30, 2))?;
round_trip_type(DataType::List(
Field::new_list_field(DataType::Int32, true).into(),
))?;
round_trip_type(DataType::LargeList(
Field::new_list_field(DataType::Int32, true).into(),
))?;

Ok(())
}

fn round_trip_type(dt: DataType) -> Result<()> {
println!("Checking round trip of {dt:?}");

let substrait = to_substrait_type(&dt)?;
let roundtrip_dt = from_substrait_type(&substrait)?;
assert_eq!(dt, roundtrip_dt);
Ok(())
}
}
10 changes: 10 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,16 @@ async fn all_type_literal() -> Result<()> {
.await
}

#[tokio::test]
async fn roundtrip_literal_list() -> Result<()> {
assert_expected_plan(
"SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
"Projection: List([[1, 2, 3], [], , []])\
\n TableScan: data projection=[]",
)
.await
}

/// Construct a plan that cast columns. Only those SQL types are supported for now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
Expand Down