diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 17e787149a7..27e0144d758 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2886,20 +2886,23 @@ cdef class Table(_PandasConvertible): """ Select rows from the table. - See :func:`pyarrow.compute.filter` for full usage. + The Table can be filtered based on a mask, which will be passed to + :func:`pyarrow.compute.filter` to perform the filtering, or it can + be filtered through a boolean :class:`.Expression` Parameters ---------- - mask : Array or array-like - The boolean mask to filter the table with. + mask : Array or array-like or .Expression + The boolean mask or the :class:`.Expression` to filter the table with. null_selection_behavior - How nulls in the mask should be handled. + How nulls in the mask should be handled, does nothing if + an :class:`.Expression` is used. Returns ------- filtered : Table A table of the same schema, with only the rows selected - by the boolean mask. + by applied filtering Examples -------- @@ -2932,7 +2935,11 @@ cdef class Table(_PandasConvertible): n_legs: [[2,4,null]] animals: [["Flamingo","Horse",null]] """ - return _pc().filter(self, mask, null_selection_behavior) + if isinstance(mask, _pc().Expression): + return _pc()._exec_plan._filter_table(self, mask, + output_type=Table) + else: + return _pc().filter(self, mask, null_selection_behavior) def take(self, object indices): """ diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 0520bdedc30..3bc905f2f97 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -4620,3 +4620,19 @@ def test_dataset_join_collisions(tempdir): [10, 20, None, 99], ["A", "B", None, "Z"], ], names=["colA", "colB", "colVals", "colB_r", "colVals_r"]) + + +@pytest.mark.dataset +def test_dataset_filter(tempdir): + t1 = pa.table({ + "colA": [1, 2, 6], + "col2": ["a", "b", "f"] + }) + ds.write_dataset(t1, tempdir / "t1", format="parquet") + ds1 = ds.dataset(tempdir / "t1") + + result = ds1.scanner(filter=pc.field("colA") < 3) + assert result.to_table() == pa.table({ + "colA": [1, 2], + "col2": ["a", "b"] + }) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index bdbd2a7c8fc..81a8b27f4d7 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2121,3 +2121,27 @@ def test_table_join_collisions(): [10, 20, None, 99], ["A", "B", None, "Z"], ], names=["colA", "colB", "colVals", "colB", "colVals"]) + + +@pytest.mark.dataset +def test_table_filter_expression(): + t1 = pa.table({ + "colA": [1, 2, 6], + "colB": [10, 20, 60], + "colVals": ["a", "b", "f"] + }) + + t2 = pa.table({ + "colA": [99, 2, 1], + "colB": [99, 20, 10], + "colVals": ["Z", "B", "A"] + }) + + t3 = pa.concat_tables([t1, t2]) + + result = t3.filter(pc.field("colA") < 10) + assert result.combine_chunks() == pa.table({ + "colA": [1, 2, 6, 2, 1], + "colB": [10, 20, 60, 20, 10], + "colVals": ["a", "b", "f", "B", "A"] + })