Skip to content

Commit fe82049

Browse files
committed
Merge pull request #159 from markhamstra/temp
Merged Apache branch-1.6
2 parents 0e93593 + 041565b commit fe82049

File tree

42 files changed

+640
-208
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+640
-208
lines changed

R/pkg/R/DataFrame.R

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,28 @@ setMethod("colnames",
303303
#' @rdname columns
304304
#' @name colnames<-
305305
setMethod("colnames<-",
306-
signature(x = "DataFrame", value = "character"),
306+
signature(x = "DataFrame"),
307307
function(x, value) {
308+
309+
# Check parameter integrity
310+
if (class(value) != "character") {
311+
stop("Invalid column names.")
312+
}
313+
314+
if (length(value) != ncol(x)) {
315+
stop(
316+
"Column names must have the same length as the number of columns in the dataset.")
317+
}
318+
319+
if (any(is.na(value))) {
320+
stop("Column names cannot be NA.")
321+
}
322+
323+
# Check if the column names have . in it
324+
if (any(regexec(".", value, fixed=TRUE)[[1]][1] != -1)) {
325+
stop("Colum names cannot contain the '.' symbol.")
326+
}
327+
308328
sdf <- callJMethod(x@sdf, "toDF", as.list(value))
309329
dataFrame(sdf)
310330
})

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,17 @@ test_that("names() colnames() set the column names", {
692692
colnames(df) <- c("col3", "col4")
693693
expect_equal(names(df)[1], "col3")
694694

695+
expect_error(colnames(df) <- c("sepal.length", "sepal_width"),
696+
"Colum names cannot contain the '.' symbol.")
697+
expect_error(colnames(df) <- c(1, 2), "Invalid column names.")
698+
expect_error(colnames(df) <- c("a"),
699+
"Column names must have the same length as the number of columns in the dataset.")
700+
expect_error(colnames(df) <- c("1", NA), "Column names cannot be NA.")
701+
702+
# Note: if this test is broken, remove check for "." character on colnames<- method
703+
irisDF <- suppressWarnings(createDataFrame(sqlContext, iris))
704+
expect_equal(names(irisDF)[1], "Sepal_Length")
705+
695706
# Test base::colnames base::names
696707
m2 <- cbind(1, 1:4)
697708
expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2"))

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
384384
* @param numReducers total number of reducers in the shuffle
385385
* @param fractionThreshold fraction of total map output size that a location must have
386386
* for it to be considered large.
387-
*
388-
* This method is not thread-safe.
389387
*/
390388
def getLocationsWithLargestOutputs(
391389
shuffleId: Int,
@@ -394,28 +392,36 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
394392
fractionThreshold: Double)
395393
: Option[Array[BlockManagerId]] = {
396394

397-
if (mapStatuses.contains(shuffleId)) {
398-
val statuses = mapStatuses(shuffleId)
399-
if (statuses.nonEmpty) {
400-
// HashMap to add up sizes of all blocks at the same location
401-
val locs = new HashMap[BlockManagerId, Long]
402-
var totalOutputSize = 0L
403-
var mapIdx = 0
404-
while (mapIdx < statuses.length) {
405-
val status = statuses(mapIdx)
406-
val blockSize = status.getSizeForBlock(reducerId)
407-
if (blockSize > 0) {
408-
locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
409-
totalOutputSize += blockSize
395+
val statuses = mapStatuses.get(shuffleId).orNull
396+
if (statuses != null) {
397+
statuses.synchronized {
398+
if (statuses.nonEmpty) {
399+
// HashMap to add up sizes of all blocks at the same location
400+
val locs = new HashMap[BlockManagerId, Long]
401+
var totalOutputSize = 0L
402+
var mapIdx = 0
403+
while (mapIdx < statuses.length) {
404+
val status = statuses(mapIdx)
405+
// status may be null here if we are called between registerShuffle, which creates an
406+
// array with null entries for each output, and registerMapOutputs, which populates it
407+
// with valid status entries. This is possible if one thread schedules a job which
408+
// depends on an RDD which is currently being computed by another thread.
409+
if (status != null) {
410+
val blockSize = status.getSizeForBlock(reducerId)
411+
if (blockSize > 0) {
412+
locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
413+
totalOutputSize += blockSize
414+
}
415+
}
416+
mapIdx = mapIdx + 1
417+
}
418+
val topLocs = locs.filter { case (loc, size) =>
419+
size.toDouble / totalOutputSize >= fractionThreshold
420+
}
421+
// Return if we have any locations which satisfy the required threshold
422+
if (topLocs.nonEmpty) {
423+
return Some(topLocs.keys.toArray)
410424
}
411-
mapIdx = mapIdx + 1
412-
}
413-
val topLocs = locs.filter { case (loc, size) =>
414-
size.toDouble / totalOutputSize >= fractionThreshold
415-
}
416-
// Return if we have any locations which satisfy the required threshold
417-
if (topLocs.nonEmpty) {
418-
return Some(topLocs.map(_._1).toArray)
419425
}
420426
}
421427
}

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.executor.TaskMetrics
2424
import org.apache.spark.memory.TaskMemoryManager
2525
import org.apache.spark.metrics.source.Source
26-
import org.apache.spark.util.TaskCompletionListener
26+
import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener}
2727

2828

2929
object TaskContext {
@@ -108,15 +108,39 @@ abstract class TaskContext extends Serializable {
108108
* Adds a (Java friendly) listener to be executed on task completion.
109109
* This will be called in all situation - success, failure, or cancellation.
110110
* An example use is for HadoopRDD to register a callback to close the input stream.
111+
*
112+
* Exceptions thrown by the listener will result in failure of the task.
111113
*/
112114
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
113115

114116
/**
115117
* Adds a listener in the form of a Scala closure to be executed on task completion.
116118
* This will be called in all situations - success, failure, or cancellation.
117119
* An example use is for HadoopRDD to register a callback to close the input stream.
120+
*
121+
* Exceptions thrown by the listener will result in failure of the task.
122+
*/
123+
def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = {
124+
addTaskCompletionListener(new TaskCompletionListener {
125+
override def onTaskCompletion(context: TaskContext): Unit = f(context)
126+
})
127+
}
128+
129+
/**
130+
* Adds a listener to be executed on task failure.
131+
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
118132
*/
119-
def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
133+
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
134+
135+
/**
136+
* Adds a listener to be executed on task failure.
137+
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
138+
*/
139+
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
140+
addTaskFailureListener(new TaskFailureListener {
141+
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
142+
})
143+
}
120144

121145
/**
122146
* Adds a callback function to be executed on task completion. An example use

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics
2323
import org.apache.spark.memory.TaskMemoryManager
2424
import org.apache.spark.metrics.MetricsSystem
2525
import org.apache.spark.metrics.source.Source
26-
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
26+
import org.apache.spark.util._
2727

2828
private[spark] class TaskContextImpl(
2929
val stageId: Int,
@@ -41,24 +41,28 @@ private[spark] class TaskContextImpl(
4141
// For backwards-compatibility; this method is now deprecated as of 1.3.0.
4242
override def attemptId(): Long = taskAttemptId
4343

44-
// List of callback functions to execute when the task completes.
44+
/** List of callback functions to execute when the task completes. */
4545
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
4646

47+
/** List of callback functions to execute when the task fails. */
48+
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
49+
4750
// Whether the corresponding task has been killed.
4851
@volatile private var interrupted: Boolean = false
4952

5053
// Whether the task has completed.
5154
@volatile private var completed: Boolean = false
5255

56+
// Whether the task has failed.
57+
@volatile private var failed: Boolean = false
58+
5359
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
5460
onCompleteCallbacks += listener
5561
this
5662
}
5763

58-
override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
59-
onCompleteCallbacks += new TaskCompletionListener {
60-
override def onTaskCompletion(context: TaskContext): Unit = f(context)
61-
}
64+
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
65+
onFailureCallbacks += listener
6266
this
6367
}
6468

@@ -69,7 +73,28 @@ private[spark] class TaskContextImpl(
6973
}
7074
}
7175

72-
/** Marks the task as completed and triggers the listeners. */
76+
/** Marks the task as failed and triggers the failure listeners. */
77+
private[spark] def markTaskFailed(error: Throwable): Unit = {
78+
// failure callbacks should only be called once
79+
if (failed) return
80+
failed = true
81+
val errorMsgs = new ArrayBuffer[String](2)
82+
// Process failure callbacks in the reverse order of registration
83+
onFailureCallbacks.reverse.foreach { listener =>
84+
try {
85+
listener.onTaskFailure(this, error)
86+
} catch {
87+
case e: Throwable =>
88+
errorMsgs += e.getMessage
89+
logError("Error in TaskFailureListener", e)
90+
}
91+
}
92+
if (errorMsgs.nonEmpty) {
93+
throw new TaskCompletionListenerException(errorMsgs, Option(error))
94+
}
95+
}
96+
97+
/** Marks the task as completed and triggers the completion listeners. */
7398
private[spark] def markTaskCompleted(): Unit = {
7499
completed = true
75100
val errorMsgs = new ArrayBuffer[String](2)

core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,14 @@ private[spark] class AppClient(
126126
registerMasterFutures.set(tryRegisterAllMasters())
127127
registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable {
128128
override def run(): Unit = {
129-
Utils.tryOrExit {
130-
if (registered.get) {
131-
registerMasterFutures.get.foreach(_.cancel(true))
132-
registerMasterThreadPool.shutdownNow()
133-
} else if (nthRetry >= REGISTRATION_RETRIES) {
134-
markDead("All masters are unresponsive! Giving up.")
135-
} else {
136-
registerMasterFutures.get.foreach(_.cancel(true))
137-
registerWithMaster(nthRetry + 1)
138-
}
129+
if (registered.get) {
130+
registerMasterFutures.get.foreach(_.cancel(true))
131+
registerMasterThreadPool.shutdownNow()
132+
} else if (nthRetry >= REGISTRATION_RETRIES) {
133+
markDead("All masters are unresponsive! Giving up.")
134+
} else {
135+
registerMasterFutures.get.foreach(_.cancel(true))
136+
registerWithMaster(nthRetry + 1)
139137
}
140138
}
141139
}, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS))

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
11081108
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
11091109
require(writer != null, "Unable to obtain RecordWriter")
11101110
var recordsWritten = 0L
1111-
Utils.tryWithSafeFinally {
1111+
Utils.tryWithSafeFinallyAndFailureCallbacks {
11121112
while (iter.hasNext) {
11131113
val pair = iter.next()
11141114
writer.write(pair._1, pair._2)
@@ -1194,7 +1194,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
11941194
writer.open()
11951195
var recordsWritten = 0L
11961196

1197-
Utils.tryWithSafeFinally {
1197+
Utils.tryWithSafeFinallyAndFailureCallbacks {
11981198
while (iter.hasNext) {
11991199
val record = iter.next()
12001200
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ private[spark] abstract class Task[T](
8787
}
8888
try {
8989
(runTask(context), context.collectAccumulators())
90+
} catch { case e: Throwable =>
91+
// Catch all errors; run task failure callbacks, and rethrow the exception.
92+
context.markTaskFailed(e)
93+
throw e
9094
} finally {
95+
// Call the task completion callbacks.
9196
context.markTaskCompleted()
9297
try {
9398
Utils.tryLogNonFatalError {

core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala

Lines changed: 0 additions & 34 deletions
This file was deleted.

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,6 @@ private[spark] object Utils extends Logging {
12441244
* exception from the original `out.write` call.
12451245
*/
12461246
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
1247-
// It would be nice to find a method on Try that did this
12481247
var originalThrowable: Throwable = null
12491248
try {
12501249
block
@@ -1270,6 +1269,44 @@ private[spark] object Utils extends Logging {
12701269
}
12711270
}
12721271

1272+
/**
1273+
* Execute a block of code, call the failure callbacks before finally block if there is any
1274+
* exceptions happen. But if exceptions happen in the finally block, do not suppress the original
1275+
* exception.
1276+
*
1277+
* This is primarily an issue with `finally { out.close() }` blocks, where
1278+
* close needs to be called to clean up `out`, but if an exception happened
1279+
* in `out.write`, it's likely `out` may be corrupted and `out.close` will
1280+
* fail as well. This would then suppress the original/likely more meaningful
1281+
* exception from the original `out.write` call.
1282+
*/
1283+
def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
1284+
var originalThrowable: Throwable = null
1285+
try {
1286+
block
1287+
} catch {
1288+
case t: Throwable =>
1289+
// Purposefully not using NonFatal, because even fatal exceptions
1290+
// we don't want to have our finallyBlock suppress
1291+
originalThrowable = t
1292+
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
1293+
throw originalThrowable
1294+
} finally {
1295+
try {
1296+
finallyBlock
1297+
} catch {
1298+
case t: Throwable =>
1299+
if (originalThrowable != null) {
1300+
originalThrowable.addSuppressed(t)
1301+
logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
1302+
throw originalThrowable
1303+
} else {
1304+
throw t
1305+
}
1306+
}
1307+
}
1308+
}
1309+
12731310
/** Default filtering function for finding call sites using `getCallSite`. */
12741311
private def sparkInternalExclusionFunction(className: String): Boolean = {
12751312
// A regular expression to match classes of the internal Spark API's
@@ -1991,8 +2028,10 @@ private[spark] object Utils extends Logging {
19912028
} catch {
19922029
case e: Exception if isBindCollision(e) =>
19932030
if (offset >= maxRetries) {
1994-
val exceptionMessage =
1995-
s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!"
2031+
val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " +
2032+
s"$maxRetries retries! Consider explicitly setting the appropriate port for the " +
2033+
s"service$serviceString (for example spark.ui.port for SparkUI) to an available " +
2034+
"port or increasing spark.port.maxRetries."
19962035
val exception = new BindException(exceptionMessage)
19972036
// restore original stack trace
19982037
exception.setStackTrace(e.getStackTrace)

0 commit comments

Comments
 (0)