Skip to content

Commit 215281d

Browse files
zero323ueshin
authored andcommitted
[SPARK-20830][PYSPARK][SQL] Add posexplode and posexplode_outer
## What changes were proposed in this pull request? Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`. ## How was this patch tested? Unit tests, doctests. Author: zero323 <[email protected]> Closes #18049 from zero323/SPARK-20830.
1 parent ba78514 commit 215281d

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

python/pyspark/sql/functions.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,6 +1727,71 @@ def posexplode(col):
17271727
return Column(jc)
17281728

17291729

1730+
@since(2.3)
1731+
def explode_outer(col):
1732+
"""Returns a new row for each element in the given array or map.
1733+
Unlike explode, if the array/map is null or empty then null is produced.
1734+
1735+
>>> df = spark.createDataFrame(
1736+
... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
1737+
... ("id", "an_array", "a_map")
1738+
... )
1739+
>>> df.select("id", "an_array", explode_outer("a_map")).show()
1740+
+---+----------+----+-----+
1741+
| id| an_array| key|value|
1742+
+---+----------+----+-----+
1743+
| 1|[foo, bar]| x| 1.0|
1744+
| 2| []|null| null|
1745+
| 3| null|null| null|
1746+
+---+----------+----+-----+
1747+
1748+
>>> df.select("id", "a_map", explode_outer("an_array")).show()
1749+
+---+-------------+----+
1750+
| id| a_map| col|
1751+
+---+-------------+----+
1752+
| 1|Map(x -> 1.0)| foo|
1753+
| 1|Map(x -> 1.0)| bar|
1754+
| 2| Map()|null|
1755+
| 3| null|null|
1756+
+---+-------------+----+
1757+
"""
1758+
sc = SparkContext._active_spark_context
1759+
jc = sc._jvm.functions.explode_outer(_to_java_column(col))
1760+
return Column(jc)
1761+
1762+
1763+
@since(2.3)
1764+
def posexplode_outer(col):
1765+
"""Returns a new row for each element with position in the given array or map.
1766+
Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
1767+
1768+
>>> df = spark.createDataFrame(
1769+
... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
1770+
... ("id", "an_array", "a_map")
1771+
... )
1772+
>>> df.select("id", "an_array", posexplode_outer("a_map")).show()
1773+
+---+----------+----+----+-----+
1774+
| id| an_array| pos| key|value|
1775+
+---+----------+----+----+-----+
1776+
| 1|[foo, bar]| 0| x| 1.0|
1777+
| 2| []|null|null| null|
1778+
| 3| null|null|null| null|
1779+
+---+----------+----+----+-----+
1780+
>>> df.select("id", "a_map", posexplode_outer("an_array")).show()
1781+
+---+-------------+----+----+
1782+
| id| a_map| pos| col|
1783+
+---+-------------+----+----+
1784+
| 1|Map(x -> 1.0)| 0| foo|
1785+
| 1|Map(x -> 1.0)| 1| bar|
1786+
| 2| Map()|null|null|
1787+
| 3| null|null|null|
1788+
+---+-------------+----+----+
1789+
"""
1790+
sc = SparkContext._active_spark_context
1791+
jc = sc._jvm.functions.posexplode_outer(_to_java_column(col))
1792+
return Column(jc)
1793+
1794+
17301795
@ignore_unicode_prefix
17311796
@since(1.6)
17321797
def get_json_object(col, path):

python/pyspark/sql/tests.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,12 @@ def test_column_name_encoding(self):
258258
self.assertTrue(isinstance(columns[1], str))
259259

260260
def test_explode(self):
261-
from pyspark.sql.functions import explode
262-
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
261+
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
262+
d = [
263+
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
264+
Row(a=1, intlist=[], mapfield={}),
265+
Row(a=1, intlist=None, mapfield=None),
266+
]
263267
rdd = self.sc.parallelize(d)
264268
data = self.spark.createDataFrame(rdd)
265269

@@ -272,6 +276,18 @@ def test_explode(self):
272276
self.assertEqual(result[0][0], "a")
273277
self.assertEqual(result[0][1], "b")
274278

279+
result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
280+
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
281+
282+
result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
283+
self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
284+
285+
result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
286+
self.assertEqual(result, [1, 2, 3, None, None])
287+
288+
result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
289+
self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
290+
275291
def test_and_in_expression(self):
276292
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
277293
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))

0 commit comments

Comments
 (0)