Skip to content

Commit d4e4336

Browse files
cem-anyscalegemini-code-assist[bot]
authored andcommitted
[Data] Callback-based stat computation for preprocessors and ValueCounter (#56848)
* Updated preprocessors to use a callback-based approach for stat computation. This improves code organization and reduces duplication. * Added ValueCounter aggregator and value_counts method to BlockColumnAccessor. Includes implementations for both Arrow and Pandas backends. <!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: cem <[email protected]> Signed-off-by: cem-anyscale <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: elliot-barn <[email protected]>
1 parent 678c503 commit d4e4336

File tree

17 files changed

+768
-284
lines changed

17 files changed

+768
-284
lines changed

doc/source/data/api/aggregate.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ compute aggregations.
2525
AbsMax
2626
Quantile
2727
Unique
28+
ValueCounter
2829
MissingValuePercentage
2930
ZeroPercentage
3031
ApproximateQuantile

python/ray/data/_internal/arrow_block.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,39 @@ def unique(self) -> BlockColumn:
530530

531531
return pac.unique(self._column)
532532

533+
def value_counts(self) -> Optional[Dict[str, List]]:
534+
import pyarrow.compute as pac
535+
536+
value_counts: pyarrow.StructArray = pac.value_counts(self._column)
537+
if len(value_counts) == 0:
538+
return None
539+
return {
540+
"values": value_counts.field("values").to_pylist(),
541+
"counts": value_counts.field("counts").to_pylist(),
542+
}
543+
544+
def hash(self) -> BlockColumn:
545+
import polars as pl
546+
547+
df = pl.DataFrame({"col": self._column})
548+
hashes = df.hash_rows().cast(pl.Int64, wrap_numerical=True)
549+
return hashes.to_arrow()
550+
533551
def flatten(self) -> BlockColumn:
534552
import pyarrow.compute as pac
535553

536554
return pac.list_flatten(self._column)
537555

556+
def dropna(self) -> BlockColumn:
557+
import pyarrow.compute as pac
558+
559+
return pac.drop_null(self._column)
560+
561+
def is_composed_of_lists(self, types: Optional[Tuple] = None) -> bool:
562+
if not types:
563+
types = (pyarrow.lib.ListType, pyarrow.lib.LargeListType)
564+
return isinstance(self._column.type, types)
565+
538566
def to_pylist(self) -> List[Any]:
539567
return self._column.to_pylist()
540568

python/ray/data/_internal/pandas_block.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,33 @@ def quantile(
174174
) -> Optional[U]:
175175
return self._column.quantile(q=q)
176176

177+
def value_counts(self) -> Optional[Dict[str, List]]:
178+
value_counts = self._column.value_counts()
179+
if len(value_counts) == 0:
180+
return None
181+
return {
182+
"values": value_counts.index.tolist(),
183+
"counts": value_counts.values.tolist(),
184+
}
185+
186+
def hash(self) -> BlockColumn:
187+
188+
from ray.air.util.tensor_extensions.pandas import TensorArrayElement
189+
190+
first_non_null = next((x for x in self._column if x is not None), None)
191+
if isinstance(first_non_null, TensorArrayElement):
192+
self._column = self._column.apply(lambda x: x.to_numpy())
193+
194+
import polars as pl
195+
196+
df = pl.from_pandas(self._column.to_frame())
197+
hashes = df.hash_rows().cast(pl.Int64, wrap_numerical=True)
198+
return hashes.to_pandas()
199+
177200
def unique(self) -> BlockColumn:
201+
178202
pd = lazy_import_pandas()
203+
179204
try:
180205
return pd.Series(self._column.unique())
181206
except ValueError as e:
@@ -187,7 +212,18 @@ def unique(self) -> BlockColumn:
187212
raise
188213

189214
def flatten(self) -> BlockColumn:
190-
return self._column.list.flatten()
215+
from ray.air.util.tensor_extensions.pandas import TensorArrayElement
216+
217+
first_non_null = next((x for x in self._column if x is not None), None)
218+
if isinstance(first_non_null, TensorArrayElement):
219+
self._column = self._column.apply(
220+
lambda x: x.to_numpy() if isinstance(x, TensorArrayElement) else x
221+
)
222+
223+
return self._column.explode(ignore_index=True)
224+
225+
def dropna(self) -> BlockColumn:
226+
return self._column.dropna()
191227

192228
def sum_of_squared_diffs_from_mean(
193229
self,
@@ -219,6 +255,14 @@ def _as_arrow_compatible(self) -> Union[List[Any], "pyarrow.Array"]:
219255
def _is_all_null(self):
220256
return not self._column.notna().any()
221257

258+
def is_composed_of_lists(self, types: Optional[Tuple] = None) -> bool:
259+
from ray.air.util.tensor_extensions.pandas import TensorArrayElement
260+
261+
if not types:
262+
types = (list, np.ndarray, TensorArrayElement)
263+
first_non_null = next((x for x in self._column if x is not None), None)
264+
return isinstance(first_non_null, types)
265+
222266

223267
class PandasBlockBuilder(TableBlockBuilder):
224268
def __init__(self):

python/ray/data/aggregate.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import math
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional
3+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
44

55
import numpy as np
66
import pyarrow.compute as pc
@@ -889,6 +889,88 @@ def _to_set(x):
889889
return {x}
890890

891891

892+
@PublicAPI
893+
class ValueCounter(AggregateFnV2):
894+
"""Counts the number of times each value appears in a column.
895+
896+
This aggregation computes value counts for a specified column, similar to pandas'
897+
`value_counts()` method. It returns a dictionary with two lists: "values" containing
898+
the unique values found in the column, and "counts" containing the corresponding
899+
count for each value.
900+
901+
Example:
902+
903+
.. testcode::
904+
905+
import ray
906+
from ray.data.aggregate import ValueCounter
907+
908+
# Create a dataset with repeated values
909+
ds = ray.data.from_items([
910+
{"category": "A"}, {"category": "B"}, {"category": "A"},
911+
{"category": "C"}, {"category": "A"}, {"category": "B"}
912+
])
913+
914+
# Count occurrences of each category
915+
result = ds.aggregate(ValueCounter(on="category"))
916+
# result: {'value_counter(category)': {'values': ['A', 'B', 'C'], 'counts': [3, 2, 1]}}
917+
918+
# Using with groupby
919+
ds = ray.data.from_items([
920+
{"group": "X", "category": "A"}, {"group": "X", "category": "B"},
921+
{"group": "Y", "category": "A"}, {"group": "Y", "category": "A"}
922+
])
923+
result = ds.groupby("group").aggregate(ValueCounter(on="category")).take_all()
924+
# result: [{'group': 'X', 'value_counter(category)': {'values': ['A', 'B'], 'counts': [1, 1]}},
925+
# {'group': 'Y', 'value_counter(category)': {'values': ['A'], 'counts': [2]}}]
926+
927+
Args:
928+
on: The name of the column to count values in. Must be provided.
929+
alias_name: Optional name for the resulting column. If not provided,
930+
defaults to "value_counter({column_name})".
931+
"""
932+
933+
def __init__(
934+
self,
935+
on: str,
936+
alias_name: Optional[str] = None,
937+
):
938+
super().__init__(
939+
alias_name if alias_name else f"value_counter({str(on)})",
940+
on=on,
941+
ignore_nulls=True,
942+
zero_factory=lambda: {"values": [], "counts": []},
943+
)
944+
945+
def aggregate_block(self, block: Block) -> Dict[str, List]:
946+
947+
col_accessor = BlockColumnAccessor.for_column(block[self._target_col_name])
948+
return col_accessor.value_counts()
949+
950+
def combine(
951+
self,
952+
current_accumulator: Dict[str, List],
953+
new_accumulator: Dict[str, List],
954+
) -> Dict[str, List]:
955+
956+
values = current_accumulator["values"]
957+
counts = current_accumulator["counts"]
958+
959+
# Build a value → index map once (avoid repeated lookups)
960+
value_to_index = {v: i for i, v in enumerate(values)}
961+
962+
for v_new, c_new in zip(new_accumulator["values"], new_accumulator["counts"]):
963+
if v_new in value_to_index:
964+
idx = value_to_index[v_new]
965+
counts[idx] += c_new
966+
else:
967+
value_to_index[v_new] = len(values)
968+
values.append(v_new)
969+
counts.append(c_new)
970+
971+
return current_accumulator
972+
973+
892974
def _null_safe_zero_factory(zero_factory, ignore_nulls: bool):
893975
"""NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
894976

python/ray/data/block.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,44 @@ def unique(self) -> BlockColumn:
685685
"""Returns new column holding only distinct values of the current one"""
686686
raise NotImplementedError()
687687

688+
def value_counts(self) -> Dict[str, List]:
689+
raise NotImplementedError()
690+
691+
def hash(self) -> BlockColumn:
692+
"""
693+
Computes a 64-bit hash value for each row in the column.
694+
695+
Provides a unified hashing method across supported backends.
696+
Handles complex types like lists or nested structures by producing a single hash per row.
697+
These hashes are useful for downstream operations such as deduplication, grouping, or partitioning.
698+
699+
Internally, Polars is used to compute row-level hashes even when the original column
700+
is backed by Pandas or PyArrow.
701+
702+
:return: A column of 64-bit integer hashes, returned in the same format as the underlying backend
703+
(e.g., Pandas Series or PyArrow Array).
704+
"""
705+
raise NotImplementedError()
706+
688707
def flatten(self) -> BlockColumn:
689708
"""Flattens nested lists merging them into top-level container"""
690709

691710
raise NotImplementedError()
692711

712+
def dropna(self) -> BlockColumn:
713+
raise NotImplementedError()
714+
715+
def is_composed_of_lists(self, types: Optional[Tuple] = None) -> bool:
716+
"""
717+
Checks whether the column is composed of list-like elements.
718+
719+
:param types: Optional tuple of backend-specific types to check against.
720+
If not provided, defaults to list-like types appropriate
721+
for the underlying backend (e.g., PyArrow list types).
722+
:return: True if the column is made up of list-like values; False otherwise.
723+
"""
724+
raise NotImplementedError()
725+
693726
def sum_of_squared_diffs_from_mean(
694727
self,
695728
*,

python/ray/data/preprocessor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class Preprocessor(abc.ABC):
4747
implemented method.
4848
"""
4949

50+
def __init__(self):
51+
from ray.data.preprocessors.utils import StatComputationPlan
52+
53+
self.stat_computation_plan = StatComputationPlan()
54+
self.stats_ = {}
55+
5056
class FitStatus(str, Enum):
5157
"""The fit status of preprocessor."""
5258

@@ -72,7 +78,7 @@ def _check_has_fitted_state(self):
7278
used to transform data in newer versions.
7379
"""
7480

75-
fitted_vars = [v for v in vars(self) if v.endswith("_")]
81+
fitted_vars = [v for v in vars(self) if v.endswith("_") and getattr(self, v)]
7682
return bool(fitted_vars)
7783

7884
def fit_status(self) -> "Preprocessor.FitStatus":
@@ -114,10 +120,15 @@ def fit(self, ds: "Dataset") -> "Preprocessor":
114120
"All previously fitted state will be overwritten!"
115121
)
116122

117-
fitted_ds = self._fit(ds)
123+
self.stat_computation_plan.reset()
124+
fitted_ds = self._fit(ds)._fit_execute(ds)
118125
self._fitted = True
119126
return fitted_ds
120127

128+
def _fit_execute(self, dataset: "Dataset"):
129+
self.stats_ |= self.stat_computation_plan.compute(dataset)
130+
return self
131+
121132
def fit_transform(
122133
self,
123134
ds: "Dataset",
@@ -373,6 +384,18 @@ def preferred_batch_format(cls) -> BatchFormat:
373384
"""
374385
return BatchFormat.PANDAS
375386

387+
def __getstate__(self):
388+
state = self.__dict__.copy()
389+
# Exclude unpicklable attributes
390+
state.pop("stat_computation_plan", None)
391+
return state
392+
393+
def __setstate__(self, state):
394+
from ray.data.preprocessors.utils import StatComputationPlan
395+
396+
self.__dict__.update(state)
397+
self.stat_computation_plan = StatComputationPlan()
398+
376399
@DeveloperAPI
377400
def serialize(self) -> str:
378401
"""Return this preprocessor serialized as a string.

python/ray/data/preprocessors/chain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def fit_status(self):
6666
return Preprocessor.FitStatus.NOT_FITTABLE
6767

6868
def __init__(self, *preprocessors: Preprocessor):
69+
super().__init__()
6970
self.preprocessors = preprocessors
7071

7172
def _fit(self, ds: "Dataset") -> Preprocessor:

0 commit comments

Comments
 (0)