Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.network.util;

import java.io.Closeable;
import java.io.EOFException;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -344,4 +346,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) {
}
}

/**
* Fills a buffer with data read from the channel.
*/
public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException {
int expected = dst.remaining();
while (dst.hasRemaining()) {
if (channel.read(dst) < 0) {
throw new EOFException(String.format("Not enough bytes in channel (expected %d).",
expected));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}
import org.apache.spark.storage._
import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

Expand Down Expand Up @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ChunkedByteBuffer] = {
private def readBlocks(): Array[BlockData] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[ChunkedByteBuffer](numBlocks)
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager

for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
Expand Down Expand Up @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = b
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
Expand Down Expand Up @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks().flatMap(_.getChunks())
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

val obj = TorrentBroadcast.unBlockifyObject[T](
blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
try {
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
} finally {
blocks.foreach(_.dispose())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch! we should dispose the blocks here

}
obj
}
}
}
Expand Down Expand Up @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging {
}

def unBlockifyObject[T: ClassTag](
blocks: Array[ByteBuffer],
blocks: Array[InputStream],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration)
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
*/
package org.apache.spark.security

import java.io.{InputStream, OutputStream}
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Properties
import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}

import scala.collection.JavaConverters._

import com.google.common.io.ByteStreams
import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.CryptoUtils
import org.apache.spark.network.util.{CryptoUtils, JavaUtils}

/**
* A util class for manipulating IO encryption and decryption streams.
Expand All @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging {
os: OutputStream,
sparkConf: SparkConf,
key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
val params = new CryptoParams(key, sparkConf)
val iv = createInitializationVector(params.conf)
os.write(iv)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
new IvParameterSpec(iv))
}

/**
* Wrap a `WritableByteChannel` for encryption.
*/
def createWritableChannel(
channel: WritableByteChannel,
sparkConf: SparkConf,
key: Array[Byte]): WritableByteChannel = {
val params = new CryptoParams(key, sparkConf)
val iv = createInitializationVector(params.conf)
val helper = new CryptoHelperChannel(channel)

helper.write(ByteBuffer.wrap(iv))
new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
new IvParameterSpec(iv))
}

/**
Expand All @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging {
is: InputStream,
sparkConf: SparkConf,
key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
ByteStreams.readFully(is, iv)
val params = new CryptoParams(key, sparkConf)
new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
new IvParameterSpec(iv))
}

/**
* Wrap a `ReadableByteChannel` for decryption.
*/
def createReadableChannel(
channel: ReadableByteChannel,
sparkConf: SparkConf,
key: Array[Byte]): ReadableByteChannel = {
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
val buf = ByteBuffer.wrap(iv)
JavaUtils.readFully(channel, buf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use ByteStreams.readFully? the buf is not used else where

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no ByteStreams.readFully for ReadableByteChannel that I'm aware of.


val params = new CryptoParams(key, sparkConf)
new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
new IvParameterSpec(iv))
}

def toCryptoConf(conf: SparkConf): Properties = {
Expand Down Expand Up @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging {
}
iv
}

/**
* This class is a workaround for CRYPTO-125, that forces all bytes to be written to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lousy bug ! Good thing that we dont seem to be hit by it (yet).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a pretty nasty workaround for it in the network library... (the non-blocking workaround is a lot worse than this.)

* underlying channel. Since the callers of this API are using blocking I/O, there are no
* concerns with regards to CPU usage here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a separated bug fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. As the comment states, it's a workaround for a bug in the commons-crypto library, which would affect the code being added.

*/
private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel {

override def write(src: ByteBuffer): Int = {
val count = src.remaining()
while (src.hasRemaining()) {
sink.write(src)
}
count
}

override def isOpen(): Boolean = sink.isOpen()

override def close(): Unit = sink.close()

}

private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {

val keySpec = new SecretKeySpec(key, "AES")
val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
val conf = toCryptoConf(sparkConf)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}

/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

Expand All @@ -167,30 +167,26 @@ private[spark] class SerializerManager(
val byteStream = new BufferedOutputStream(outputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the wrapStream and wrapForEncryption methods can be removed from this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're still used in a bunch of places.

ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}

/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](
blockId: BlockId,
values: Iterator[T],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
allowEncryption = allowEncryption)
values: Iterator[T]): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
}

/** Serializes into a chunked byte buffer. */
def dataSerializeWithExplicitClassTag(
blockId: BlockId,
values: Iterator[_],
classTag: ClassTag[_],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance()
val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}

Expand All @@ -200,15 +196,13 @@ private[spark] class SerializerManager(
*/
def dataDeserializeStream[T](
blockId: BlockId,
inputStream: InputStream,
maybeEncrypted: Boolean = true)
inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
getSerializer(classTag, autoPick)
.newInstance()
.deserializeStream(wrapForCompression(blockId, decrypted))
.deserializeStream(wrapForCompression(blockId, inputStream))
.asIterator.asInstanceOf[Iterator[T]]
}
}
Loading