Skip to content

Commit fd41eb9

Browse files
jbencookJoshRosen
authored andcommitted
[SPARK-4860][pyspark][sql] speeding up sample() and takeSample()
This PR modifies the python `SchemaRDD` to use `sample()` and `takeSample()` from Scala instead of the slower python implementations from `rdd.py`. This is worthwhile because the `Row`'s are already serialized as Java objects. In order to use the faster `takeSample()`, a `takeSampleToPython()` method was implemented in `SchemaRDD.scala` following the pattern of `collectToPython()`. Author: jbencook <[email protected]> Author: J. Benjamin Cook <[email protected]> Closes apache#3764 from jbencook/master and squashes the following commits: 6fbc769 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments 5170da2 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD de22f70 [jbencook] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD b916442 [jbencook] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD 020cbdf [jbencook] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()`
1 parent 7e2deb7 commit fd41eb9

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

python/pyspark/sql.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,6 +2085,34 @@ def subtract(self, other, numPartitions=None):
20852085
else:
20862086
raise ValueError("Can only subtract another SchemaRDD")
20872087

2088+
def sample(self, withReplacement, fraction, seed=None):
2089+
"""
2090+
Return a sampled subset of this SchemaRDD.
2091+
2092+
>>> srdd = sqlCtx.inferSchema(rdd)
2093+
>>> srdd.sample(False, 0.5, 97).count()
2094+
2L
2095+
"""
2096+
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
2097+
seed = seed if seed is not None else random.randint(0, sys.maxint)
2098+
rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
2099+
return SchemaRDD(rdd, self.sql_ctx)
2100+
2101+
def takeSample(self, withReplacement, num, seed=None):
2102+
"""Return a fixed-size sampled subset of this SchemaRDD.
2103+
2104+
>>> srdd = sqlCtx.inferSchema(rdd)
2105+
>>> srdd.takeSample(False, 2, 97)
2106+
[Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
2107+
"""
2108+
seed = seed if seed is not None else random.randint(0, sys.maxint)
2109+
with SCCallSiteSync(self.context) as css:
2110+
bytesInJava = self._jschema_rdd.baseSchemaRDD() \
2111+
.takeSampleToPython(withReplacement, num, long(seed)) \
2112+
.iterator()
2113+
cls = _create_cls(self.schema())
2114+
return map(cls, self._collect_iterator_through_file(bytesInJava))
2115+
20882116

20892117
def _test():
20902118
import doctest

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,21 @@ class SchemaRDD(
437437
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
438438
}
439439

440+
/**
441+
* Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
442+
* format as javaToPython and collectToPython. It is used by pyspark.
443+
*/
444+
private[sql] def takeSampleToPython(
445+
withReplacement: Boolean,
446+
num: Int,
447+
seed: Long): JList[Array[Byte]] = {
448+
val fieldTypes = schema.fields.map(_.dataType)
449+
val pickle = new Pickler
450+
new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
451+
EvaluatePython.rowToArray(row, fieldTypes)
452+
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
453+
}
454+
440455
/**
441456
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
442457
* of base RDD functions that do not change schema.

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,10 @@ class JavaSchemaRDD(
218218
*/
219219
def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD =
220220
this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD
221+
222+
/**
223+
* Return a SchemaRDD with a sampled version of the underlying dataset.
224+
*/
225+
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD =
226+
this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD
221227
}

0 commit comments

Comments
 (0)