Skip to content

Commit

Permalink
[CHORE] Turn v0.3 deprecations into breaking changes (#2663)
Browse files Browse the repository at this point in the history
Two things being removed:
- `read_deltalake`
- Tuple arguments to `DataFrame.agg` and `GroupedDataFrame.agg`
  • Loading branch information
kevinzwang authored Aug 15, 2024
1 parent bdf8aca commit c7848ae
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 87 deletions.
2 changes: 0 additions & 2 deletions daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def refresh_logger() -> None:
from_glob_path,
read_csv,
read_deltalake,
read_delta_lake,
read_hudi,
read_iceberg,
read_json,
Expand All @@ -107,7 +106,6 @@ def refresh_logger() -> None:
"read_hudi",
"read_iceberg",
"read_deltalake",
"read_delta_lake",
"read_sql",
"read_lance",
"DataCatalogType",
Expand Down
66 changes: 31 additions & 35 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io
import os
import pathlib
import typing
import warnings
from dataclasses import dataclass
from functools import partial, reduce
Expand Down Expand Up @@ -1862,16 +1863,15 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any)
)
return result

def _agg(self, to_agg: List[Expression], group_by: Optional[ExpressionsProjection] = None) -> "DataFrame":
builder = self._builder.agg(to_agg, list(group_by) if group_by is not None else None)
def _agg(
self,
to_agg: Iterable[Expression],
group_by: Optional[ExpressionsProjection] = None,
) -> "DataFrame":
builder = self._builder.agg(list(to_agg), list(group_by) if group_by is not None else None)
return DataFrame(builder)

def _agg_tuple_to_expression(self, agg_tuple: Tuple[ColumnInputType, str]) -> Expression:
expr, op = agg_tuple

if isinstance(expr, str):
expr = col(expr)

def _map_agg_string_to_expr(self, expr: Expression, op: str) -> Expression:
if op == "sum":
return expr.sum()
elif op == "count":
Expand All @@ -1891,30 +1891,6 @@ def _agg_tuple_to_expression(self, agg_tuple: Tuple[ColumnInputType, str]) -> Ex

raise NotImplementedError(f"Aggregation {op} is not implemented.")

def _agg_inputs_to_expressions(
self, to_agg: Tuple[Union[Expression, Iterable[Expression]], ...]
) -> List[Expression]:
def is_agg_column_input(x: Any) -> bool:
# aggs currently support Expression or tuple of (ColumnInputType, str) [deprecated]
if isinstance(x, Expression):
return True
if isinstance(x, tuple) and len(x) == 2:
tuple_type = list(map(type, x))
return tuple_type == [Expression, str] or tuple_type == [str, str]
return False

columns: Iterable[Expression] = to_agg[0] if len(to_agg) == 1 and not is_agg_column_input(to_agg[0]) else to_agg # type: ignore

if any(isinstance(col, tuple) for col in columns):
warnings.warn(
"Tuple arguments in aggregations is deprecated and will be removed "
"in Daft v0.3. Please use aggregation expressions instead.",
DeprecationWarning,
)
return [self._agg_tuple_to_expression(col) if isinstance(col, tuple) else col for col in columns] # type: ignore
else:
return list(columns)

