Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](

if (!isLocal && conf.getBoolean("spark.speculation", false)) {
logInfo("Starting speculative execution thread")
speculationScheduler.scheduleAtFixedRate(new Runnable {
speculationScheduler.scheduleWithFixedDelay(new Runnable {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this change?

@jinxing64 jinxing64 Mar 16, 2017

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking checkSpeculatableTasks will synchronize TaskSchedulerImpl. If checkSpeculatableTasks doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. How do you think about it?Should it be included in this pr?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, this is a good change. Its somewhat related to other changes here, so I personally don't feel too strongly about needing to put it in its own pr.

I would say that if its in this pr, the change description should be updated to mention this as well.

@jinxing64 jinxing64 Mar 16, 2017

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@squito
Thanks a lot for looking into this. I will put the change in this pr : )

override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
checkSpeculatableTasks()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package org.apache.spark.scheduler

import java.io.NotSerializableException
import java.nio.ByteBuffer
import java.util.Arrays
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.math.{max, min}
import scala.math.max
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SchedulingMode._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils}
import org.apache.spark.util.collection.MedianHeap

/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
Expand Down Expand Up @@ -63,6 +63,8 @@ private[spark] class TaskSetManager(
// Limit of bytes for total size of results (default is 1GB)
val maxResultSize = Utils.getMaxResultSize(conf)

val speculationEnabled = conf.getBoolean("spark.speculation", false)

// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
Expand Down Expand Up @@ -141,6 +143,11 @@ private[spark] class TaskSetManager(
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]

// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
// of inserting into the heap when the heap won't be used.
val successfulTaskDurations = new MedianHeap()

// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
conf.getLong("spark.logging.exceptionPrintInterval", 10000)
Expand Down Expand Up @@ -696,6 +703,9 @@ private[spark] class TaskSetManager(
val info = taskInfos(tid)
val index = info.index
info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
if (speculationEnabled) {
successfulTaskDurations.insert(info.duration)
}
removeRunningTask(tid)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
Expand Down Expand Up @@ -917,11 +927,10 @@ private[spark] class TaskSetManager(
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)

if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1))
var medianDuration = successfulTaskDurations.median
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import scala.collection.mutable.PriorityQueue

/**
* MedianHeap is designed to be used to quickly track the median of a group of numbers
* that may contain duplicates. Inserting a new number has O(log n) time complexity and
* determining the median has O(1) time complexity.
* The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
* stores the smaller half of all numbers while the largerHalf stores the larger half.
* The sizes of two heaps need to be balanced each time when a new number is inserted so
* that their sizes will not be different by more than 1. Therefore each time when
* findMedian() is called we check if two heaps have the same size. If they do, we should
* return the average of the two top values of heaps. Otherwise we return the top of the
* heap which has one more element.
*/
private[spark] class MedianHeap(implicit val ord: Ordering[Double]) {

/**
* Stores all the numbers less than the current median in a smallerHalf,
* i.e median is the maximum, at the root.
*/
private[this] var smallerHalf = PriorityQueue.empty[Double](ord)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor -- could you make this comment a doc with /**? Even though its private, I find that helpful as that is useful in IDEs where they'll show this text w/ a hover on a reference


/**
* Stores all the numbers greater than the current median in a largerHalf,
* i.e median is the minimum, at the root.
*/
private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse)

def isEmpty(): Boolean = {
smallerHalf.isEmpty && largerHalf.isEmpty
}

def size(): Int = {
smallerHalf.size + largerHalf.size
}

def insert(x: Double): Unit = {
// If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf.
if (isEmpty) {
largerHalf.enqueue(x)
} else {
// If the number is larger than current median, it should be inserted into largerHalf,
// otherwise smallerHalf.
if (x > median) {
largerHalf.enqueue(x)
} else {
smallerHalf.enqueue(x)
}
}
rebalance()
}

private[this] def rebalance(): Unit = {
if (largerHalf.size - smallerHalf.size > 1) {
smallerHalf.enqueue(largerHalf.dequeue())
}
if (smallerHalf.size - largerHalf.size > 1) {
largerHalf.enqueue(smallerHalf.dequeue)
}
}

def median: Double = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: I find comments which basically just restate the method name to be pretty pointless. I'd only include them if they add something else, eg. preconditions, or complexity, etc. Mostly I'd say they're not necessary for any of the methods here.

if (isEmpty) {
throw new NoSuchElementException("MedianHeap is empty.")
}
if (largerHalf.size == smallerHalf.size) {
(largerHalf.head + smallerHalf.head) / 2.0
} else if (largerHalf.size > smallerHalf.size) {
largerHalf.head
} else {
smallerHalf.head
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(4)
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation", "true")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this once you make the change I suggested above to eliminate the (redundant) check

@jinxing64 jinxing64 Mar 16, 2017

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be set. Because the duration is inserted to MedianHeap only when spark.speculation=true(e.g. If I remove this, MedianHeap will be empty when call checkSpeculatableTasks).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh cool that makes sense

val clock = new ManualClock()
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
Expand Down Expand Up @@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation.quantile", "0.6")
sc.conf.set("spark.speculation", "true")
val clock = new ManualClock()
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import java.util.Arrays

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: can you combine these into one import (import java.util.{Arrays, NoSuchElementException})

import java.util.NoSuchElementException

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.apache.spark.SparkFunSuite

class MedianHeapSuite extends SparkFunSuite {

test("If no numbers in MedianHeap, NoSuchElementException is thrown.") {
val medianHeap = new MedianHeap()
intercept[NoSuchElementException] {
medianHeap.median
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalatest has a simpler pattern for this:

intercept[NoSuchElementException] {
  medianHeap.median
}

http://www.scalatest.org/user_guide/using_assertions

(I guess you could use assertThrows in this case, but I tend to always use intercept since it also lets you inspect the thrown exception.)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the recommendation :)

}

test("Median should be correct when size of MedianHeap is even") {
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size() === 10)
assert(medianHeap.median === ((array(4) + array(5)) / 2.0))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of indexing into the array, I think it would be clearer here to just hard-code 4.5 (it's easier to see that the median is 4.5 than to reason about the indices in the array)

}

test("Median should be correct when size of MedianHeap is odd") {
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size() === 9)
assert(medianHeap.median === (array(4)))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here -- just medianHeap.median === 4

}

test("Median should be correct though there are duplicated numbers inside.") {
val array = Array(0, 0, 1, 1, 2, 2, 3, 3, 4, 4)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to something like:

Array(0, 0, 1, 1, 2, 3, 4)?

Otherwise the median heap could be handling the duplicates wrong (e.g., by not actually inserting duplicates), and the assertion at the bottom would still old. Then the check at the end can be medianHeap.median === 1.

val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size === 10)
assert(medianHeap.median === ((array(4) + array(5)) / 2.0))
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Kay asked for tests with a some hardcoded data, but I think these tests are too simplistic. All of these tests insert data in order, and none have significant skew.

Can you add a test case which does something like:

  1. inserts 10 elements with the same value (eg. 5), check the median
  2. insert 100 elements with a larger value (eg 10) check the median
  3. insert 1000 elements with an even smaller value (eg 0), check the median

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I added this change.


test("Median should be correct when skew situations.") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"when skew situations" --> "when input data is skewed"

val medianHeap = new MedianHeap()
(0 until 10).foreach(_ => medianHeap.insert(5))
assert(medianHeap.median === 5)
(0 until 100).foreach(_ => medianHeap.insert(10))
assert(medianHeap.median === 10)
(0 until 1000).foreach(_ => medianHeap.insert(0))
assert(medianHeap.median === 0)
}
}