diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7b31fa93c32e..14d5df14ed85 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -403,6 +403,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // the decrypted data to python val idsAndFiles = broadcastVars.flatMap { broadcast => if (!oldBids.contains(broadcast.id)) { + oldBids.add(broadcast.id) Some((broadcast.id, broadcast.value.path)) } else { None @@ -416,7 +417,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( idsAndFiles.foreach { case (id, _) => // send new broadcast dataOut.writeLong(id) - oldBids.add(id) } dataOut.flush() logTrace("waiting for python to read decrypted broadcast data from server") diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index 8185e812e66b..bc4587ffa645 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -26,6 +26,7 @@ from pyspark import SparkConf, SparkContext, Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import ChunkedStream +from pyspark.sql import SparkSession, Row class BroadcastTest(unittest.TestCase): @@ -126,6 +127,19 @@ def test_broadcast_for_error_condition(self): with self.assertRaisesRegex(Py4JJavaError, "RuntimeError.*Broadcast.*unpersisted.*driver"): self.sc.parallelize([1]).map(lambda x: bs.unpersist()).collect() + def test_broadcast_in_udfs_with_encryption(self): + conf = SparkConf() + conf.set("spark.io.encryption.enabled", "true") + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bar = {"a": "aa", "b": "bb"} + foo = self.sc.broadcast(bar) + spark = SparkSession(self.sc) + spark.udf.register("MYUDF", lambda x: foo.value[x] if x else "") + sel = spark.sql("SELECT MYUDF('a') AS a, MYUDF('b') AS b") + self.assertEqual(sel.collect(), [Row(a="aa", b="bb")]) + spark.stop() + class BroadcastFrameProtocolTest(unittest.TestCase): @classmethod