Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@

def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
# If this certain accumulator was deserialized, don't overwrite it.
if aid in _accumulatorRegistry:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be if aid in _accumulatorRegistry and _accumulatorRegistry[aid]._deserialized is True
or:

if aid in _accumulatorRegistry:
    _accumulatorRegistry[aid]._deserialize = True
    return _accumulatorRegistry[aid]

To make double sure that this function always returns a deserialize version of the accum ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only save deserialized accumulators (_deserialized is True) into this dict.

Copy link

@AbdealiLoKo AbdealiLoKo Oct 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesnt seem right because the constructor for Accumulator has:

        ...
        self._deserialized = False
        _accumulatorRegistry[aid] = self

PS: First time Im looking at this code, so not too familiar with it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but _deserialize_accumulator is only called when doing deserialzation at executors. The constructor saves accumulators in _accumulatorRegistry at driver.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - got it 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so the problem is this accumulator is de/serialized multiple times and _deserialize_accumulator modifies the global status multiple times. I see. LGTM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

return _accumulatorRegistry[aid]
else:
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum


class Accumulator(object):
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3603,6 +3603,31 @@ def test_repr_behaviors(self):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)

test_accum = self.sc.accumulator(0)

def first_udf(x):
test_accum.add(1)
return x

def second_udf(x):
test_accum.add(100)
return x

func_udf = udf(first_udf, IntegerType())
func_udf2 = udf(second_udf, IntegerType())
data = data.withColumn("out1", func_udf(data["a"]))
data = data.withColumn("out2", func_udf2(data["b"]))
data.collect()
self.assertEqual(test_accum.value, 101)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya, can we just use int for data and accumulator as well in this test case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.



class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down