def _apply_agg_fn(
self,
fn: Callable[[Expression], Expression],
Expand Down Expand Up @@ -2058,7 +2034,17 @@ def agg(self, *to_agg: Union[Expression, Iterable[Expression]]) -> "DataFrame":
Returns:
DataFrame: DataFrame with aggregated results
"""
return self._agg(self._agg_inputs_to_expressions(to_agg), group_by=None)
to_agg_list = (
list(to_agg[0])
if (len(to_agg) == 1 and not isinstance(to_agg[0], Expression))
else list(typing.cast(Tuple[Expression], to_agg))
)

for expr in to_agg_list:
if not isinstance(expr, Expression):
raise ValueError(f"DataFrame.agg() only accepts expression type, received: {type(expr)}")

return self._agg(to_agg_list, group_by=None)

@DataframePublicAPI
def groupby(self, *group_by: ManyColumnsInputType) -> "GroupedDataFrame":
Expand Down Expand Up @@ -2151,7 +2137,7 @@ def pivot(
"""
group_by_expr = self._column_inputs_to_expressions(group_by)
[pivot_col_expr, value_col_expr] = self._column_inputs_to_expressions([pivot_col, value_col])
agg_expr = self._agg_tuple_to_expression((value_col_expr, agg_fn))
agg_expr = self._map_agg_string_to_expr(value_col_expr, agg_fn)

if names is None:
names = self.select(pivot_col_expr).distinct().to_pydict()[pivot_col_expr.name()]
Expand Down Expand Up @@ -2705,7 +2691,17 @@ def agg(self, *to_agg: Union[Expression, Iterable[Expression]]) -> "DataFrame":
Returns:
DataFrame: DataFrame with grouped aggregations
"""
return self.df._agg(self.df._agg_inputs_to_expressions(to_agg), group_by=self.group_by)
to_agg_list = (
list(to_agg[0])
if (len(to_agg) == 1 and not isinstance(to_agg[0], Expression))
else list(typing.cast(Tuple[Expression], to_agg))
)

for expr in to_agg_list:
if not isinstance(expr, Expression):
raise ValueError(f"GroupedDataFrame.agg() only accepts expression type, received: {type(expr)}")

return self.df._agg(to_agg_list, group_by=self.group_by)

def map_groups(self, udf: Expression) -> "DataFrame":
"""Apply a user-defined function to each group. The name of the resultant column will default to the name of the first input column.
Expand Down
3 changes: 1 addition & 2 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
set_io_pool_num_threads,
)
from daft.io._csv import read_csv
from daft.io._delta_lake import read_deltalake, read_delta_lake
from daft.io._delta_lake import read_deltalake
from daft.io._hudi import read_hudi
from daft.io._iceberg import read_iceberg
from daft.io._json import read_json
Expand Down Expand Up @@ -45,7 +45,6 @@ def _set_linux_cert_paths():
"read_hudi",
"read_iceberg",
"read_deltalake",
"read_delta_lake",
"read_lance",
"read_sql",
"IOConfig",
Expand Down
13 changes: 0 additions & 13 deletions daft/io/_delta_lake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# isort: dont-add-import: from __future__ import annotations

import warnings
from typing import Optional, Union

from daft import context
Expand All @@ -17,18 +16,6 @@
_UNITY_CATALOG_AVAILABLE = False


def read_delta_lake(
table: Union[str, DataCatalogTable],
io_config: Optional["IOConfig"] = None,
_multithreaded_io: Optional[bool] = None,
) -> DataFrame:
warnings.warn(
"read_delta_lake has been renamed to read_deltalake and will be removed in Daft v0.3",
DeprecationWarning,
)
return read_deltalake(table, io_config, _multithreaded_io)


