Skip to content
Merged
27 changes: 23 additions & 4 deletions core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
import org.apache.spark.util.Utils

private[deploy] sealed trait DeployMessage extends Serializable
Expand All @@ -34,14 +34,25 @@ private[deploy] object DeployMessages {

// Worker to Master

/**
* @param id the worker id
* @param host the worker host
* @param port the worker post
* @param worker the worker endpoint ref
* @param cores the core number of worker
* @param memory the memory size of worker
* @param workerWebUiUrl the worker Web UI address
* @param masterAddress the master address used by the worker to connect
*/
case class RegisterWorker(
id: String,
host: String,
port: Int,
worker: RpcEndpointRef,
cores: Int,
memory: Int,
workerWebUiUrl: String)
workerWebUiUrl: String,
masterAddress: RpcAddress)
extends DeployMessage {
Utils.checkHost(host)
assert (port > 0)
Expand Down Expand Up @@ -80,8 +91,16 @@ private[deploy] object DeployMessages {

sealed trait RegisterWorkerResponse

case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage
with RegisterWorkerResponse
/**
* @param master the master ref
* @param masterWebUiUrl the master Web UI address
* @param masterAddress the master address used by the worker to connect. It should be
* [[RegisterWorker.masterAddress]].
*/
case class RegisteredWorker(
master: RpcEndpointRef,
masterWebUiUrl: String,
masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse

case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ private[deploy] class Master(
logError("Leadership has been revoked -- master shutting down.")
System.exit(0)

case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) =>
case RegisterWorker(
id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) =>
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
Expand All @@ -243,7 +244,7 @@ private[deploy] class Master(
workerRef, workerWebUiUrl)
if (registerWorker(worker)) {
persistenceEngine.addWorker(worker)
workerRef.send(RegisteredWorker(self, masterWebUiUrl))
workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress))
schedule()
} else {
val workerAddress = worker.endpoint.address
Expand Down
53 changes: 45 additions & 8 deletions core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ private[deploy] class Worker(

private val testing: Boolean = sys.props.contains("spark.testing")
private var master: Option[RpcEndpointRef] = None

/**
* Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker
* will just use the address received from Master.
*/
private val preferConfiguredMasterAddress =
conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false)
/**
* The master address to connect in case of failure. When the connection is broken, worker will
* use this address to connect. This is usually just one of `masterRpcAddresses`. However, when
* a master is restarted or takes over leadership, it will be an address sent from master, which
* may not be in `masterRpcAddresses`.
*/
private var masterAddressToConnect: Option[RpcAddress] = None
private var activeMasterUrl: String = ""
private[worker] var activeMasterWebUiUrl : String = ""
private var workerWebUiUrl: String = ""
Expand Down Expand Up @@ -196,10 +210,19 @@ private[deploy] class Worker(
metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}

private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) {
/**
* Change to use the new master.
*
* @param masterRef the new master ref
* @param uiUrl the new master Web UI address
* @param masterAddress the new master address which the worker should use to connect in case of
* failure
*/
private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) {
// activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
masterAddressToConnect = Some(masterAddress)
master = Some(masterRef)
connected = true
if (conf.getBoolean("spark.ui.reverseProxy", false)) {
Expand Down Expand Up @@ -266,7 +289,8 @@ private[deploy] class Worker(
if (registerMasterFutures != null) {
registerMasterFutures.foreach(_.cancel(true))
}
val masterAddress = masterRef.address
val masterAddress =
if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address
registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable {
override def run(): Unit = {
try {
Expand Down Expand Up @@ -342,15 +366,27 @@ private[deploy] class Worker(
}

private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = {
masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl))
masterEndpoint.send(RegisterWorker(
workerId,
host,
port,
self,
cores,
memory,
workerWebUiUrl,
masterEndpoint.address))
}

private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized {
msg match {
case RegisteredWorker(masterRef, masterWebUiUrl) =>
logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) =>
if (preferConfiguredMasterAddress) {
logInfo("Successfully registered with master " + masterAddress.toSparkURL)
} else {
logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
}
registered = true
changeMaster(masterRef, masterWebUiUrl)
changeMaster(masterRef, masterWebUiUrl, masterAddress)
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(SendHeartbeat)
Expand Down Expand Up @@ -419,7 +455,7 @@ private[deploy] class Worker(

case MasterChanged(masterRef, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
changeMaster(masterRef, masterWebUiUrl)
changeMaster(masterRef, masterWebUiUrl, masterRef.address)

val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
Expand Down Expand Up @@ -561,7 +597,8 @@ private[deploy] class Worker(
}

override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (master.exists(_.address == remoteAddress)) {
if (master.exists(_.address == remoteAddress) ||
masterAddressToConnect.exists(_ == remoteAddress)) {
logInfo(s"$remoteAddress Disassociated !")
masterDisconnected()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy._
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv}

class MasterSuite extends SparkFunSuite
with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
Expand Down Expand Up @@ -447,8 +447,15 @@ class MasterSuite extends SparkFunSuite
}
})

master.self.send(
RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080"))
master.self.send(RegisterWorker(
"1",
"localhost",
9999,
fakeWorker,
10,
1024,
"http://localhost:8080",
RpcAddress("localhost", 9999)))
val executors = (0 until 3).map { i =>
new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING)
}
Expand All @@ -459,4 +466,37 @@ class MasterSuite extends SparkFunSuite
assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2"))
}
}

test("SPARK-20529: Master should reply the address received from worker") {
val master = makeMaster()
master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
eventually(timeout(10.seconds)) {
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
assert(masterState.status === RecoveryState.ALIVE, "Master is not alive")
}

@volatile var receivedMasterAddress: RpcAddress = null
val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint {
override val rpcEnv: RpcEnv = master.rpcEnv

override def receive: PartialFunction[Any, Unit] = {
case RegisteredWorker(_, _, masterAddress) =>
receivedMasterAddress = masterAddress
}
})

master.self.send(RegisterWorker(
"1",
"localhost",
9999,
fakeWorker,
10,
1024,
"http://localhost:8080",
RpcAddress("localhost2", 10000)))

eventually(timeout(10.seconds)) {
assert(receivedMasterAddress === RpcAddress("localhost2", 10000))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,15 @@ public static void main(String[] args) {
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);

// Generate top 10 movie recommendations for each user
Dataset<Row> userRecs = model.recommendForAllUsers(10);
// Generate top 10 user recommendations for each movie
Dataset<Row> movieRecs = model.recommendForAllItems(10);
// $example off$
userRecs.show();
movieRecs.show();

spark.stop();
}
}
8 changes: 8 additions & 0 deletions examples/src/main/python/ml/als_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,13 @@
predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))

# Generate top 10 movie recommendations for each user
userRecs = model.recommendForAllUsers(10)
# Generate top 10 user recommendations for each movie
movieRecs = model.recommendForAllItems(10)
# $example off$
userRecs.show()
movieRecs.show()

spark.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ object ALSExample {
.setPredictionCol("prediction")
val rmse = evaluator.evaluate(predictions)
println(s"Root-mean-square error = $rmse")

// Generate top 10 movie recommendations for each user
val userRecs = model.recommendForAllUsers(10)
// Generate top 10 user recommendations for each movie
val movieRecs = model.recommendForAllItems(10)
// $example off$
userRecs.show()
movieRecs.show()

spark.stop()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.auth.AWSCredentials
import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord
import com.amazonaws.services.kinesis.model._
Expand Down Expand Up @@ -81,9 +81,9 @@ class KinesisBackedBlockRDD[T: ClassTag](
@transient private val _blockIds: Array[BlockId],
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient private val isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _,
val kinesisCreds: SparkAWSCredentials = DefaultCredentials
val kinesisCreds: SparkAWSCredentials = DefaultCredentials,
val kinesisReadConfigs: KinesisReadConfigurations = KinesisReadConfigurations()
) extends BlockRDD[T](sc, _blockIds) {

require(_blockIds.length == arrayOfseqNumberRanges.length,
Expand Down Expand Up @@ -112,7 +112,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
val credentials = kinesisCreds.provider.getCredentials
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
range, retryTimeoutMs).map(messageHandler)
range, kinesisReadConfigs).map(messageHandler)
}
}
if (partition.isBlockIdValid) {
Expand All @@ -135,7 +135,7 @@ class KinesisSequenceRangeIterator(
endpointUrl: String,
regionId: String,
range: SequenceNumberRange,
retryTimeoutMs: Int) extends NextIterator[Record] with Logging {
kinesisReadConfigs: KinesisReadConfigurations) extends NextIterator[Record] with Logging {

private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
Expand Down Expand Up @@ -251,21 +251,19 @@ class KinesisSequenceRangeIterator(

/** Helper method to retry Kinesis API request with exponential backoff and timeouts */
private def retryOrTimeout[T](message: String)(body: => T): T = {
import KinesisSequenceRangeIterator._

var startTimeMs = System.currentTimeMillis()
val startTimeMs = System.currentTimeMillis()
var retryCount = 0
var waitTimeMs = MIN_RETRY_WAIT_TIME_MS
var result: Option[T] = None
var lastError: Throwable = null
var waitTimeInterval = kinesisReadConfigs.retryWaitTimeMs

def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs
def isMaxRetryDone = retryCount >= MAX_RETRIES
def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= kinesisReadConfigs.retryTimeoutMs
def isMaxRetryDone = retryCount >= kinesisReadConfigs.maxRetries

while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
if (retryCount > 0) { // wait only if this is a retry
Thread.sleep(waitTimeMs)
waitTimeMs *= 2 // if you have waited, then double wait time for next round
Thread.sleep(waitTimeInterval)
waitTimeInterval *= 2 // if you have waited, then double wait time for next round
}
try {
result = Some(body)
Expand All @@ -284,17 +282,12 @@ class KinesisSequenceRangeIterator(
result.getOrElse {
if (isTimedOut) {
throw new SparkException(
s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError)
s"Timed out after ${kinesisReadConfigs.retryTimeoutMs} ms while " +
s"$message, last exception: ", lastError)
} else {
throw new SparkException(
s"Gave up after $retryCount retries while $message, last exception: ", lastError)
}
}
}
}

private[streaming]
object KinesisSequenceRangeIterator {
val MAX_RETRIES = 3
val MIN_RETRY_WAIT_TIME_MS = 100
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.reflect.ClassTag

import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import KinesisReadConfigurations._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -60,12 +61,13 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray
logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " +
s"seq number ranges: ${seqNumRanges.mkString(", ")} ")

new KinesisBackedBlockRDD(
context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
messageHandler = messageHandler,
kinesisCreds = kinesisCreds)
kinesisCreds = kinesisCreds,
kinesisReadConfigs = KinesisReadConfigurations(ssc))
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
" it may not be possible to recover from failures")
Expand Down
Loading