diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 9040b6f807f93..9a97c25f296a1 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -35,6 +35,19 @@ def df(): return ctx.create_dataframe([[batch]]) +@pytest.fixture +def struct_df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + + return ctx.create_dataframe([[batch]]) + + def test_select(df): df = df.select( column("a") + column("b"), @@ -153,3 +166,16 @@ def test_get_dataframe(tmp_path): df = ctx.table("csv") assert isinstance(df, DataFrame) + + +def test_struct_select(struct_df): + df = struct_df.select( + column("a")["c"] + column("b"), + column("a")["c"] - column("b"), + ) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([-3, -3, -3]) diff --git a/python/src/expression.rs b/python/src/expression.rs index 5e1cad246bf87..d646d6b58d861 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use pyo3::PyMappingProtocol; use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; use std::convert::{From, Into}; @@ -133,3 +134,14 @@ impl PyExpr { expr.into() } } + +#[pyproto] +impl PyMappingProtocol for PyExpr { + fn __getitem__(&self, key: &str) -> PyResult { + Ok(Expr::GetIndexedField { + expr: Box::new(self.expr.clone()), + key: ScalarValue::Utf8(Some(key.to_string()).to_owned()), + } + .into()) + } +}