Skip to content

Commit d053a31

Browse files
animeshrxin
authored andcommitted
[SPARK-7980] [SQL] Support SQLContext.range(end)
1. range() overloaded in SQLContext.scala 2. range() modified in python sql context.py 3. Tests added accordingly in DataFrameSuite.scala and python sql tests.py Author: animesh <[email protected]> Closes #6609 from animeshbaranawal/SPARK-7980 and squashes the following commits: 935899c [animesh] SPARK-7980:python+scala changes
1 parent 2c4d550 commit d053a31

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

python/pyspark/sql/context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def udf(self):
131131
return UDFRegistration(self)
132132

133133
@since(1.4)
134-
def range(self, start, end, step=1, numPartitions=None):
134+
def range(self, start, end=None, step=1, numPartitions=None):
135135
"""
136136
Create a :class:`DataFrame` with single LongType column named `id`,
137137
containing elements in a range from `start` to `end` (exclusive) with
@@ -145,10 +145,18 @@ def range(self, start, end, step=1, numPartitions=None):
145145
146146
>>> sqlContext.range(1, 7, 2).collect()
147147
[Row(id=1), Row(id=3), Row(id=5)]
148+
149+
>>> sqlContext.range(3).collect()
150+
[Row(id=0), Row(id=1), Row(id=2)]
148151
"""
149152
if numPartitions is None:
150153
numPartitions = self._sc.defaultParallelism
151-
jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
154+
155+
if end is None:
156+
jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
157+
else:
158+
jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
159+
152160
return DataFrame(jdf, self)
153161

154162
@ignore_unicode_prefix

python/pyspark/sql/tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def test_range(self):
131131
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
132132
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
133133
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
134+
self.assertEqual(self.sqlCtx.range(-2).count(), 0)
135+
self.assertEqual(self.sqlCtx.range(3).count(), 3)
134136

135137
def test_explode(self):
136138
from pyspark.sql.functions import explode

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
717717
StructType(StructField("id", LongType, nullable = false) :: Nil))
718718
}
719719

720+
/**
721+
* :: Experimental ::
722+
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
723+
* in an range from 0 to `end`(exclusive) with step value 1.
724+
*
725+
* @since 1.4.0
726+
* @group dataframe
727+
*/
728+
@Experimental
729+
def range(end: Long): DataFrame = range(0, end)
730+
720731
/**
721732
* :: Experimental ::
722733
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,5 +576,13 @@ class DataFrameSuite extends QueryTest {
576576
val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
577577
assert(res9.count == 2)
578578
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
579+
580+
// only end provided as argument
581+
val res10 = TestSQLContext.range(10).select("id")
582+
assert(res10.count == 10)
583+
assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
584+
585+
val res11 = TestSQLContext.range(-1).select("id")
586+
assert(res11.count == 0)
579587
}
580588
}

0 commit comments

Comments
 (0)