diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166e..51d7fda0cb260 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,9 +18,11 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; @@ -344,4 +346,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645d..039df75ce74fd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e50eb6635651..77005aa9040b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -485,12 +485,17 @@ object SparkSubmit extends CommandLineUtils { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index cdd3b8d8512b1..78dabb42ac9d2 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import scala.collection.JavaConverters._ +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.network.util.CryptoUtils +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging { os: OutputStream, sparkConf: SparkConf, key: Array[Byte]): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) } /** @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging { is: InputStream, sparkConf: SparkConf, key: Array[Byte]): InputStream = { - val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `ReadableByteChannel` for decryption. + */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) } def toCryptoConf(conf: SparkConf): Properties = { @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 96b288b9cfb81..bb7ed8709ba8a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -148,14 +148,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -167,30 +167,26 @@ private[spark] class SerializerManager( val byteStream = new BufferedOutputStream(outputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag]( blockId: BlockId, - values: Iterator[T], - allowEncryption: Boolean = true): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], - allowEncryption = allowEncryption) + values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_], - allowEncryption: Boolean = true): ChunkedByteBuffer = { + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(classTag, autoPick).newInstance() - val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream - ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -200,15 +196,13 @@ private[spark] class SerializerManager( */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream, - maybeEncrypted: Boolean = true) + inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] - val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, decrypted)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991346a40af4e..fcda9fa65303a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -55,6 +56,55 @@ private[spark] class BlockResult( val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * @see [[ManagedBuffer#convertToNetty()]] + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -94,7 +144,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -304,7 +354,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -463,21 +514,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -488,7 +540,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -496,9 +548,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -510,7 +562,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -525,17 +577,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -761,43 +815,15 @@ private[spark] class BlockManager( * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing * so may corrupt or change the data stored by the `BlockManager`. * - * @param encrypt If true, asks the block manager to encrypt the data block before storing, - * when I/O encryption is enabled. This is required for blocks that have been - * read from unencrypted sources, since all the BlockManager read APIs - * automatically do decryption. * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - encrypt: Boolean = false): Boolean = { + tellMaster: Boolean = true): Boolean = { require(bytes != null, "Bytes is null") - - val bytesToStore = - if (encrypt && securityManager.ioEncryptionKey.isDefined) { - try { - val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data) - val byteBufOut = new ByteBufferOutputStream(data.remaining()) - val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, - securityManager.ioEncryptionKey.get) - try { - ByteStreams.copy(in, out) - } finally { - in.close() - out.close() - } - new ChunkedByteBuffer(byteBufOut.toByteBuffer) - } finally { - bytes.dispose() - } - } else { - bytes - } - - doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -828,8 +854,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -1008,8 +1035,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -1024,8 +1052,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -1035,8 +1064,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1065,7 +1095,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.unmap() + bytesToReplicate.dispose() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) @@ -1089,29 +1119,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1203,7 +1233,7 @@ private[spark] class BlockManager( replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { logDebug(s"Releasing lock for $blockId") - releaseLock(blockId) + releaseLockAndDispose(blockId, data) } } } @@ -1214,7 +1244,7 @@ private[spark] class BlockManager( */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, classTag: ClassTag[_], existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { @@ -1256,7 +1286,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1339,10 +1369,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1434,6 +1465,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f942798550..1ea0d378cbe87 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed02..c6656341fcd15 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 5efdd23f79a21..241aacd74b586 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,14 +236,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - // Ewwww... Reflection!!! See the unmap method for justification - private val memoryMappedBufferFileDescriptorField = { - val mappedBufferClass = classOf[java.nio.MappedByteBuffer] - val fdField = mappedBufferClass.getDeclaredField("fd") - fdField.setAccessible(true) - fdField - } - /** * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun * API that will cause errors if one attempts to read from the disposed buffer. However, neither @@ -251,8 +243,6 @@ private[spark] object StorageUtils extends Logging { * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of * off-heap memory or huge numbers of open files. There's unfortunately no standard API to * manually dispose of these kinds of buffers. - * - * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { @@ -261,28 +251,6 @@ private[spark] object StorageUtils extends Logging { } } - /** - * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will - * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of - * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage - * collection may lead to huge numbers of open files. There's unfortunately no standard API to - * manually unmap memory-mapped buffers. - * - * See also [[dispose]] - */ - def unmap(buffer: ByteBuffer): Unit = { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the - // JDK does not provide a public API to distinguish between direct buffers and memory-mapped - // buffers. As an alternative, we peek beneath the curtains and look for a non-null file - // descriptor in mappedByteBuffer - if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { - logTrace(s"Unmapping $buffer") - cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) - } - } - } - private def cleanDirectBuffer(buffer: DirectBuffer) = { val cleaner = buffer.cleaner() if (cleaner != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fb54dd66a39a9..90e3af2d0ec74 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -344,7 +344,7 @@ private[spark] class MemoryStore( val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 1667516663b35..2f905c8af0f63 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -138,8 +138,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { /** * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. * See [[StorageUtils.dispose]] for more information. - * - * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -148,18 +146,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } } - /** - * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See - * [[StorageUtils.unmap]] for more information. - * - * See also [[dispose]] - */ - def unmap(): Unit = { - if (!disposed) { - chunks.foreach(StorageUtils.unmap) - disposed = true - } - } } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4e36adc8baf3f..84f7f1fc8eb09 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,7 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6646068d5080b..82760fe92f76a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +63,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -137,9 +140,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("Cache broadcast to disk") { - val conf = new SparkConf() - .setMaster("local") + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") .setAppName("test") .set("spark.memory.useLegacyMode", "true") .set("spark.storage.memoryFraction", "0.0") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9417930d02405..a591b98bca488 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -213,7 +213,12 @@ class SparkSubmitSuite childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") mainClass should be ("org.apache.spark.deploy.yarn.Client") - classpath should have length (0) + + // In yarn cluster mode, also adding jars to classpath + classpath(0) should endWith ("thejar.jar") + classpath(1) should endWith ("one.jar") + classpath(2) should endWith ("two.jar") + classpath(3) should endWith ("three.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.driver.memory") should be ("4g") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0f3a4a03618ed..608052f5ed855 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams @@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() + } + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 0000000000000..3f52dc41abf6d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 64a67b4c4cbab..a8b9604899838 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1337,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4d..67fc084e8a13d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/dev/run-pip-tests b/dev/run-pip-tests index af1b1feb70cd1..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -35,9 +35,28 @@ function delete_virtualenv() { } trap delete_virtualenv EXIT +PYTHON_EXECS=() # Some systems don't have pip or virtualenv - in those cases our tests won't work. -if ! hash virtualenv 2>/dev/null; then - echo "Missing virtualenv skipping pip installability tests." +if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then + echo "virtualenv installed - using. Note if this is a conda virtual env you may wish to set USE_CONDA" + # Figure out which Python execs we should test pip installation with + if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') + elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') + fi + if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') + fi +elif hash conda 2>/dev/null; then + echo "Using conda virtual enviroments" + PYTHON_EXECS=('3.5') + USE_CONDA=1 +else + echo "Missing virtualenv & conda, skipping pip installability tests" exit 0 fi if ! hash pip 2>/dev/null; then @@ -45,22 +64,8 @@ if ! hash pip 2>/dev/null; then exit 0 fi -# Figure out which Python execs we should test pip installation with -PYTHON_EXECS=() -if hash python2 2>/dev/null; then - # We do this since we are testing with virtualenv and the default virtual env python - # is in /usr/bin/python - PYTHON_EXECS+=('python2') -elif hash python 2>/dev/null; then - # If python2 isn't installed fallback to python if available - PYTHON_EXECS+=('python') -fi -if hash python3 2>/dev/null; then - PYTHON_EXECS+=('python3') -fi - # Determine which version of PySpark we are building for archive name -PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_VERSION=$(python3 -c "exec(open('python/pyspark/version.py').read());print(__version__)") PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" # The pip install options we use for all the pip commands PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " @@ -75,18 +80,24 @@ for python in "${PYTHON_EXECS[@]}"; do echo "Using $VIRTUALENV_BASE for virtualenv" VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python rm -rf "$VIRTUALENV_PATH" - mkdir -p "$VIRTUALENV_PATH" - virtualenv --python=$python "$VIRTUALENV_PATH" - source "$VIRTUALENV_PATH"/bin/activate - # Upgrade pip & friends - pip install --upgrade pip pypandoc wheel - pip install numpy # Needed so we can verify mllib imports + if [ -n "$USE_CONDA" ]; then + conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools + source activate "$VIRTUALENV_PATH" + else + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + fi + # Upgrade pip & friends if using virutal env + if [ ! -n "USE_CONDA" ]; then + pip install --upgrade pip pypandoc wheel numpy + fi echo "Creating pip installable source dist" cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - $python setup.py sdist + python setup.py sdist echo "Installing dist into virtual env" @@ -112,6 +123,13 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" + # conda / virtualenv enviroments need to be deactivated differently + if [ -n "$USE_CONDA" ]; then + source deactivate + else + deactivate + fi + done done diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index e79accf9e987a..f41f1ac79e381 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,8 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -FWDIR="$(cd "`dirname $0`"/..; pwd)" +FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" +export PATH=/home/anaconda/bin:$PATH exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py index 53a0aef229b08..b2e50435bb192 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -111,9 +111,9 @@ def run_individual_python_test(test_name, pyspark_python): def get_default_python_executables(): - python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] - if "python2.6" not in python_execs: - LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + python_execs = [x for x in ["python2.7", "python3.4", "pypy"] if which(x)] + if "python2.7" not in python_execs: + LOGGER.warning("Not testing against `python2.7` because it could not be found; falling" " back to `python` instead") python_execs.insert(0, "python") return python_execs diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala index 5571df09a2ec9..5adeb8e605ff4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { @@ -36,7 +37,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val obtainToken = mirror.classLoader. loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). getMethod("obtainToken", classOf[Configuration]) @@ -60,7 +61,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide private def hbaseConf(conf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val confCreate = mirror.classLoader. loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). getMethod("create", classOf[Configuration]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1c7720afe1ca3..da37eb00dcd97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,8 @@ object ScalaReflection extends ScalaReflection { } } - val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + val cls = companion.declaration(newTermName("newBuilder")) match { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index f14df93160b75..b32374c5742ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -24,6 +24,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -104,12 +105,23 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) Some(percent1 + percent2 - (percent1 * percent2)) + // Not-operator pushdown case Not(And(cond1, cond2)) => calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + // Not-operator pushdown case Not(Or(cond1, cond2)) => calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) + // Collapse two consecutive Not operators which could be generated after Not-operator pushdown + case Not(Not(cond)) => + calculateFilterSelectivity(cond, update = false) + + // The foldable Not has been processed in the ConstantFolding rule + // This is a top-down traversal. The Not could be pushed down by the above two cases. + case Not(l @ Literal(null, _)) => + calculateSingleCondition(l, update = false) + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { case Some(percent) => Some(1.0 - percent) @@ -134,13 +146,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { + case l: Literal => + evaluateLiteral(l) + // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. // EqualTo/EqualNullSafe does not care about the order - case op @ Equality(ar: Attribute, l: Literal) => + case Equality(ar: Attribute, l: Literal) => evaluateEquality(ar, l, update) - case op @ Equality(l: Literal, ar: Attribute) => + case Equality(l: Literal, ar: Attribute) => evaluateEquality(ar, l, update) case op @ LessThan(ar: Attribute, l: Literal) => @@ -342,6 +357,26 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } + /** + * Returns a percentage of rows meeting a Literal expression. + * This method evaluates all the possible literal cases in Filter. + * + * FalseLiteral and TrueLiteral should be eliminated by optimizer, but null literal might be added + * by optimizer rule NullPropagation. For safety, we handle all the cases here. + * + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateLiteral(literal: Literal): Option[Double] = { + literal match { + case Literal(null, _) => Some(0.0) + case FalseLiteral => Some(0.0) + case TrueLiteral => Some(1.0) + // Ideally, we should not hit the following branch + case _ => None + } + } + /** * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa1..54006e20a3eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -417,6 +417,12 @@ object StructType extends AbstractDataType { } } + /** + * Creates StructType for a given DDL-formatted string, which is a comma separated list of field + * definitions, e.g., a INT, b STRING. + */ + def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 07abe1ed28533..1966c96c05294 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -76,6 +77,82 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attrDouble -> colStatDouble, attrString -> colStatString)) + test("true") { + validateEstimatedStats( + Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 10) + } + + test("false") { + validateEstimatedStats( + Filter(FalseLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("null") { + validateEstimatedStats( + Filter(Literal(null, IntegerType), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(null)") { + validateEstimatedStats( + Filter(Not(Literal(null, IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(Not(null))") { + validateEstimatedStats( + Filter(Not(Not(Literal(null, IntegerType))), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 AND null") { + val condition = And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 OR null") { + val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 3) + } + + test("Not(cint < 3 AND null)") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + + test("Not(cint < 3 OR null)") { + val condition = Not(Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(cint < 3 AND Not(null))") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), @@ -163,6 +240,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 10) } + test("cint IS NOT NULL && null") { + // 'cint < null' will be optimized to 'cint IS NOT NULL && null'. + // More similar cases can be found in the Optimizer NullPropagation. + val condition = And(IsNotNull(attrInt), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + test("cint > 3 AND cint <= 6") { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 61e1ec7c7ab35..05cb999af6a50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite { assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) } - def checkDataTypeJsonRepr(dataType: DataType): Unit = { - test(s"JSON - $dataType") { + def checkDataTypeFromJson(dataType: DataType): Unit = { + test(s"from Json - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) } } - checkDataTypeJsonRepr(NullType) - checkDataTypeJsonRepr(BooleanType) - checkDataTypeJsonRepr(ByteType) - checkDataTypeJsonRepr(ShortType) - checkDataTypeJsonRepr(IntegerType) - checkDataTypeJsonRepr(LongType) - checkDataTypeJsonRepr(FloatType) - checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) - checkDataTypeJsonRepr(DateType) - checkDataTypeJsonRepr(TimestampType) - checkDataTypeJsonRepr(StringType) - checkDataTypeJsonRepr(BinaryType) - checkDataTypeJsonRepr(ArrayType(DoubleType, true)) - checkDataTypeJsonRepr(ArrayType(StringType, false)) - checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) - checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + def checkDataTypeFromDDL(dataType: DataType): Unit = { + test(s"from DDL - $dataType") { + val parsed = StructType.fromDDL(s"a ${dataType.sql}") + val expected = new StructType().add("a", dataType) + assert(parsed.sameType(expected)) + } + } + + checkDataTypeFromJson(NullType) + + checkDataTypeFromJson(BooleanType) + checkDataTypeFromDDL(BooleanType) + + checkDataTypeFromJson(ByteType) + checkDataTypeFromDDL(ByteType) + + checkDataTypeFromJson(ShortType) + checkDataTypeFromDDL(ShortType) + + checkDataTypeFromJson(IntegerType) + checkDataTypeFromDDL(IntegerType) + + checkDataTypeFromJson(LongType) + checkDataTypeFromDDL(LongType) + + checkDataTypeFromJson(FloatType) + checkDataTypeFromDDL(FloatType) + + checkDataTypeFromJson(DoubleType) + checkDataTypeFromDDL(DoubleType) + + checkDataTypeFromJson(DecimalType(10, 5)) + checkDataTypeFromDDL(DecimalType(10, 5)) + + checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT) + checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT) + + checkDataTypeFromJson(DateType) + checkDataTypeFromDDL(DateType) + + checkDataTypeFromJson(TimestampType) + checkDataTypeFromDDL(TimestampType) + + checkDataTypeFromJson(StringType) + checkDataTypeFromDDL(StringType) + + checkDataTypeFromJson(BinaryType) + checkDataTypeFromDDL(BinaryType) + + checkDataTypeFromJson(ArrayType(DoubleType, true)) + checkDataTypeFromDDL(ArrayType(DoubleType, true)) + + checkDataTypeFromJson(ArrayType(StringType, false)) + checkDataTypeFromDDL(ArrayType(StringType, false)) + + checkDataTypeFromJson(MapType(IntegerType, StringType, true)) + checkDataTypeFromDDL(MapType(IntegerType, StringType, true)) + + checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) val metadata = new MetadataBuilder() .putString("name", "age") @@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", IntegerType, nullable = true), StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false, metadata))) - checkDataTypeJsonRepr(structType) + checkDataTypeFromJson(structType) + checkDataTypeFromDDL(structType) def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 49562578b23cd..a97297892b5e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -194,7 +194,7 @@ class SparkSession private( * * @since 2.0.0 */ - def udf: UDFRegistration = sessionState.udf + def udf: UDFRegistration = sessionState.udfRegistration /** * :: Experimental :: @@ -990,28 +990,28 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" + private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = + "org.apache.spark.sql.hive.HiveSessionStateBuilder" private def sessionStateClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_SESSION_STATE_CLASS_NAME - case "in-memory" => classOf[SessionState].getCanonicalName + case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME + case "in-memory" => classOf[SessionStateBuilder].getCanonicalName } } /** * Helper method to create an instance of `SessionState` based on `className` from conf. - * The result is either `SessionState` or `HiveSessionState`. + * The result is either `SessionState` or a Hive based `SessionState`. */ private def instantiateSessionState( className: String, sparkSession: SparkSession): SessionState = { - try { - // get `SessionState.apply(SparkSession)` + // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])` val clazz = Utils.classForName(className) - val method = clazz.getMethod("apply", sparkSession.getClass) - method.invoke(null, sparkSession).asInstanceOf[SessionState] + val ctor = clazz.getConstructors.head + ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build() } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) @@ -1023,7 +1023,7 @@ object SparkSession { */ private[spark] def hiveClassesArePresent: Boolean = { try { - Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) + Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d876688a8aabd..66a8e044ab879 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -628,13 +628,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum longMetric("dataSize") += dataSize - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } }(SubqueryExec.executionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a765..efcaca9338ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -97,13 +97,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { case oe: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index dbc27d8b237f3..ef982a4ebd10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -22,9 +22,15 @@ import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +/** + * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on + * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates + * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. + */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will @@ -126,4 +132,18 @@ object SQLMetrics { s"\n$sum ($min, $med, $max)" } } + + /** + * Updates metrics based on the driver side value. This is useful for certain metrics that + * are only updated on the driver, e.g. subquery execution time, or number of files. + */ + def postDriverMetricUpdates( + sc: SparkContext, executionId: String, metrics: Seq[SQLMetric]): Unit = { + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 12d3bc9281f35..b4a91230a0012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,6 +47,13 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent +/** + * A message used to update SQL metric value for driver-side updates (which doesn't get reflected + * automatically). + * + * @param executionId The execution id for a query, so we can find the query plan. + * @param accumUpdates Map from accumulator id to the metric value (metrics are always 64-bit ints). + */ @DeveloperApi case class SparkListenerDriverAccumUpdates( executionId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acdb8e2d3edc8..0f9203065ef05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection @@ -3055,13 +3056,21 @@ object functions { * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. * * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema), options) + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + val dataType = try { + DataType.fromJson(schema) + } catch { + case NonFatal(_) => StructType.fromDDL(schema) + } + from_json(e, dataType, options) + } /** * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b8f645fdee85a..2b14eca919fa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.util.ExecutionListenerManager /** * Builder class that coordinates construction of a new [[SessionState]]. @@ -133,6 +134,14 @@ abstract class BaseSessionStateBuilder( catalog } + /** + * Interface exposed to the user for registering user-defined functions. + * + * Note 1: The user-defined functions must be deterministic. + * Note 2: This depends on the `functionRegistry` field. + */ + protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry) + /** * Logical query plan analyzer for resolving unresolved attributes and relations. * @@ -232,6 +241,16 @@ abstract class BaseSessionStateBuilder( */ protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + * + * This gets cloned from parent if available, otherwise is a new instance is created. + */ + protected def listenerManager: ExecutionListenerManager = { + parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + } + /** * Function used to make clones of the session state. */ @@ -245,17 +264,18 @@ abstract class BaseSessionStateBuilder( */ def build(): SessionState = { new SessionState( - session.sparkContext, session.sharedState, conf, experimentalMethods, functionRegistry, + udfRegistration, catalog, sqlParser, analyzer, optimizer, planner, streamingQueryManager, + listenerManager, resourceLoader, createQueryExecution, createClone) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c6241d923d7b3..1b341a12fc609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -32,43 +32,46 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.streaming.StreamingQueryManager -import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener} /** * A class that holds all session-specific state in a given [[SparkSession]]. * - * @param sparkContext The [[SparkContext]]. - * @param sharedState The shared state. + * @param sharedState The state shared across sessions, e.g. global view manager, external catalog. * @param conf SQL-specific key-value configurations. - * @param experimentalMethods The experimental methods. + * @param experimentalMethods Interface to add custom planning strategies and optimizers. * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param udfRegistration Interface exposed to the user for registering user-defined functions. * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. * @param optimizer Logical query plan optimizer. - * @param planner Planner that converts optimized logical plans to physical plans + * @param planner Planner that converts optimized logical plans to physical plans. * @param streamingQueryManager Interface to start and stop streaming queries. + * @param listenerManager Interface to register custom [[QueryExecutionListener]]s. + * @param resourceLoader Session shared resource loader to load JARs, files, etc. * @param createQueryExecution Function used to create QueryExecution objects. * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( - sparkContext: SparkContext, sharedState: SharedState, val conf: SQLConf, val experimentalMethods: ExperimentalMethods, val functionRegistry: FunctionRegistry, + val udfRegistration: UDFRegistration, val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, val optimizer: Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, + val listenerManager: ExecutionListenerManager, val resourceLoader: SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( - sparkContext.hadoopConfiguration, + sharedState.sparkContext.hadoopConfiguration, conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { @@ -81,18 +84,6 @@ private[sql] class SessionState( hadoopConf } - /** - * Interface exposed to the user for registering user-defined functions. - * Note that the user-defined functions must be deterministic. - */ - val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s - * that listen for execution metrics. - */ - val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ @@ -110,13 +101,6 @@ private[sql] class SessionState( } private[sql] object SessionState { - /** - * Create a new [[SessionState]] for the given session. - */ - def apply(session: SparkSession): SessionState = { - new SessionStateBuilder(session).build() - } - def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { val newHadoopConf = new Configuration(hadoopConf) sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } @@ -155,7 +139,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade /** * Add a jar path to [[SparkContext]] and the classloader. * - * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * Note: this method seems not access any session state, but a Hive based `SessionState` needs * to add the jar to its hive client for the current session. Hence, it still needs to be in * [[SessionState]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 26ad0eadd9d4c..f6240d85fba6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging { listeners.clear() } + /** + * Get an identical copy of this listener manager. + */ + @DeveloperApi + override def clone(): ExecutionListenerManager = writeLock { + val newListenerManager = new ExecutionListenerManager + listeners.foreach(newListenerManager.register) + newListenerManager + } + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { readLock { withErrorHandling { listener => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 170c238c53438..8465e8d036a6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } + test("from_json uses DDL strings for defining a schema") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 2d5e37242a58b..5638c8eeda842 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener class SessionStateSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { @@ -122,6 +125,56 @@ class SessionStateSuite extends SparkFunSuite } } + test("fork new session and inherit listener manager") { + class CommandCollector extends QueryExecutionListener { + val commands: ArrayBuffer[String] = ArrayBuffer.empty[String] + override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName + } + } + val collectorA = new CommandCollector + val collectorB = new CommandCollector + val collectorC = new CommandCollector + + try { + def runCollectQueryOn(sparkSession: SparkSession): Unit = { + val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING) + val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j") + df.select("i").collect() + } + + activeSession.listenerManager.register(collectorA) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.listenerManager ne activeSession.listenerManager) + runCollectQueryOn(forkedSession) + assert(collectorA.commands.length == 1) // forked should callback to A + assert(collectorA.commands(0) == "collect") + + // independence + // => changes to forked do not affect original + forkedSession.listenerManager.register(collectorB) + runCollectQueryOn(activeSession) + assert(collectorB.commands.isEmpty) // original should not callback to B + assert(collectorA.commands.length == 2) // original should still callback to A + assert(collectorA.commands(1) == "collect") + // <= changes to original do not affect forked + activeSession.listenerManager.register(collectorC) + runCollectQueryOn(forkedSession) + assert(collectorC.commands.isEmpty) // forked should not callback to C + assert(collectorA.commands.length == 3) // forked should still callback to A + assert(collectorB.commands.length == 1) // forked should still callback to B + assert(collectorA.commands(2) == "collect") + assert(collectorB.commands(0) == "collect") + } finally { + activeSession.listenerManager.unregister(collectorA) + activeSession.listenerManager.unregister(collectorC) + } + } + test("fork new sessions and run query on inherited table") { def checkTableExists(sparkSession: SparkSession): Unit = { QueryTest.checkAnswer(sparkSession.sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e41c00ecec271..e6cd41e4facf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -477,9 +477,11 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue - sc.listenerBus.post(SparkListenerDriverAccumUpdates( - sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, - metrics.values.map(m => m.id -> m.value).toSeq)) + + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) sc.emptyRDD } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 517b01f183926..ff3784cab9e26 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -292,7 +292,7 @@ object SparkExecuteStatementOperation { def getTableSchema(structType: StructType): TableSchema = { val schema = structType.map { field => val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString - new FieldSchema(field.name, attrTypeString, "") + new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse("")) } new TableSchema(schema.asJava) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 0c79b6f4211ff..1bc5c3c62f045 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,12 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.util.ShutdownHookManager /** @@ -275,6 +276,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + if (sessionState.getIsSilent) { + Logger.getRootLogger.setLevel(Level.WARN) + } + private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 32ded0d254ef8..06e3980662048 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} class SparkExecuteStatementOperationSuite extends SparkFunSuite { test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { @@ -30,4 +30,16 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) } + + test("SPARK-20146 Comment should be preserved") { + val field1 = StructField("column1", StringType).withComment("comment 1") + val field2 = StructField("column2", IntegerType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE) + assert(columns.get(0).getComment() == "comment 1") + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE) + assert(columns.get(1).getComment() == "") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala similarity index 92% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index f49e6bb418644..8048c2ba2c2e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,19 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Entry object for creating a Hive aware [[SessionState]]. - */ -private[hive] object HiveSessionState { - /** - * Create a new Hive aware [[SessionState]]. for the given session. - */ - def apply(session: SparkSession): SessionState = { - new HiveSessionStateBuilder(session).build() - } -} - -/** - * Builder that produces a [[HiveSessionState]]. + * Builder that produces a Hive aware [[SessionState]]. */ @Experimental @InterfaceStability.Unstable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 0bcf219922764..d9bb1f8c7edcc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal._ +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 67c77fb62f4e1..958ad3e1c3ce8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton /** - * Run all tests from `SessionStateSuite` with a `HiveSessionState`. + * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton with BeforeAndAfterEach { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1607c97cd6acb..9f4009bfe402a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, Row, SparkSession} +import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index d0864fd3678b2..844760ab61d2e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel, - encrypt = true) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager .dataDeserializeStream( blockId, - new ChunkedByteBuffer(dataRead).toInputStream(), - maybeEncrypted = false)(elementClassTag) + new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 2b488038f0620..80c07958b41f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true, - encrypt = true) + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = serializerManager.dataSerialize(blockId, countIterator, - allowEncryption = false) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => @@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( blockId, serializedBlock, effectiveStorageLevel, - tellMaster = true, - encrypt = true) + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c2b0389b8c6f0..3c4a2716caf90 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) reader.close() serializerManager.dataDeserializeStream( generateBlockId(), - new ChunkedByteBuffer(bytes).toInputStream(), - maybeEncrypted = false)(ClassTag.Any).toList + new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } @@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) } def dataToByteBuffer(b: Seq[String]) = - serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false) + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 2ac0dc96916c5..aa69be7ca9939 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false) - .toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments