Skip to content

Commit ca55dc9

Browse files
committed
[SPARK-7152][SQL] Add a Column expression for partition ID.
Author: Reynold Xin <[email protected]> Closes #5705 from rxin/df-pid and squashes the following commits: 401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID.
1 parent 9a5bbe0 commit ca55dc9

File tree

5 files changed

+110
-19
lines changed

5 files changed

+110
-19
lines changed

python/pyspark/sql/functions.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ def _(col):
7575
__all__.sort()
7676

7777

78+
def approxCountDistinct(col, rsd=None):
79+
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
80+
81+
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
82+
[Row(c=2)]
83+
"""
84+
sc = SparkContext._active_spark_context
85+
if rsd is None:
86+
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
87+
else:
88+
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
89+
return Column(jc)
90+
91+
7892
def countDistinct(col, *cols):
7993
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
8094
@@ -89,18 +103,16 @@ def countDistinct(col, *cols):
89103
return Column(jc)
90104

91105

92-
def approxCountDistinct(col, rsd=None):
93-
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
106+
def sparkPartitionId():
107+
"""Returns a column for partition ID of the Spark task.
94108
95-
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
96-
[Row(c=2)]
109+
Note that this is indeterministic because it depends on data partitioning and task scheduling.
110+
111+
>>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect()
112+
[Row(pid=0), Row(pid=0)]
97113
"""
98114
sc = SparkContext._active_spark_context
99-
if rsd is None:
100-
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
101-
else:
102-
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
103-
return Column(jc)
115+
return Column(sc._jvm.functions.sparkPartitionId())
104116

105117

106118
class UserDefinedFunction(object):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.expressions
19+
20+
import org.apache.spark.TaskContext
21+
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
22+
import org.apache.spark.sql.catalyst.trees
23+
import org.apache.spark.sql.types.{IntegerType, DataType}
24+
25+
26+
/**
27+
* Expression that returns the current partition id of the Spark task.
28+
*/
29+
case object SparkPartitionID extends Expression with trees.LeafNode[Expression] {
30+
self: Product =>
31+
32+
override type EvaluatedType = Int
33+
34+
override def nullable: Boolean = false
35+
36+
override def dataType: DataType = IntegerType
37+
38+
override def eval(input: Row): Int = TaskContext.get().partitionId()
39+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
/**
21+
* Package containing expressions that are specific to Spark runtime.
22+
*/
23+
package object expressions

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ object functions {
276276
// Non-aggregate functions
277277
//////////////////////////////////////////////////////////////////////////////////////////////
278278

279+
/**
280+
* Computes the absolute value.
281+
*
282+
* @group normal_funcs
283+
*/
284+
def abs(e: Column): Column = Abs(e.expr)
285+
279286
/**
280287
* Returns the first column that is not null.
281288
* {{{
@@ -287,6 +294,13 @@ object functions {
287294
@scala.annotation.varargs
288295
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
289296

297+
/**
298+
* Converts a string exprsesion to lower case.
299+
*
300+
* @group normal_funcs
301+
*/
302+
def lower(e: Column): Column = Lower(e.expr)
303+
290304
/**
291305
* Unary minus, i.e. negate the expression.
292306
* {{{
@@ -317,18 +331,13 @@ object functions {
317331
def not(e: Column): Column = !e
318332

319333
/**
320-
* Converts a string expression to upper case.
334+
* Partition ID of the Spark task.
321335
*
322-
* @group normal_funcs
323-
*/
324-
def upper(e: Column): Column = Upper(e.expr)
325-
326-
/**
327-
* Converts a string exprsesion to lower case.
336+
* Note that this is indeterministic because it depends on data partitioning and task scheduling.
328337
*
329338
* @group normal_funcs
330339
*/
331-
def lower(e: Column): Column = Lower(e.expr)
340+
def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
332341

333342
/**
334343
* Computes the square root of the specified float value.
@@ -338,11 +347,11 @@ object functions {
338347
def sqrt(e: Column): Column = Sqrt(e.expr)
339348

340349
/**
341-
* Computes the absolutle value.
350+
* Converts a string expression to upper case.
342351
*
343352
* @group normal_funcs
344353
*/
345-
def abs(e: Column): Column = Abs(e.expr)
354+
def upper(e: Column): Column = Upper(e.expr)
346355

347356
//////////////////////////////////////////////////////////////////////////////////////////////
348357
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest {
310310
)
311311
}
312312

313+
test("sparkPartitionId") {
314+
val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
315+
checkAnswer(
316+
df.select(sparkPartitionId()),
317+
Row(0)
318+
)
319+
}
320+
313321
test("lift alias out of cast") {
314322
compareExpressions(
315323
col("1234").as("name").cast("int").expr,

0 commit comments

Comments
 (0)