Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}
import org.apache.spark.storage._
import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

Expand Down Expand Up @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ChunkedByteBuffer] = {
private def readBlocks(): Array[BlockData] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[ChunkedByteBuffer](numBlocks)
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager

for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
Expand Down Expand Up @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = b
blocks(pid) = new ByteBufferBlockData(b)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
Expand Down Expand Up @@ -219,7 +219,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks().flatMap(_.getChunks())
val blocks = readBlocks().map(_.toInputStream())
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

val obj = TorrentBroadcast.unBlockifyObject[T](
Expand Down Expand Up @@ -277,12 +277,11 @@ private object TorrentBroadcast extends Logging {
}

def unBlockifyObject[T: ClassTag](
blocks: Array[ByteBuffer],
blocks: Array[InputStream],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration)
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
Expand Down
100 changes: 89 additions & 11 deletions core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
*/
package org.apache.spark.security

import java.io.{InputStream, OutputStream}
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Properties
import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
Expand Down Expand Up @@ -48,12 +50,30 @@ private[spark] object CryptoStreamUtils extends Logging {
os: OutputStream,
sparkConf: SparkConf,
key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
val params = new CryptoParams(key, sparkConf)
val iv = createInitializationVector(params.conf)
os.write(iv)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
new IvParameterSpec(iv))
}

/**
* Wrap a `WritableByteChannel` for encryption.
*/
def createWritableChannel(
channel: WritableByteChannel,
sparkConf: SparkConf,
key: Array[Byte]): WritableByteChannel = {
val params = new CryptoParams(key, sparkConf)
val iv = createInitializationVector(params.conf)
val buf = ByteBuffer.wrap(iv)
while (buf.remaining() > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: buf.hasRemaining for this pattern of use

channel.write(buf)
}

val helper = new CryptoHelperChannel(channel)
new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
new IvParameterSpec(iv))
}

/**
Expand All @@ -63,12 +83,40 @@ private[spark] object CryptoStreamUtils extends Logging {
is: InputStream,
sparkConf: SparkConf,
key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
var read = 0
while (read < iv.length) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this while loop do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It avoids issues with short reads. It's unlikely to happen but I always write read code like this to be safe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you can just use ByteStreams.readFully(is, iv).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, missed that one. +1 for shorter code.

val _read = is.read(iv, 0, iv.length)
if (_read < 0) {
throw new EOFException("Failed to read IV from stream.")
}
read += _read
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ByteStreams.readFully instead of the loop


val params = new CryptoParams(key, sparkConf)
new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
new IvParameterSpec(iv))
}

/**
* Wrap a `ReadableByteChannel` for decryption.
*/
def createReadableChannel(
channel: ReadableByteChannel,
sparkConf: SparkConf,
key: Array[Byte]): ReadableByteChannel = {
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
val buf = ByteBuffer.wrap(iv)
buf.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The clear is not required.

while (buf.remaining() > 0) {
if (channel.read(buf) < 0) {
throw new EOFException("Failed to read IV from channel.")
}
}

val params = new CryptoParams(key, sparkConf)
new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
new IvParameterSpec(iv))
}

def toCryptoConf(conf: SparkConf): Properties = {
Expand Down Expand Up @@ -102,4 +150,34 @@ private[spark] object CryptoStreamUtils extends Logging {
}
iv
}

/**
* This class is a workaround for CRYPTO-125, that forces all bytes to be written to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lousy bug ! Good thing that we dont seem to be hit by it (yet).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a pretty nasty workaround for it in the network library... (the non-blocking workaround is a lot worse than this.)

* underlying channel. Since the callers of this API are using blocking I/O, there are no
* concerns with regards to CPU usage here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a separated bug fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. As the comment states, it's a workaround for a bug in the commons-crypto library, which would affect the code being added.

*/
private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel {

override def write(src: ByteBuffer): Int = {
val count = src.remaining()
while (src.remaining() > 0) {
sink.write(src)
}
count
}

override def isOpen(): Boolean = sink.isOpen()

override def close(): Unit = sink.close()

}

private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {

val keySpec = new SecretKeySpec(key, "AES")
val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
val conf = toCryptoConf(sparkConf)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}

/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

Expand All @@ -167,30 +167,26 @@ private[spark] class SerializerManager(
val byteStream = new BufferedOutputStream(outputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the wrapStream and wrapForEncryption methods can be removed from this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're still used in a bunch of places.

ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}

/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](
blockId: BlockId,
values: Iterator[T],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
allowEncryption = allowEncryption)
values: Iterator[T]): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
}

/** Serializes into a chunked byte buffer. */
def dataSerializeWithExplicitClassTag(
blockId: BlockId,
values: Iterator[_],
classTag: ClassTag[_],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance()
val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}

Expand All @@ -200,15 +196,13 @@ private[spark] class SerializerManager(
*/
def dataDeserializeStream[T](
blockId: BlockId,
inputStream: InputStream,
maybeEncrypted: Boolean = true)
inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
getSerializer(classTag, autoPick)
.newInstance()
.deserializeStream(wrapForCompression(blockId, decrypted))
.deserializeStream(wrapForCompression(blockId, inputStream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
Loading