Skip to content

Commit f03cdfa

Browse files
committed
[SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect.
1 parent 60050f4 commit f03cdfa

File tree

2 files changed

+67
-15
lines changed

2 files changed

+67
-15
lines changed

python/pyspark/sql.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,18 @@ def id(self):
15501550
self._id = self._jrdd.id()
15511551
return self._id
15521552

1553+
def limit(self, num):
1554+
"""Limit the result count to the number specified.
1555+
1556+
>>> srdd = sqlCtx.inferSchema(rdd)
1557+
>>> srdd.limit(2).collect()
1558+
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
1559+
>>> srdd.limit(0).collect()
1560+
[]
1561+
"""
1562+
rdd = self._jschema_rdd.limit(num)
1563+
return SchemaRDD(rdd, self.sql_ctx)
1564+
15531565
def saveAsParquetFile(self, path):
15541566
"""Save the contents as a Parquet file, preserving the schema.
15551567
@@ -1626,15 +1638,40 @@ def count(self):
16261638
return self._jschema_rdd.count()
16271639

16281640
def collect(self):
1629-
"""
1630-
Return a list that contains all of the rows in this RDD.
1641+
"""Return a list that contains all of the rows in this RDD.
16311642
1632-
Each object in the list is on Row, the fields can be accessed as
1643+
Each object in the list is a Row, the fields can be accessed as
16331644
attributes.
1645+
1646+
Unlike the base RDD implementation of collect, this implementation
1647+
leverages the query optimizer to perform a collect on the SchemaRDD,
1648+
which supports features such as filter pushdown.
1649+
1650+
>>> srdd = sqlCtx.inferSchema(rdd)
1651+
>>> srdd.collect()
1652+
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
16341653
"""
1635-
rows = RDD.collect(self)
1654+
from pyspark.context import JavaStackTrace
1655+
with JavaStackTrace(self.context) as st:
1656+
bytesInJava = self._jschema_rdd.collectToPython().iterator()
16361657
cls = _create_cls(self.schema())
1637-
return map(cls, rows)
1658+
return map(cls, self._collect_iterator_through_file(bytesInJava))
1659+
1660+
def take(self, num):
1661+
"""Take the first num rows of the RDD.
1662+
1663+
Each object in the list is a Row, the fields can be accessed as
1664+
attributes.
1665+
1666+
Unlike the base RDD implementation of take, this implementation
1667+
leverages the query optimizer to perform a collect on a SchemaRDD,
1668+
which supports features such as filter pushdown.
1669+
1670+
>>> srdd = sqlCtx.inferSchema(rdd)
1671+
>>> srdd.take(2)
1672+
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
1673+
"""
1674+
return self.limit(num).collect()
16381675

16391676
# Convert each object in the RDD to a Row with the right class
16401677
# for this SchemaRDD, so that fields can be accessed as attributes.

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,15 +377,15 @@ class SchemaRDD(
377377
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
378378

379379
/**
380-
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
380+
* Helper for converting a Row to a simple Array suitable for pyspark serialization.
381381
*/
382-
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
382+
private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
383383
import scala.collection.Map
384384

385385
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
386386
case (null, _) => null
387387

388-
case (obj: Row, struct: StructType) => rowToArray(obj, struct)
388+
case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
389389

390390
case (seq: Seq[Any], array: ArrayType) =>
391391
seq.map(x => toJava(x, array.elementType)).asJava
@@ -402,22 +402,37 @@ class SchemaRDD(
402402
case (other, _) => other
403403
}
404404

405-
def rowToArray(row: Row, structType: StructType): Array[Any] = {
406-
val fields = structType.fields.map(field => field.dataType)
407-
row.zip(fields).map {
408-
case (obj, dataType) => toJava(obj, dataType)
409-
}.toArray
410-
}
405+
val fields = structType.fields.map(field => field.dataType)
406+
row.zip(fields).map {
407+
case (obj, dataType) => toJava(obj, dataType)
408+
}.toArray
409+
}
411410

411+
/**
412+
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
413+
*/
414+
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
412415
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
413416
this.mapPartitions { iter =>
414417
val pickle = new Pickler
415418
iter.map { row =>
416-
rowToArray(row, rowSchema)
419+
rowToJArray(row, rowSchema)
417420
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
418421
}
419422
}
420423

424+
/**
425+
* Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
426+
* format as javaToPython. It is used by pyspark.
427+
*/
428+
private[sql] def collectToPython: JList[Array[Byte]] = {
429+
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
430+
val pickle = new Pickler
431+
new java.util.ArrayList(collect().map { row =>
432+
rowToJArray(row, rowSchema)
433+
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
434+
}
435+
421436
/**
422437
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
423438
* of base RDD functions that do not change schema.

0 commit comments

Comments
 (0)