Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ from pyspark.sql.column import Column

ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
ColumnOrNameOrOrdinal = Union[Column, str, int]
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = PrimitiveType
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

ColumnOrName = Union[Column, str]

ColumnOrNameOrOrdinal = Union[Column, str, int]

PrimitiveType = Union[bool, float, int, str]

OptionalPrimitiveType = Optional[PrimitiveType]
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
ColumnOrName,
ColumnOrNameOrOrdinal,
LiteralType,
PrimitiveType,
OptionalPrimitiveType,
Expand Down Expand Up @@ -476,7 +477,7 @@ def first(self) -> Optional[Row]:

first.__doc__ = PySparkDataFrame.first.__doc__

def groupBy(self, *cols: "ColumnOrName") -> GroupedData:
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]

Expand All @@ -486,6 +487,12 @@ def groupBy(self, *cols: "ColumnOrName") -> GroupedData:
_cols.append(c)
elif isinstance(c, str):
_cols.append(self[c])
elif isinstance(c, int) and not isinstance(c, bool):
# TODO: should introduce dedicated error class
if c < 1:
raise IndexError(f"Column ordinal must be positive but got {c}")
# ordinal is 1-based
_cols.append(self[c - 1])
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
Expand Down
66 changes: 60 additions & 6 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@
if TYPE_CHECKING:
from pyspark._typing import PrimitiveType
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType
from pyspark.sql._typing import (
ColumnOrName,
ColumnOrNameOrOrdinal,
LiteralType,
OptionalPrimitiveType,
)
from pyspark.sql.context import SQLContext
from pyspark.sql.session import SparkSession
from pyspark.sql.group import GroupedData
Expand Down Expand Up @@ -2844,6 +2849,26 @@ def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
cols = cols[0]
return self._jseq(cols, _to_java_column)

def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> JavaObject:
"""Return a JVM Seq of Columns from a list of Column or column names or column ordinals.

If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]

_cols = []
for c in cols:
if isinstance(c, int) and not isinstance(c, bool):
# TODO: should introduce dedicated error class
if c < 1:
raise IndexError(f"Column ordinal must be positive but got {c}")
# ordinal is 1-based
_cols.append(self[c - 1])
else:
_cols.append(c) # type: ignore[arg-type]
return self._jseq(_cols, _to_java_column)

def _sort_cols(
self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], kwargs: Dict[str, Any]
) -> JavaObject:
Expand Down Expand Up @@ -3513,14 +3538,14 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame":
return DataFrame(jdf, self.sparkSession)

@overload
def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
...

@overload
def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
...

def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
"""Groups the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
Expand All @@ -3532,18 +3557,26 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.0.0
Supports column ordinal.

Parameters
----------
cols : list, str or :class:`Column`
columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`)
or list of them.
or a column ordinal (int, 1-based) or list of them.

Returns
-------
:class:`GroupedData`
Grouped data by given columns.

Notes
-----
A column ordinal starts from 1, which is different from the
0-based :meth:`__getitem__`.

Examples
--------
>>> df = spark.createDataFrame([
Expand Down Expand Up @@ -3578,6 +3611,16 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
| Bob| 5|
+-----+--------+

Also group-by 'name', but using the column ordinal.

>>> df.groupBy(2).max().sort("name").show()
+-----+--------+
| name|max(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+

Group-by 'name' and 'age', and calculate the number of rows in each group.

>>> df.groupBy(["name", df.age]).count().sort("name", "age").show()
Expand All @@ -3588,8 +3631,19 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
| Bob| 2| 2|
| Bob| 5| 1|
+-----+---+-----+

Also Group-by 'name' and 'age', but using the column ordinal.

>>> df.groupBy([df.name, 1]).count().sort("name", "age").show()
+-----+---+-----+
| name|age|count|
+-----+---+-----+
|Alice| 2| 1|
| Bob| 2| 2|
| Bob| 5| 1|
+-----+---+-----+
"""
jgd = self._jdf.groupBy(self._jcols(*cols))
jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
from pyspark.sql.group import GroupedData

return GroupedData(jgd, self)
Expand Down
61 changes: 61 additions & 0 deletions python/pyspark/sql/tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#

from pyspark.sql import Row
from pyspark.sql import functions as sf
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual


class GroupTestsMixin:
Expand All @@ -35,6 +37,65 @@ def test_aggregator(self):
# test deprecated countDistinct
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_group_by_ordinal(self):
spark = self.spark
df = spark.createDataFrame(
[
(1, 1),
(1, 2),
(2, 1),
(2, 2),
(3, 1),
(3, 2),
],
["a", "b"],
)

with self.tempView("v"):
df.createOrReplaceTempView("v")

# basic case
df1 = spark.sql("select a, sum(b) from v group by 1;")
df2 = df.groupBy(1).agg(sf.sum("b"))
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

# constant case
df1 = spark.sql("select 1, 2, sum(b) from v group by 1, 2;")
df2 = df.select(sf.lit(1), sf.lit(2), "b").groupBy(1, 2).agg(sf.sum("b"))
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

# duplicate group by column
df1 = spark.sql("select a, 1, sum(b) from v group by a, 1;")
df2 = df.select("a", sf.lit(1), "b").groupBy("a", 2).agg(sf.sum("b"))
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2;")
df2 = df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b"))
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

# group by a non-aggregate expression's ordinal
df1 = spark.sql("select a, b + 2, count(2) from v group by a, 2;")
df2 = df.select("a", df.b + 2).groupBy(1, 2).agg(sf.count(sf.lit(2)))
assertSchemaEqual(df1.schema, df2.schema)
assertDataFrameEqual(df1, df2)

# negative cases: ordinal out of range
with self.assertRaises(IndexError):
df.groupBy(0).agg(sf.sum("b"))

with self.assertRaises(IndexError):
df.groupBy(-1).agg(sf.sum("b"))

with self.assertRaises(IndexError):
df.groupBy(3).agg(sf.sum("b"))

with self.assertRaises(IndexError):
df.groupBy(10).agg(sf.sum("b"))


class GroupTests(GroupTestsMixin, ReusedSQLTestCase):
pass
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/typing/test_dataframe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
df.groupby(["name", "age"])
df.groupBy([col("name"), col("age")])
df.groupby([col("name"), col("age")])
df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str]]" [arg-type]
df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str], List[int]]" [arg-type]


- case: rollup
Expand Down