Skip to content

Commit 3af53e6

Browse files
zsxwingMarcelo Vanzin
authored andcommitted
[SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly
`ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu <[email protected]> Closes #10083 from zsxwing/bytebuffer-array.
1 parent f30373f commit 3af53e6

File tree

22 files changed

+81
-69
lines changed

22 files changed

+81
-69
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
3030
import org.apache.spark.network.server._
3131
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
3232
import org.apache.spark.network.shuffle.protocol.UploadBlock
33+
import org.apache.spark.network.util.JavaUtils
3334
import org.apache.spark.serializer.JavaSerializer
3435
import org.apache.spark.storage.{BlockId, StorageLevel}
3536
import org.apache.spark.util.Utils
@@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
123124

124125
// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
125126
// using our binary protocol.
126-
val levelBytes = serializer.newInstance().serialize(level).array()
127+
val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))
127128

128129
// Convert or copy nio buffer into array in order to serialize it.
129-
val nioBuffer = blockData.nioByteBuffer()
130-
val array = if (nioBuffer.hasArray) {
131-
nioBuffer.array()
132-
} else {
133-
val data = new Array[Byte](nioBuffer.remaining())
134-
nioBuffer.get(data)
135-
data
136-
}
130+
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())
137131

138132
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
139133
new RpcResponseCallback {

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils
3434
import org.apache.spark._
3535
import org.apache.spark.broadcast.Broadcast
3636
import org.apache.spark.executor.TaskMetrics
37+
import org.apache.spark.network.util.JavaUtils
3738
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
3839
import org.apache.spark.rdd.RDD
3940
import org.apache.spark.rpc.RpcTimeout
@@ -997,9 +998,10 @@ class DAGScheduler(
997998
// For ResultTask, serialize and broadcast (rdd, func).
998999
val taskBinaryBytes: Array[Byte] = stage match {
9991000
case stage: ShuffleMapStage =>
1000-
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
1001+
JavaUtils.bufferToArray(
1002+
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
10011003
case stage: ResultStage =>
1002-
closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
1004+
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
10031005
}
10041006

10051007
taskBinary = sc.broadcast(taskBinaryBytes)

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ private[spark] object Task {
191191

192192
// Write the task itself and finish
193193
dataOut.flush()
194-
val taskBytes = serializer.serialize(task).array()
195-
out.write(taskBytes)
194+
val taskBytes = serializer.serialize(task)
195+
Utils.writeByteBuffer(taskBytes, out)
196196
ByteBuffer.wrap(out.toByteArray)
197197
}
198198

core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
8181
* seen values so to limit the number of times that decompression has to be done.
8282
*/
8383
def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
84-
val bis = new ByteArrayInputStream(schemaBytes.array())
84+
val bis = new ByteArrayInputStream(
85+
schemaBytes.array(),
86+
schemaBytes.arrayOffset() + schemaBytes.position(),
87+
schemaBytes.remaining())
8588
val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
8689
new Schema.Parser().parse(new String(bytes, "UTF-8"))
8790
})

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
309309
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
310310
val kryo = borrowKryo()
311311
try {
312-
input.setBuffer(bytes.array)
312+
input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
313313
kryo.readClassAndObject(input).asInstanceOf[T]
314314
} finally {
315315
releaseKryo(kryo)
@@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
321321
val oldClassLoader = kryo.getClassLoader
322322
try {
323323
kryo.setClassLoader(loader)
324-
input.setBuffer(bytes.array)
324+
input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
325325
kryo.readClassAndObject(input).asInstanceOf[T]
326326
} finally {
327327
kryo.setClassLoader(oldClassLoader)

core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
103103
val file = getFile(blockId)
104104
val os = file.getOutStream(WriteType.TRY_CACHE)
105105
try {
106-
os.write(bytes.array())
106+
Utils.writeByteBuffer(bytes, os)
107107
} catch {
108108
case NonFatal(e) =>
109109
logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,20 @@ private[spark] object Utils extends Logging {
178178
/**
179179
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
180180
*/
181-
def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
181+
def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
182+
if (bb.hasArray) {
183+
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
184+
} else {
185+
val bbval = new Array[Byte](bb.remaining())
186+
bb.get(bbval)
187+
out.write(bbval)
188+
}
189+
}
190+
191+
/**
192+
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
193+
*/
194+
def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
182195
if (bb.hasArray) {
183196
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
184197
} else {

core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.apache.spark.executor.ShuffleWriteMetrics;
4242
import org.apache.spark.memory.TestMemoryManager;
4343
import org.apache.spark.memory.TaskMemoryManager;
44+
import org.apache.spark.network.util.JavaUtils;
4445
import org.apache.spark.serializer.SerializerInstance;
4546
import org.apache.spark.storage.*;
4647
import org.apache.spark.unsafe.Platform;
@@ -430,7 +431,7 @@ public void randomizedStressTest() {
430431
}
431432

432433
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
433-
final byte[] key = entry.getKey().array();
434+
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
434435
final byte[] value = entry.getValue();
435436
final BytesToBytesMap.Location loc =
436437
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
@@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
480481
}
481482
}
482483
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
483-
final byte[] key = entry.getKey().array();
484+
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
484485
final byte[] value = entry.getValue();
485486
final BytesToBytesMap.Location loc =
486487
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.mockito.Matchers.any
2323
import org.scalatest.BeforeAndAfter
2424

2525
import org.apache.spark._
26+
import org.apache.spark.network.util.JavaUtils
2627
import org.apache.spark.rdd.RDD
2728
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
2829
import org.apache.spark.metrics.source.JvmSource
@@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
5758
}
5859
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
5960
val func = (c: TaskContext, i: Iterator[String]) => i.next()
60-
val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
61+
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
6162
val task = new ResultTask[String, String](
6263
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
6364
intercept[RuntimeException] {

examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable {
7979

8080
def unpackBytes(obj: Any): Array[Byte] = {
8181
val bytes: Array[Byte] = obj match {
82-
case buf: java.nio.ByteBuffer => buf.array()
82+
case buf: java.nio.ByteBuffer =>
83+
val arr = new Array[Byte](buf.remaining())
84+
buf.get(arr)
85+
arr
8386
case arr: Array[Byte] => arr
8487
case other => throw new SparkException(
8588
s"Unknown BYTES type ${other.getClass.getName}")

0 commit comments

Comments
 (0)