@PublicAPI
def read_deltalake(
table: Union[str, DataCatalogTable, "UnityCatalogTable"],
Expand Down
2 changes: 1 addition & 1 deletion daft/unity_catalog/unity_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class UnityCatalog:
>>> cat = UnityCatalog("https://<databricks_workspace_id>.cloud.databricks.com", token="my-token")
>>> table = cat.load_table("my_catalog.my_schema.my_table")
>>> df = daft.read_delta_lake(table)
>>> df = daft.read_deltalake(table)
"""

def __init__(self, endpoint: str, token: str | None = None):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/fotw/fotw-000-data-access.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@
"source": [
"### Delta Lake\n",
"\n",
"You can easily read Delta Lake tables using the `read_delta_lake()` method."
"You can easily read Delta Lake tables using the `read_deltalake()` method."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/integrations/unity-catalog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Loading a Daft Dataframe from a Delta Lake table in Unity Catalog
unity_table = unity.load_table("my_catalog_name.my_schema_name.my_table_name")
df = daft.read_delta_lake(unity_table)
df = daft.read_deltalake(unity_table)
df.show()
Any subsequent filter operations on the Daft ``df`` DataFrame object will be correctly optimized to take advantage of DeltaLake features
Expand Down
16 changes: 0 additions & 16 deletions tests/dataframe/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,22 +385,6 @@ def test_groupby_result_partitions_smaller_than_input(shuffle_aggregation_defaul
assert df.num_partitions() == min(min_partitions, partition_size)


def test_agg_deprecation():
with pytest.deprecated_call():
df = daft.from_pydict({"a": [1, 2, 3], "b": [True, False, True]})
df = df.agg([("a", "sum"), ("b", "count")])
df.collect()

assert df.to_pydict() == {"a": [6], "b": [3]}

with pytest.deprecated_call():
df = daft.from_pydict({"a": [1, 2, 3], "b": [True, False, True]})
df = df.groupby("b").agg([("a", "sum")])
df.collect()

assert df.to_pydict() == {"b": [True, False], "a": [4, 2]}


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
def test_agg_any_value(make_df, repartition_nparts):
daft_df = make_df(
Expand Down
10 changes: 5 additions & 5 deletions tutorials/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -854,11 +854,11 @@
" .with_column(\"correct\", (col(\"model_classification\") == col(\"label\")).cast(DataType.int64())) \\\n",
" .with_column(\"wrong\", (col(\"model_classification\") != col(\"label\")).cast(DataType.int64())) \\\n",
" .groupby(col(\"label\")) \\\n",
" .agg([\n",
" (col(\"label\").alias(\"num_rows\"), \"count\"),\n",
" (col(\"correct\"), \"sum\"),\n",
" (col(\"wrong\"), \"sum\"),\n",
" ]) \\\n",
" .agg(\n",
" col(\"label\").count().alias(\"num_rows\"),\n",
" col(\"correct\").sum(),\n",
" col(\"wrong\").sum(),\n",
" ) \\\n",
" .sort(col(\"label\"))\n",
"\n",
"analysis_df.show()"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/talks_and_demos/data-ai-summit-2024.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@
}
],
"source": [
"read_df = daft.read_delta_lake(\"my_table.delta_lake\")\n",
"read_df = daft.read_deltalake(\"my_table.delta_lake\")\n",
"read_df"
]
},
Expand Down
18 changes: 8 additions & 10 deletions tutorials/talks_and_demos/pydata_global_2023.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,14 @@
" lineitem.where(col(\"L_SHIPDATE\") <= datetime.date(1998, 9, 2))\n",
" .groupby(col(\"L_RETURNFLAG\"), col(\"L_LINESTATUS\"))\n",
" .agg(\n",
" [\n",
" (col(\"L_QUANTITY\").alias(\"sum_qty\"), \"sum\"),\n",
" (col(\"L_EXTENDEDPRICE\").alias(\"sum_base_price\"), \"sum\"),\n",
" (discounted_price.alias(\"sum_disc_price\"), \"sum\"),\n",
" (taxed_discounted_price.alias(\"sum_charge\"), \"sum\"),\n",
" (col(\"L_QUANTITY\").alias(\"avg_qty\"), \"mean\"),\n",
" (col(\"L_EXTENDEDPRICE\").alias(\"avg_price\"), \"mean\"),\n",
" (col(\"L_DISCOUNT\").alias(\"avg_disc\"), \"mean\"),\n",
" (col(\"L_QUANTITY\").alias(\"count_order\"), \"count\"),\n",
" ]\n",
" col(\"L_QUANTITY\").alias(\"sum_qty\").sum(),\n",
" col(\"L_EXTENDEDPRICE\").alias(\"sum_base_price\").sum(),\n",
" discounted_price.alias(\"sum_disc_price\").sum(),\n",
" taxed_discounted_price.alias(\"sum_charge\").sum(),\n",
" col(\"L_QUANTITY\").alias(\"avg_qty\").mean(),\n",
" col(\"L_EXTENDEDPRICE\").alias(\"avg_price\").mean(),\n",
" col(\"L_DISCOUNT\").alias(\"avg_disc\").mean(),\n",
" col(\"L_QUANTITY\").alias(\"count_order\").count(),\n",
" )\n",
" .sort([\"L_RETURNFLAG\", \"L_LINESTATUS\"])\n",
")"
Expand Down

0 comments on commit c7848ae

Please sign in to comment.