diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 58cd68fe8..09de05d67 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -5,7 +5,6 @@ from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter -from dask_sql.physical.utils.map import map_on_partition_index if TYPE_CHECKING: import dask_sql @@ -80,17 +79,17 @@ def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame # its pandas representation at this point and we can calculate the cumsum # (which is not possible on the dask object). Recalculating it should not cost # us much, as we assume the number of partitions is rather small. - def select_from_to(df, partition_index, partition_borders): - partition_borders = partition_borders.cumsum().to_dict() + def select_from_to(index, partition, partition_borders): + partition_borders = partition_borders.compute().cumsum().to_dict() this_partition_border_left = ( - partition_borders[partition_index - 1] if partition_index > 0 else 0 + partition_borders[index - 1] if index > 0 else 0 ) - this_partition_border_right = partition_borders[partition_index] + this_partition_border_right = partition_borders[index] if (end and end < this_partition_border_left) or ( offset and offset >= this_partition_border_right ): - return df.iloc[0:0] + return partition.compute().iloc[0:0] from_index = max(offset - this_partition_border_left, 0) if offset else 0 to_index = ( @@ -99,10 +98,14 @@ def select_from_to(df, partition_index, partition_borders): else this_partition_border_right ) - this_partition_border_left - return df.iloc[from_index:to_index] + return partition.compute().iloc[from_index:to_index] # (b) Now we just need to apply the function on every partition # We do this via the delayed interface, which seems the easiest one. - return map_on_partition_index( - df, select_from_to, partition_borders, meta=df._meta + return dd.from_map( + select_from_to, + list(range(df.npartitions)), + list(df.partitions), + args=(partition_borders,), + meta=df._meta, ) diff --git a/dask_sql/physical/utils/map.py b/dask_sql/physical/utils/map.py deleted file mode 100644 index 791342ccc..000000000 --- a/dask_sql/physical/utils/map.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Callable - -import dask -import dask.dataframe as dd - - -def map_on_partition_index( - df: dd.DataFrame, f: Callable, *args: Any, **kwargs: Any -) -> dd.DataFrame: - meta = kwargs.pop("meta", None) - return dd.from_delayed( - [ - dask.delayed(f)(partition, partition_number, *args, **kwargs) - for partition_number, partition in enumerate(df.partitions) - ], - meta=meta, - )