Skip to content

Commit 4f8309d

Browse files
committed
address comment, add tests
1 parent 0a5b6eb commit 4f8309d

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

python/pyspark/accumulators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def addInPlace(self, value1, value2):
215215
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
216216

217217

218-
class StatsParam(AccumulatorParam):
219-
"""StatsParam is used to merge pstats.Stats"""
218+
class PStatsParam(AccumulatorParam):
219+
"""PStatsParam is used to merge pstats.Stats"""
220220

221221
@staticmethod
222222
def zero(value):

python/pyspark/rdd.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
from base64 import standard_b64encode as b64enc
1918
import copy
2019
from collections import defaultdict
2120
from collections import namedtuple
@@ -35,7 +34,7 @@
3534
from random import Random
3635
from math import sqrt, log, isinf, isnan
3736

38-
from pyspark.accumulators import StatsParam
37+
from pyspark.accumulators import PStatsParam
3938
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
4039
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
4140
PickleSerializer, pack_long, CompressedSerializer
@@ -2114,7 +2113,7 @@ def _jrdd(self):
21142113
if self._bypass_serializer:
21152114
self._jrdd_deserializer = NoOpSerializer()
21162115
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
2117-
profileStats = self.ctx.accumulator(None, StatsParam) if enable_profile else None
2116+
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
21182117
command = (self.func, profileStats, self._prev_jrdd_deserializer,
21192118
self._jrdd_deserializer)
21202119
ser = CloudPickleSerializer()

python/pyspark/tests.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,31 @@ def test_repartitionAndSortWithinPartitions(self):
554554
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
555555

556556

557+
class TestProfiler(PySparkTestCase):
558+
559+
def setUp(self):
560+
self._old_sys_path = list(sys.path)
561+
class_name = self.__class__.__name__
562+
conf = SparkConf().set("spark.python.profile", "true")
563+
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
564+
565+
def test_profiler(self):
566+
567+
def heavy_foo(x):
568+
for i in range(1 << 20):
569+
x = 1
570+
rdd = self.sc.parallelize(range(100)).foreach(heavy_foo)
571+
from pyspark.rdd import PipelinedRDD
572+
profiles = PipelinedRDD._created_profiles
573+
self.assertEqual(1, len(profiles))
574+
id, acc = profiles.pop()
575+
stats = acc.value
576+
self.assertTrue(stats is not None)
577+
width, stat_list = stats.get_print_list([])
578+
func_names = [func_name for fname, n, func_name in stat_list]
579+
self.assertTrue("heavy_foo" in func_names)
580+
581+
557582
class TestSQL(PySparkTestCase):
558583

559584
def setUp(self):

0 commit comments

Comments
 (0)