Skip to content

Commit 3f3201a

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-37203][SQL] Fix NotSerializableException when observe with TypedImperativeAggregate
### What changes were proposed in this pull request? Currently, ``` val namedObservation = Observation("named") val df = spark.range(100) val observed_df = df.observe( namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) observed_df.collect() namedObservation.get ``` throws exception as follows: ``` 15:16:27.994 ERROR org.apache.spark.util.Utils: Exception encountered java.io.NotSerializableException: org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile$PercentileDigest at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1184) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$2(TaskResult.scala:55) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$2$adapted(TaskResult.scala:55) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at scala.collection.IterableLike.foreach(IterableLike.scala:74) at scala.collection.IterableLike.foreach$(IterableLike.scala:73) at scala.collection.AbstractIterable.foreach(Iterable.scala:56) at org.apache.spark.scheduler.DirectTaskResult.$anonfun$writeExternal$1(TaskResult.scala:55) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1434) at org.apache.spark.scheduler.DirectTaskResult.writeExternal(TaskResult.scala:51) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:44) at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:101) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:616) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` This PR will fix the issue. After the change, `assert(namedObservation.get === Map("percentile_approx_val" -> 49))` `java.io.NotSerializableException` will not happen. ### Why are the changes needed? Fix `NotSerializableException` when observe with `TypedImperativeAggregate`. ### Does this PR introduce _any_ user-facing change? No. This PR change the implement of `AggregatingAccumulator` who uses serialize and deserialize of `TypedImperativeAggregate` now. ### How was this patch tested? New tests. Closes apache#34474 from beliefer/SPARK-37203. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2b4099f commit 3f3201a

File tree

5 files changed

+46
-8
lines changed

5 files changed

+46
-8
lines changed

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
157157
*/
158158
def value: OUT
159159

160+
// Serialize the buffer of this accumulator before sending back this accumulator to the driver.
161+
// By default this method does nothing.
162+
protected def withBufferSerialized(): AccumulatorV2[IN, OUT] = this
163+
160164
// Called by Java when serializing an object
161165
final protected def writeReplace(): Any = {
162166
if (atDriverSide) {
@@ -179,7 +183,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
179183
}
180184
copyAcc
181185
} else {
182-
this
186+
withBufferSerialized()
183187
}
184188
}
185189

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ object ApproximatePercentile {
284284
}
285285

286286
/**
287-
* Serializer for class [[PercentileDigest]]
287+
* Serializer for class [[PercentileDigest]]
288288
*
289289
* This class is thread safe.
290290
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,18 @@ class AggregatingAccumulator private(
163163
i += 1
164164
}
165165
i = 0
166-
while (i < typedImperatives.length) {
167-
typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer)
168-
i += 1
166+
if (isAtDriverSide) {
167+
while (i < typedImperatives.length) {
168+
// The input buffer stores serialized data
169+
typedImperatives(i).merge(buffer, otherBuffer)
170+
i += 1
171+
}
172+
} else {
173+
while (i < typedImperatives.length) {
174+
// The input buffer stores deserialized object
175+
typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer)
176+
i += 1
177+
}
169178
}
170179
case _ =>
171180
throw QueryExecutionErrors.cannotMergeClassWithOtherClassError(
@@ -188,6 +197,17 @@ class AggregatingAccumulator private(
188197
resultProjection(input)
189198
}
190199

200+
override def withBufferSerialized(): AggregatingAccumulator = {
201+
assert(!isAtDriverSide)
202+
var i = 0
203+
// AggregatingAccumulator runs on executor, we should serialize all TypedImperativeAggregate.
204+
while (i < typedImperatives.length) {
205+
typedImperatives(i).serializeAggregateBufferInPlace(buffer)
206+
i += 1
207+
}
208+
this
209+
}
210+
191211
/**
192212
* Get the output schema of the aggregating accumulator.
193213
*/

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,19 @@ class DatasetSuite extends QueryTest
754754
assert(err2.getMessage.contains("Name must not be empty"))
755755
}
756756

757+
test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") {
758+
def observe[T](df: Dataset[T], expected: Map[String, _]): Unit = {
759+
val namedObservation = Observation("named")
760+
val observed_df = df.observe(
761+
namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val"))
762+
observed_df.collect()
763+
assert(namedObservation.get === expected)
764+
}
765+
766+
observe(spark.range(100), Map("percentile_approx_val" -> 49))
767+
observe(spark.range(0), Map("percentile_approx_val" -> null))
768+
}
769+
757770
test("sample with replacement") {
758771
val n = 100
759772
val data = sparkContext.parallelize(1 to n, 2).toDS()

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
417417
min($"value").as("min_val"),
418418
max($"value").as("max_val"),
419419
sum($"value").as("sum_val"),
420-
count(when($"value" % 2 === 0, 1)).as("num_even"))
420+
count(when($"value" % 2 === 0, 1)).as("num_even"),
421+
percentile_approx($"value", lit(0.5), lit(100)).as("percentile_approx_val"))
421422
.observe(
422423
name = "other_event",
423424
avg($"value").cast("int").as("avg_val"))
@@ -444,15 +445,15 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
444445
AddData(inputData, 1, 2),
445446
AdvanceManualClock(100),
446447
checkMetrics { metrics =>
447-
assert(metrics.get("my_event") === Row(1, 2, 3L, 1L))
448+
assert(metrics.get("my_event") === Row(1, 2, 3L, 1L, 1))
448449
assert(metrics.get("other_event") === Row(1))
449450
},
450451

451452
// Batch 2
452453
AddData(inputData, 10, 30, -10, 5),
453454
AdvanceManualClock(100),
454455
checkMetrics { metrics =>
455-
assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L))
456+
assert(metrics.get("my_event") === Row(-10, 30, 35L, 3L, 5))
456457
assert(metrics.get("other_event") === Row(8))
457458
},
458459

0 commit comments

Comments
 (0)