diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index b5cb3f0a0f9dc..c1a91c27eef2d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable @@ -34,6 +34,16 @@ private[deploy] object DeployMessages { // Worker to Master + /** + * @param id the worker id + * @param host the worker host + * @param port the worker post + * @param worker the worker endpoint ref + * @param cores the core number of worker + * @param memory the memory size of worker + * @param workerWebUiUrl the worker Web UI address + * @param masterAddress the master address used by the worker to connect + */ case class RegisterWorker( id: String, host: String, @@ -41,7 +51,8 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - workerWebUiUrl: String) + workerWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage { Utils.checkHost(host) assert (port > 0) @@ -80,8 +91,16 @@ private[deploy] object DeployMessages { sealed trait RegisterWorkerResponse - case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage - with RegisterWorkerResponse + /** + * @param master the master ref + * @param masterWebUiUrl the master Web UI address + * @param masterAddress the master address used by the worker to connect. It should be + * [[RegisterWorker.masterAddress]]. + */ + case class RegisteredWorker( + master: RpcEndpointRef, + masterWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e061939623cbb..53384e7373252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,7 +231,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -243,7 +244,7 @@ private[deploy] class Master( workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress)) schedule() } else { val workerAddress = worker.endpoint.address diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 34e3a4c020c80..1198e3cb05eaa 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -99,6 +99,20 @@ private[deploy] class Worker( private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None + + /** + * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker + * will just use the address received from Master. + */ + private val preferConfiguredMasterAddress = + conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false) + /** + * The master address to connect in case of failure. When the connection is broken, worker will + * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when + * a master is restarted or takes over leadership, it will be an address sent from master, which + * may not be in `masterRpcAddresses`. + */ + private var masterAddressToConnect: Option[RpcAddress] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" private var workerWebUiUrl: String = "" @@ -196,10 +210,19 @@ private[deploy] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { + /** + * Change to use the new master. + * + * @param masterRef the new master ref + * @param uiUrl the new master Web UI address + * @param masterAddress the new master address which the worker should use to connect in case of + * failure + */ + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl + masterAddressToConnect = Some(masterAddress) master = Some(masterRef) connected = true if (conf.getBoolean("spark.ui.reverseProxy", false)) { @@ -266,7 +289,8 @@ private[deploy] class Worker( if (registerMasterFutures != null) { registerMasterFutures.foreach(_.cancel(true)) } - val masterAddress = masterRef.address + val masterAddress = + if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { @@ -342,15 +366,27 @@ private[deploy] class Worker( } private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) + masterEndpoint.send(RegisterWorker( + workerId, + host, + port, + self, + cores, + memory, + workerWebUiUrl, + masterEndpoint.address)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { msg match { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) => + if (preferConfiguredMasterAddress) { + logInfo("Successfully registered with master " + masterAddress.toSparkURL) + } else { + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + } registered = true - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterAddress) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { self.send(SendHeartbeat) @@ -419,7 +455,7 @@ private[deploy] class Worker( case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterRef.address) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) @@ -561,7 +597,8 @@ private[deploy] class Worker( } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (master.exists(_.address == remoteAddress)) { + if (master.exists(_.address == remoteAddress) || + masterAddressToConnect.exists(_ == remoteAddress)) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 2127da48ece49..539264652d7d5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -34,7 +34,7 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv} class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -447,8 +447,15 @@ class MasterSuite extends SparkFunSuite } }) - master.self.send( - RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost", 9999))) val executors = (0 until 3).map { i => new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) } @@ -459,4 +466,37 @@ class MasterSuite extends SparkFunSuite assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) } } + + test("SPARK-20529: Master should reply the address received from worker") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + @volatile var receivedMasterAddress: RpcAddress = null + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(_, _, masterAddress) => + receivedMasterAddress = masterAddress + } + }) + + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000))) + + eventually(timeout(10.seconds)) { + assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 81970b7c81f40..60ef03d89d17b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -113,7 +113,15 @@ public static void main(String[] args) { .setPredictionCol("prediction"); Double rmse = evaluator.evaluate(predictions); System.out.println("Root-mean-square error = " + rmse); + + // Generate top 10 movie recommendations for each user + Dataset userRecs = model.recommendForAllUsers(10); + // Generate top 10 user recommendations for each movie + Dataset movieRecs = model.recommendForAllItems(10); // $example off$ + userRecs.show(); + movieRecs.show(); + spark.stop(); } } diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 2e7214ed56f98..1672d552eb1d5 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -55,5 +55,13 @@ predictionCol="prediction") rmse = evaluator.evaluate(predictions) print("Root-mean-square error = " + str(rmse)) + + # Generate top 10 movie recommendations for each user + userRecs = model.recommendForAllUsers(10) + # Generate top 10 user recommendations for each movie + movieRecs = model.recommendForAllItems(10) # $example off$ + userRecs.show() + movieRecs.show() + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 868f49b16f218..07b15dfa178f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -75,7 +75,14 @@ object ALSExample { .setPredictionCol("prediction") val rmse = evaluator.evaluate(predictions) println(s"Root-mean-square error = $rmse") + + // Generate top 10 movie recommendations for each user + val userRecs = model.recommendForAllUsers(10) + // Generate top 10 user recommendations for each movie + val movieRecs = model.recommendForAllItems(10) // $example off$ + userRecs.show() + movieRecs.show() spark.stop() } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index f31ebf1ec8da0..88b294246bb30 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.auth.AWSCredentials import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ @@ -81,9 +81,9 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, - val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, - val kinesisCreds: SparkAWSCredentials = DefaultCredentials + val kinesisCreds: SparkAWSCredentials = DefaultCredentials, + val kinesisReadConfigs: KinesisReadConfigurations = KinesisReadConfigurations() ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -112,7 +112,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( val credentials = kinesisCreds.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, - range, retryTimeoutMs).map(messageHandler) + range, kinesisReadConfigs).map(messageHandler) } } if (partition.isBlockIdValid) { @@ -135,7 +135,7 @@ class KinesisSequenceRangeIterator( endpointUrl: String, regionId: String, range: SequenceNumberRange, - retryTimeoutMs: Int) extends NextIterator[Record] with Logging { + kinesisReadConfigs: KinesisReadConfigurations) extends NextIterator[Record] with Logging { private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName @@ -251,21 +251,19 @@ class KinesisSequenceRangeIterator( /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ private def retryOrTimeout[T](message: String)(body: => T): T = { - import KinesisSequenceRangeIterator._ - - var startTimeMs = System.currentTimeMillis() + val startTimeMs = System.currentTimeMillis() var retryCount = 0 - var waitTimeMs = MIN_RETRY_WAIT_TIME_MS var result: Option[T] = None var lastError: Throwable = null + var waitTimeInterval = kinesisReadConfigs.retryWaitTimeMs - def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs - def isMaxRetryDone = retryCount >= MAX_RETRIES + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= kinesisReadConfigs.retryTimeoutMs + def isMaxRetryDone = retryCount >= kinesisReadConfigs.maxRetries while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { if (retryCount > 0) { // wait only if this is a retry - Thread.sleep(waitTimeMs) - waitTimeMs *= 2 // if you have waited, then double wait time for next round + Thread.sleep(waitTimeInterval) + waitTimeInterval *= 2 // if you have waited, then double wait time for next round } try { result = Some(body) @@ -284,7 +282,8 @@ class KinesisSequenceRangeIterator( result.getOrElse { if (isTimedOut) { throw new SparkException( - s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + s"Timed out after ${kinesisReadConfigs.retryTimeoutMs} ms while " + + s"$message, last exception: ", lastError) } else { throw new SparkException( s"Gave up after $retryCount retries while $message, last exception: ", lastError) @@ -292,9 +291,3 @@ class KinesisSequenceRangeIterator( } } } - -private[streaming] -object KinesisSequenceRangeIterator { - val MAX_RETRIES = 3 - val MIN_RETRY_WAIT_TIME_MS = 100 -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 77553412eda56..decfb6b3ebd31 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import KinesisReadConfigurations._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD @@ -60,12 +61,13 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " + s"seq number ranges: ${seqNumRanges.mkString(", ")} ") + new KinesisBackedBlockRDD( context.sc, regionName, endpointUrl, blockIds, seqNumRanges, isBlockIdValid = isBlockIdValid, - retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - kinesisCreds = kinesisCreds) + kinesisCreds = kinesisCreds, + kinesisReadConfigs = KinesisReadConfigurations(ssc)) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReadConfigurations.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReadConfigurations.scala new file mode 100644 index 0000000000000..871071e4677e3 --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReadConfigurations.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.streaming.StreamingContext + +/** + * Configurations to pass to the [[KinesisBackedBlockRDD]]. + * + * @param maxRetries: The maximum number of attempts to be made to Kinesis. Defaults to 3. + * @param retryWaitTimeMs: The interval between consequent Kinesis retries. + * Defaults to 100ms. + * @param retryTimeoutMs: The timeout in milliseconds for a Kinesis request. + * Defaults to batch duration provided for streaming, + * else uses 10000 if invoked directly. + */ +private[kinesis] case class KinesisReadConfigurations( + maxRetries: Int, + retryWaitTimeMs: Long, + retryTimeoutMs: Long) + +private[kinesis] object KinesisReadConfigurations { + def apply(): KinesisReadConfigurations = { + KinesisReadConfigurations(maxRetries = DEFAULT_MAX_RETRIES, + retryWaitTimeMs = JavaUtils.timeStringAsMs(DEFAULT_RETRY_WAIT_TIME), + retryTimeoutMs = DEFAULT_RETRY_TIMEOUT) + } + + def apply(ssc: StreamingContext): KinesisReadConfigurations = { + KinesisReadConfigurations( + maxRetries = ssc.sc.getConf.getInt(RETRY_MAX_ATTEMPTS_KEY, DEFAULT_MAX_RETRIES), + retryWaitTimeMs = JavaUtils.timeStringAsMs( + ssc.sc.getConf.get(RETRY_WAIT_TIME_KEY, DEFAULT_RETRY_WAIT_TIME)), + retryTimeoutMs = ssc.graph.batchDuration.milliseconds) + } + + /** + * SparkConf key for configuring the maximum number of retries used when attempting a Kinesis + * request. + */ + val RETRY_MAX_ATTEMPTS_KEY = "spark.streaming.kinesis.retry.maxAttempts" + + /** + * SparkConf key for configuring the wait time to use before retrying a Kinesis attempt. + */ + val RETRY_WAIT_TIME_KEY = "spark.streaming.kinesis.retry.waitTime" + + /** + * Default value for the RETRY_MAX_ATTEMPTS_KEY + */ + val DEFAULT_MAX_RETRIES = 3 + + /** + * Default value for the RETRY_WAIT_TIME_KEY + */ + val DEFAULT_RETRY_WAIT_TIME = "100ms" + + /** + * Default value for the retry timeout + */ + val DEFAULT_RETRY_TIMEOUT = 10000 +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 341a6898cbbff..7e5bda923f63e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo @@ -136,7 +137,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) - assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) + assert(kinesisRDD.kinesisReadConfigs.retryTimeoutMs === batchDuration.milliseconds) assert(kinesisRDD.kinesisCreds === BasicCredentials( awsAccessKeyId = dummyAWSAccessKey, awsSecretKey = dummyAWSSecretKey)) @@ -234,6 +235,52 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc.stop(stopSparkContext = false) } + test("Kinesis read with custom configurations") { + try { + ssc.sc.conf.set(RETRY_WAIT_TIME_KEY, "2000ms") + ssc.sc.conf.set(RETRY_MAX_ATTEMPTS_KEY, "5") + + val kinesisStream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName("dummyStream") + .endpointUrl(dummyEndpointUrl) + .regionName(dummyRegionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() + .asInstanceOf[KinesisInputDStream[Array[Byte]]] + + val time = Time(1000) + // Generate block info data for testing + val seqNumRanges1 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 67)) + val blockId1 = StreamBlockId(kinesisStream.id, 123) + val blockInfo1 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) + + val seqNumRanges2 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb", 89)) + val blockId2 = StreamBlockId(kinesisStream.id, 345) + val blockInfo2 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) + + // Verify that the generated KinesisBackedBlockRDD has the all the right information + val blockInfos = Seq(blockInfo1, blockInfo2) + + val kinesisRDD = + kinesisStream.createBlockRDD(time, blockInfos).asInstanceOf[KinesisBackedBlockRDD[_]] + + assert(kinesisRDD.kinesisReadConfigs.retryWaitTimeMs === 2000) + assert(kinesisRDD.kinesisReadConfigs.maxRetries === 5) + assert(kinesisRDD.kinesisReadConfigs.retryTimeoutMs === batchDuration.milliseconds) + } finally { + ssc.sc.conf.remove(RETRY_WAIT_TIME_KEY) + ssc.sc.conf.remove(RETRY_MAX_ATTEMPTS_KEY) + ssc.stop(stopSparkContext = false) + } + } + testIfEnabled("split and merge shards in a stream") { // Since this test tries to split and merge shards in a stream, we create another // temporary stream and then remove it when finished. diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index ef3890962494d..2a0f8c11d0a50 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -29,7 +29,7 @@ private[spark] object BLAS extends Serializable { @transient private var _nativeBLAS: NetlibBLAS = _ // For level-1 routines, we use Java implementation. - private def f2jBLAS: NetlibBLAS = { + private[ml] def f2jBLAS: NetlibBLAS = { if (_f2jBLAS == null) { _f2jBLAS = new F2jBLAS } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 7507c7539d4ef..9900fbc9edda7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -51,6 +51,7 @@ private[classification] trait LinearSVCParams extends ClassifierParams with HasR * Linear SVM Classifier * * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. + * Only supports L2 regularization currently. * */ @Since("2.2.0") @@ -148,7 +149,7 @@ class LinearSVC @Since("2.2.0") ( @Since("2.2.0") override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) - override protected[classification] def train(dataset: Dataset[_]): LinearSVCModel = { + override protected def train(dataset: Dataset[_]): LinearSVCModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -264,7 +265,7 @@ object LinearSVC extends DefaultParamsReadable[LinearSVC] { /** * :: Experimental :: - * SVM Model trained by [[LinearSVC]] + * Linear SVM Model trained by [[LinearSVC]] */ @Since("2.2.0") @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 053487242edd8..567af0488e1b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -267,8 +267,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } /** - * Logistic regression. Supports multinomial logistic (softmax) regression and binomial logistic - * regression. + * Logistic regression. Supports: + * - Multinomial logistic (softmax) regression. + * - Binomial logistic regression. + * + * This class supports fitting traditional logistic regression model by LBFGS/OWLQN and + * bound (box) constrained logistic regression model by LBFGSB. */ @Since("1.2.0") class LogisticRegression @Since("1.2.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index a41bd8e689d56..9e023b9dd469b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -102,7 +102,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols { * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @Experimental -class Imputer @Since("2.2.0")(override val uid: String) +@Since("2.2.0") +class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { @Since("2.2.0") @@ -165,8 +166,8 @@ class Imputer @Since("2.2.0")(override val uid: String) object Imputer extends DefaultParamsReadable[Imputer] { /** strategy names that Imputer currently supports. */ - private[ml] val mean = "mean" - private[ml] val median = "median" + private[feature] val mean = "mean" + private[feature] val median = "median" @Since("2.2.0") override def load(path: String): Imputer = super.load(path) @@ -180,9 +181,10 @@ object Imputer extends DefaultParamsReadable[Imputer] { * which are used to replace the missing values in the input DataFrame. */ @Experimental -class ImputerModel private[ml]( - override val uid: String, - val surrogateDF: DataFrame) +@Since("2.2.0") +class ImputerModel private[ml] ( + @Since("2.2.0") override val uid: String, + @Since("2.2.0") val surrogateDF: DataFrame) extends Model[ImputerModel] with ImputerParams with MLWritable { import ImputerModel._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 12804d08a4bc6..aa7871d6ff29d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -200,7 +200,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { @Experimental class FPGrowthModel private[ml] ( @Since("2.2.0") override val uid: String, - @transient val freqItemsets: DataFrame) + @Since("2.2.0") @transient val freqItemsets: DataFrame) extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d626f04599670..0955d3e6e1f8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -35,6 +35,7 @@ import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContex import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -363,7 +364,7 @@ class ALSModel private[ml] ( * relatively efficient, the approach implemented here is significantly more efficient. * * This approach groups factors into blocks and computes the top-k elements per block, - * using a simple dot product (instead of gemm) and an efficient [[BoundedPriorityQueue]]. + * using dot product and an efficient [[BoundedPriorityQueue]] (instead of gemm). * It then computes the global top-k by aggregating the per block top-k elements with * a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data. * This is the DataFrame equivalent to the approach used in @@ -393,31 +394,18 @@ class ALSModel private[ml] ( val m = srcIter.size val n = math.min(dstIter.size, num) val output = new Array[(Int, Int, Float)](m * n) - var j = 0 + var i = 0 val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) srcIter.foreach { case (srcId, srcFactor) => dstIter.foreach { case (dstId, dstFactor) => - /* - * The below code is equivalent to - * `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)` - * This handwritten version is as or more efficient as BLAS calls in this case. - */ - var score = 0.0f - var k = 0 - while (k < rank) { - score += srcFactor(k) * dstFactor(k) - k += 1 - } + // We use F2jBLAS which is faster than a call to native BLAS for vector dot product + val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1) pq += dstId -> score } - val pqIter = pq.iterator - var i = 0 - while (i < n) { - val (dstId, score) = pqIter.next() - output(j + i) = (srcId, dstId, score) + pq.foreach { case (dstId, score) => + output(i) = (srcId, dstId, score) i += 1 } - j += n pq.clear() } output.toSeq diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index e185bc8a6faaa..6e885d7c8aec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.types.{StructField, StructType} /** - * API for correlation functions in MLlib, compatible with Dataframes and Datasets. + * API for correlation functions in MLlib, compatible with DataFrames and Datasets. * * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]] * to spark.ml's Vector types. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index cd1950bd76c05..3fc3ac58b7795 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -110,77 +110,77 @@ private[ml] trait DecisionTreeParams extends PredictorParams maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group getParam */ final def getMaxBins: Int = $(maxBins) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setSeed(value: Long): this.type = set(seed, value) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group expertSetParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group expertSetParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -226,10 +226,10 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ @@ -277,10 +277,10 @@ private[ml] trait TreeRegressorParams extends Params { setDefault(impurity -> "variance") /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ @@ -339,10 +339,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group getParam */ @@ -383,10 +383,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ @@ -431,10 +431,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ @@ -472,10 +472,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { // validationTol -> 1e-5 /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @@ -492,10 +492,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { final def getStepSize: Double = $(stepSize) /** - * @deprecated This method is deprecated and will be removed in 2.2.0. + * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) setDefault(maxIter -> 20, stepSize -> 0.1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a8b80031faf86..b54e258cff2f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -44,9 +44,11 @@ private[util] sealed trait BaseReadWrite { /** * Sets the Spark SQLContext to use for saving/loading. + * + * @deprecated Use session instead. This method will be removed in 3.0.0. */ @Since("1.6.0") - @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") + @deprecated("Use session instead. This method will be removed in 3.0.0.", "2.0.0") def context(sqlContext: SQLContext): this.type = { optionSparkSession = Option(sqlContext.sparkSession) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 0cd68a633c0b5..cb97742245689 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -31,7 +31,7 @@ private[spark] object BLAS extends Serializable with Logging { @transient private var _nativeBLAS: NetlibBLAS = _ // For level-1 routines, we use Java implementation. - private def f2jBLAS: NetlibBLAS = { + private[mllib] def f2jBLAS: NetlibBLAS = { if (_f2jBLAS == null) { _f2jBLAS = new F2jBLAS } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index d45866c016d91..ac709ad72f0c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -20,8 +20,6 @@ package org.apache.spark.mllib.recommendation import java.io.IOException import java.lang.{Integer => JavaInteger} -import scala.collection.mutable - import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.Path @@ -33,7 +31,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.BLAS import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -263,6 +261,19 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { /** * Makes recommendations for all users (or products). + * + * Note: the previous approach used for computing top-k recommendations aimed to group + * individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could + * be used for efficiency. However, this causes excessive GC pressure due to the large + * arrays required for intermediate result storage, as well as a high sensitivity to the + * block size used. + * + * The following approach still groups factors into blocks, but instead computes the + * top-k elements per block, using dot product and an efficient [[BoundedPriorityQueue]] + * (instead of gemm). This avoids any large intermediate data structures and results + * in significantly reduced GC pressure as well as shuffle data, which far outweighs + * any cost incurred from not using Level 3 BLAS operations. + * * @param rank rank * @param srcFeatures src features to receive recommendations * @param dstFeatures dst features used to make recommendations @@ -277,46 +288,22 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { num: Int): RDD[(Int, Array[(Int, Double)])] = { val srcBlocks = blockify(srcFeatures) val dstBlocks = blockify(dstFeatures) - /** - * The previous approach used for computing top-k recommendations aimed to group - * individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could - * be used for efficiency. However, this causes excessive GC pressure due to the large - * arrays required for intermediate result storage, as well as a high sensitivity to the - * block size used. - * The following approach still groups factors into blocks, but instead computes the - * top-k elements per block, using a simple dot product (instead of gemm) and an efficient - * [[BoundedPriorityQueue]]. This avoids any large intermediate data structures and results - * in significantly reduced GC pressure as well as shuffle data, which far outweighs - * any cost incurred from not using Level 3 BLAS operations. - */ val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) => val m = srcIter.size val n = math.min(dstIter.size, num) val output = new Array[(Int, (Int, Double))](m * n) - var j = 0 + var i = 0 val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2)) srcIter.foreach { case (srcId, srcFactor) => dstIter.foreach { case (dstId, dstFactor) => - /* - * The below code is equivalent to - * `val score = blas.ddot(rank, srcFactor, 1, dstFactor, 1)` - * This handwritten version is as or more efficient as BLAS calls in this case. - */ - var score: Double = 0 - var k = 0 - while (k < rank) { - score += srcFactor(k) * dstFactor(k) - k += 1 - } + // We use F2jBLAS which is faster than a call to native BLAS for vector dot product + val score = BLAS.f2jBLAS.ddot(rank, srcFactor, 1, dstFactor, 1) pq += dstId -> score } - val pqIter = pq.iterator - var i = 0 - while (i < n) { - output(j + i) = (srcId, pqIter.next()) + pq.foreach { case (dstId, score) => + output(i) = (srcId, (dstId, score)) i += 1 } - j += n pq.clear() } output.toSeq diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 930646de9cd86..01627ba92b633 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -96,3 +96,11 @@ pyspark.ml.fpm module :members: :undoc-members: :inherited-members: + +pyspark.ml.util module +---------------------------- + +.. automodule:: pyspark.ml.util + :members: + :undoc-members: + :inherited-members: diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index dcc12d93e979f..60bdeedd6a144 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -70,6 +70,7 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha `Linear SVM Classifier `_ This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. + Only supports L2 regularization currently. >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 02016f172aebc..7863edda7e7ab 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -79,7 +79,8 @@ def overwrite(self): def context(self, sqlContext): """ Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + + .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -113,9 +114,10 @@ def overwrite(self): def context(self, sqlContext): """ Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + + .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") + warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") self._jwrite.context(sqlContext._ssql_ctx) return self @@ -168,7 +170,8 @@ def load(self, path): def context(self, sqlContext): """ Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + + .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -200,9 +203,10 @@ def load(self, path): def context(self, sqlContext): """ Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + + .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") + warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") self._jread.context(sqlContext._ssql_ctx) return self diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 3ca9e6a8da5b5..1fc3a654cfeb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -155,7 +155,7 @@ object ExternalCatalogUtils { }) inputPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 760ead42c762c..f8da78b5f5e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,10 @@ import scala.language.existentials import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} +import org.apache.commons.lang3.exception.ExceptionUtils +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} @@ -899,8 +902,14 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - def compile(code: CodeAndComment): GeneratedClass = { + def compile(code: CodeAndComment): GeneratedClass = try { cache.get(code) + } catch { + // Cache.get() may wrap the original exception. See the following URL + // http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/ + // Cache.html#get(K,%20java.util.concurrent.Callable) + case e @ (_: UncheckedExecutionException | _: ExecutionError) => + throw e.getCause } /** @@ -951,10 +960,14 @@ object CodeGenerator extends Logging { evaluator.cook("generated.java", code.body) recordCompilationStats(evaluator) } catch { - case e: Exception => + case e: JaninoRuntimeException => val msg = s"failed to compile: $e\n$formatted" logError(msg, e) - throw new Exception(msg, e) + throw new JaninoRuntimeException(msg, e) + case e: CompileException => + val msg = s"failed to compile: $e\n$formatted" + logError(msg, e) + throw new CompileException(msg, e.getLocation) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5034566132f7a..c15ee2ab270bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,20 +20,22 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = + def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = create(BindReferences.bindReference(expression, inputSchema)) - def create(expression: Expression): (InternalRow => Boolean) = { - (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] - } + def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) } +case class InterpretedPredicate(expression: Expression) extends BasePredicate { + override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] +} /** * An [[Expression]] that returns a boolean value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 8e2e973485e1c..09598ffe770c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -78,7 +78,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { protected def astBuilder: AstBuilder protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { - logInfo(s"Parsing command: $command") + logDebug(s"Parsing command: $command") val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) lexer.removeErrorListeners() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index cadab37a449aa..c4ed96640eb19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.JaninoRuntimeException + import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec @@ -353,9 +356,27 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } + private def genInterpretedPredicate( + expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = { + val str = expression.toString + val logMessage = if (str.length > 256) { + str.substring(0, 256 - 3) + "..." + } else { + str + } + logWarning(s"Codegen disabled for this expression:\n $logMessage") + InterpretedPredicate.create(expression, inputSchema) + } + protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e @ (_: JaninoRuntimeException | _: CompileException) + if sqlContext == null || sqlContext.conf.wholeStageFallback => + genInterpretedPredicate(expression, inputSchema) + } } protected def newOrdering( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index ffd7f6c750f85..6b6f6388d54e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -177,7 +177,7 @@ abstract class PartitioningAwareFileIndex( }) val selected = partitions.filter { - case PartitionPath(values, _) => boundPredicate(values) + case PartitionPath(values, _) => boundPredicate.eval(values) } logInfo { val total = partitions.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ef0de6f6f4ff1..2f52192b54030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1844,4 +1844,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .filter($"x1".isNotNull || !$"y".isin("a!")) .count } + + test("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + val N = 400 + val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) + val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + + val filter = (0 until N) + .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) + df.filter(filter).count + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9f4009bfe402a..60a4638f610b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) Cast(Literal(value), dataType).eval() }) - }.filter(predicate).map(projection) + }.filter(predicate.eval).map(projection) // Appends partition values val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes