Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-field structs from builtin scalar fn incompatible with UDAF #7012

Open
alexwilcoxson-rel opened this issue Jul 18, 2023 · 4 comments
Open
Labels
bug Something isn't working

Comments

@alexwilcoxson-rel
Copy link

alexwilcoxson-rel commented Jul 18, 2023

Describe the bug

We have a use case to provide multiple column values to a UDAF. UDAFs support one column input (unless I'm mistaken, I'm looking at this supporting one input data type. this has been resolved by #7096


To work around this we tried packing the columns into a struct column and passing that as input into the UDAF but we're seeing an error with both SQL API struct() builtin and the Expr API BuiltInScalarFunction::Struct

To Reproduce

run the tests below and see following output

Failures

failures:

---- tests::test_udaf_pack_many_col_struct_sql stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Caused by:
    Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

---- tests::test_udaf_pack_many_col_struct_expr stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Caused by:
    Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Table

cargo test -- --nocapture shows the table

+----+-------+------------+-------------------+---------+
| a  | b     | c          | d                 | e       |
+----+-------+------------+-------------------+---------+
| 12 | true  | hi         | {i: 12, j: true}  | {i: 12} |
| 11 | false | datafusion | {i: 11, j: false} | {i: 11} |
+----+-------+------------+-------------------+---------+

Tests

use datafusion::{physical_plan::Accumulator, scalar::ScalarValue};

#[tokio::main]
async fn main() {}

#[derive(Default, Debug)]
struct SumUdaf {
    sum: u32,
}

impl Accumulator for SumUdaf {
    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let arr = &values[0];
        (0..arr.len()).try_for_each(|index| {
            let sv = ScalarValue::try_from_array(&arr, index)?;
            if let ScalarValue::Struct(Some(values), _) = sv {
                for v in values {
                    if let ScalarValue::Int32(Some(v)) = v {
                        self.sum += v as u32;
                    }
                }
            } else if let ScalarValue::Int32(Some(v)) = sv {
                self.sum += v as u32;
            }
            Ok(())
        })
    }

    fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
        Ok(ScalarValue::from(self.sum))
    }

    fn size(&self) -> usize {
        std::mem::size_of_val(self)
    }

    fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
        Ok(vec![ScalarValue::from(self.sum)])
    }

    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
        if states.is_empty() {
            return Ok(());
        }

        let arr = &states[0];

        (0..arr.len()).try_for_each(|index| {
            if let ScalarValue::UInt32(Some(v)) = ScalarValue::try_from_array(arr, index)? {
                self.sum += v;
            } else {
                unreachable!("")
            }
            Ok(())
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;

    use arrow::{
        array::{
            downcast_array, ArrayBuilder, BooleanBuilder, Int32Builder, StringBuilder,
            StructBuilder, UInt32Array,
        },
        datatypes::{
            DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef,
            Fields as ArrowFields, Schema as ArrowSchema,
        },
        record_batch::RecordBatch,
    };
    use datafusion::{
        logical_expr::{expr::ScalarFunction, AggregateUDF, BuiltinScalarFunction},
        prelude::*,
    };

    fn test_data() -> anyhow::Result<RecordBatch> {
        let d_fields: Vec<ArrowFieldRef> = vec![
            Arc::new(ArrowField::new("i", ArrowDataType::Int32, false)),
            Arc::new(ArrowField::new("j", ArrowDataType::Boolean, false)),
        ];
        let e_fields: Vec<ArrowFieldRef> =
            vec![Arc::new(ArrowField::new("i", ArrowDataType::Int32, false))];
        let schema = ArrowSchema::new(vec![
            ArrowField::new("a", ArrowDataType::Int32, false),
            ArrowField::new("b", ArrowDataType::Boolean, false),
            ArrowField::new("c", ArrowDataType::Utf8, false),
            ArrowField::new_struct("d", &*d_fields, false),
            ArrowField::new_struct("e", &*e_fields, false),
        ]);

        let mut a_builder = Int32Builder::new();
        let mut b_builder = BooleanBuilder::new();
        let mut c_builder = StringBuilder::new();

        a_builder.append_values(&[12, 11], &[true, true]);
        b_builder.append_values(&[true, false], &[true, true])?;
        c_builder.append_value("hi");
        c_builder.append_value("datafusion");

        let struct_builders: Vec<Box<dyn ArrayBuilder>> = vec![
            Box::new(Int32Builder::new()),
            Box::new(BooleanBuilder::new()),
        ];
        let mut d_builder = StructBuilder::new(d_fields, struct_builders);

        d_builder.append(true);
        d_builder.append(true);

        let i_builder = d_builder
            .field_builder::<Int32Builder>(0)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        i_builder.append_value(12);
        i_builder.append_value(11);
        let j_builder = d_builder
            .field_builder::<BooleanBuilder>(1)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        j_builder.append_value(true);
        j_builder.append_value(false);

        let mut e_builder = StructBuilder::new(e_fields, vec![Box::new(Int32Builder::new())]);
        e_builder.append(true);
        e_builder.append(true);
        let i_builder = e_builder
            .field_builder::<Int32Builder>(0)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        i_builder.append_value(12);
        i_builder.append_value(11);

        let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![
            Box::new(a_builder),
            Box::new(b_builder),
            Box::new(c_builder),
            Box::new(d_builder),
            Box::new(e_builder),
        ];
        let arrays = builders.iter_mut().map(|b| b.finish()).collect::<Vec<_>>();

        let batch = RecordBatch::try_new(Arc::new(schema), arrays)?;
        Ok(batch)
    }

    async fn sql(
        sql: impl AsRef<str>,
        udaf_input_type: ArrowDataType,
    ) -> anyhow::Result<DataFrame> {
        let ctx = SessionContext::default();
        let batch = test_data()?;
        ctx.register_batch("batch", batch)?;
        ctx.register_udaf(udaf(udaf_input_type));
        let df = ctx.sql(sql.as_ref()).await?;
        Ok(df)
    }

    fn dataframe() -> anyhow::Result<DataFrame> {
        let ctx = SessionContext::default();
        let batch = test_data()?;
        let df = ctx.read_batch(batch)?;
        Ok(df)
    }

    fn udaf(input_type: ArrowDataType) -> AggregateUDF {
        create_udaf(
            "my_sum",
            input_type,
            Arc::new(ArrowDataType::UInt32),
            datafusion::logical_expr::Volatility::Immutable,
            Arc::new(|_| Ok(Box::new(SumUdaf::default()))),
            Arc::new(vec![ArrowDataType::UInt32]),
        )
    }

    fn pack_cols(cols: Vec<impl Into<Column>>) -> Expr {
        Expr::ScalarFunction(ScalarFunction {
            fun: BuiltinScalarFunction::Struct,
            args: cols.into_iter().map(|c| col(c)).collect::<Vec<_>>(),
        })
    }
    
    async fn assert(df: DataFrame, expected: u32) -> anyhow::Result<()> {
        let result = df.collect().await?;
        let result_arr = result[0].column(0);
        let result_arr = downcast_array::<UInt32Array>(result_arr);
        let actual = result_arr.value(0);
        assert_eq!(expected, actual);
        Ok(())
    }

    #[tokio::test]
    async fn test_show() -> anyhow::Result<()> {
        let df = dataframe()?;
        df.show().await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_many_col_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
            ArrowField::new("j", ArrowDataType::Boolean, false),
        ]));
        let df = sql("SELECT my_sum(d) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_many_col_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
            ArrowField::new("j", ArrowDataType::Boolean, false),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);

        let df = df.aggregate(vec![], vec![(udaf.call(vec![col("d")]))])?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_one_col_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
        ]));
        let df = sql("SELECT my_sum(e) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_one_col_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);

        let df = df.aggregate(vec![], vec![(udaf.call(vec![col("e")]))])?;
        assert(df, 23).await?;
        Ok(())
    }
    #[tokio::test]
    async fn test_udaf_pack_one_col_struct_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            // ArrowField::new("c1", ArrowDataType::Boolean, true),
            //ArrowField::new("c2", ArrowDataType::Utf8, true),
        ]));
        let df = sql("SELECT my_sum(struct(a)) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    // FAILS - Treats all struct fields as Utf8
    #[tokio::test]
    async fn test_udaf_pack_many_col_struct_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            ArrowField::new("c1", ArrowDataType::Boolean, true),
            ArrowField::new("c2", ArrowDataType::Utf8, true),
        ]));
        let df = sql("SELECT my_sum(struct(a, b, c)) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_pack_one_col_struct_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([ArrowField::new(
            "c0",
            ArrowDataType::Int32,
            true,
        )]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);
        let packed_expr = pack_cols(vec!["a"]);

        let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;
        assert(df, 23).await?;

        Ok(())
    }

    // FAILS - Treats all struct fields as Utf8
    #[tokio::test]
    async fn test_udaf_pack_many_col_struct_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            ArrowField::new("c1", ArrowDataType::Boolean, true),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);
        let packed_expr = pack_cols(vec!["a", "b"]);
        let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;

        assert(df, 23).await?;

        Ok(())
    }
}

Expected behavior

We are able to create a struct with multiple fields using SQL API struct() builtin or Expr API's BuiltInScalarFunction::Struct and provide that as input to UDAF.

Additional context

The UDAF here is very simple just for example.

Is there a limitation with UDAF or could we open an enhancement request to support multiple input columns?

@alexwilcoxson-rel alexwilcoxson-rel added the bug Something isn't working label Jul 18, 2023
@alamb alamb changed the title Multi-field strucdts from builtin scalar fn incompatible with UDAF Multi-field structs from builtin scalar fn incompatible with UDAF Jul 25, 2023
@2010YOUY01
Copy link
Contributor

I think it makes sense to let UDAF support multiple column inputs, there are already built-in aggregate functions like correlation/covariance that support multi-column input.

@alamb
Copy link
Contributor

alamb commented Jul 30, 2023

I believe this was fixed in #7096

Can you confirm @alexwilcoxson-rel ?

@alexwilcoxson-rel
Copy link
Author

@alamb this fixes our initial use case of just needing to provide multiple inputs. There still looks to be an issue with the latest code on main where you can't create a struct and pass it as a single argument to a UDAF, e.g. SELECT my_udaf(struct(col_A, col_B)) This was just our workaround though and is more of an edge case IMO.

So perhaps just keep it with the other "improve struct" issues.

@alamb
Copy link
Contributor

alamb commented Aug 1, 2023

This was just our workaround though and is more of an edge case IMO.

So perhaps just keep it with the other "improve struct" issues.

Yes that makes sense to me -- it is probably worth making a new ticket for just that usecase (especially since this ticket has such a nice reproducer) ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants