Skip to content

Commit

Permalink
[SPARK-25921][PYSPARK] Fix barrier task run without BarrierTaskContex…
Browse files Browse the repository at this point in the history
…t while python worker reuse

## What changes were proposed in this pull request?

Running a barrier job after a normal spark job causes the barrier job to run without a BarrierTaskContext. This is because while python worker reuse, BarrierTaskContext._getOrCreate() will still return a TaskContext after firstly submit a normal spark job, we'll get a `AttributeError: 'TaskContext' object has no attribute 'barrier'`. Fix this by adding check logic in BarrierTaskContext._getOrCreate() and make sure it will return BarrierTaskContext in this scenario.

## How was this patch tested?

Add new UT in pyspark-core.

Closes apache#22962 from xuanyuanking/SPARK-25921.

Authored-by: Yuanjian Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
xuanyuanking authored and cloud-fan committed Nov 13, 2018
1 parent e25bce5 commit c00e72f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(self):
@classmethod
def _getOrCreate(cls):
"""Internal function to get or create global BarrierTaskContext."""
if cls._taskContext is None:
cls._taskContext = BarrierTaskContext()
if not isinstance(cls._taskContext, BarrierTaskContext):
cls._taskContext = object.__new__(cls)
return cls._taskContext

@classmethod
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,21 @@ def context_barrier(x):
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 1)

def test_barrier_with_python_worker_reuse(self):
"""
Verify that BarrierTaskContext.barrier() with reused python worker.
"""
self.sc._conf.set("spark.python.work.reuse", "true")
rdd = self.sc.parallelize(range(4), 4)
# start a normal job first to start all worker
result = rdd.map(lambda x: x ** 2).collect()
self.assertEqual([0, 1, 4, 9], result)
# make sure `spark.python.work.reuse=true`
self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")

# worker will be reused in this barrier job
self.test_barrier()

def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
Expand Down

0 comments on commit c00e72f

Please sign in to comment.